Pix-Agent / app /utils /middleware.py
Cuong2004's picture
first commit
ac0f906
raw
history blame
4.04 kB
from fastapi import Request, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import logging
import time
import traceback
import uuid
from .utils import get_local_time
# Configure logging
logger = logging.getLogger(__name__)
class RequestLoggingMiddleware(BaseHTTPMiddleware):
"""Middleware to log requests and responses"""
async def dispatch(self, request: Request, call_next):
request_id = str(uuid.uuid4())
request.state.request_id = request_id
# Log request information
client_host = request.client.host if request.client else "unknown"
logger.info(f"Request [{request_id}]: {request.method} {request.url.path} from {client_host}")
# Measure processing time
start_time = time.time()
try:
# Process request
response = await call_next(request)
# Calculate processing time
process_time = time.time() - start_time
logger.info(f"Response [{request_id}]: {response.status_code} processed in {process_time:.4f}s")
# Add headers
response.headers["X-Request-ID"] = request_id
response.headers["X-Process-Time"] = str(process_time)
return response
except Exception as e:
# Log error
process_time = time.time() - start_time
logger.error(f"Error [{request_id}] after {process_time:.4f}s: {str(e)}")
logger.error(traceback.format_exc())
# Return error response
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"detail": "Internal server error",
"request_id": request_id,
"timestamp": get_local_time()
}
)
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
"""Middleware to handle uncaught exceptions in the application"""
async def dispatch(self, request: Request, call_next):
try:
return await call_next(request)
except Exception as e:
# Get request_id if available
request_id = getattr(request.state, "request_id", str(uuid.uuid4()))
# Log error
logger.error(f"Uncaught exception [{request_id}]: {str(e)}")
logger.error(traceback.format_exc())
# Return error response
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"detail": "Internal server error",
"request_id": request_id,
"timestamp": get_local_time()
}
)
class DatabaseCheckMiddleware(BaseHTTPMiddleware):
"""Middleware to check database connections before each request"""
async def dispatch(self, request: Request, call_next):
# Skip paths that don't need database checks
skip_paths = ["/", "/health", "/docs", "/redoc", "/openapi.json"]
if request.url.path in skip_paths:
return await call_next(request)
# Check database connections
try:
# TODO: Add checks for MongoDB and Pinecone if needed
# PostgreSQL check is already done in route handler with get_db() method
# Process request normally
return await call_next(request)
except Exception as e:
# Log error
logger.error(f"Database connection check failed: {str(e)}")
# Return error response
return JSONResponse(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
content={
"detail": "Database connection failed",
"timestamp": get_local_time()
}
)