|
from fastapi import Request, status |
|
from fastapi.responses import JSONResponse |
|
from starlette.middleware.base import BaseHTTPMiddleware |
|
import logging |
|
import time |
|
import traceback |
|
import uuid |
|
from .utils import get_local_time |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class RequestLoggingMiddleware(BaseHTTPMiddleware): |
|
"""Middleware to log requests and responses""" |
|
|
|
async def dispatch(self, request: Request, call_next): |
|
request_id = str(uuid.uuid4()) |
|
request.state.request_id = request_id |
|
|
|
|
|
client_host = request.client.host if request.client else "unknown" |
|
logger.info(f"Request [{request_id}]: {request.method} {request.url.path} from {client_host}") |
|
|
|
|
|
start_time = time.time() |
|
|
|
try: |
|
|
|
response = await call_next(request) |
|
|
|
|
|
process_time = time.time() - start_time |
|
logger.info(f"Response [{request_id}]: {response.status_code} processed in {process_time:.4f}s") |
|
|
|
|
|
response.headers["X-Request-ID"] = request_id |
|
response.headers["X-Process-Time"] = str(process_time) |
|
|
|
return response |
|
|
|
except Exception as e: |
|
|
|
process_time = time.time() - start_time |
|
logger.error(f"Error [{request_id}] after {process_time:.4f}s: {str(e)}") |
|
logger.error(traceback.format_exc()) |
|
|
|
|
|
return JSONResponse( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
content={ |
|
"detail": "Internal server error", |
|
"request_id": request_id, |
|
"timestamp": get_local_time() |
|
} |
|
) |
|
|
|
class ErrorHandlingMiddleware(BaseHTTPMiddleware): |
|
"""Middleware to handle uncaught exceptions in the application""" |
|
|
|
async def dispatch(self, request: Request, call_next): |
|
try: |
|
return await call_next(request) |
|
except Exception as e: |
|
|
|
request_id = getattr(request.state, "request_id", str(uuid.uuid4())) |
|
|
|
|
|
logger.error(f"Uncaught exception [{request_id}]: {str(e)}") |
|
logger.error(traceback.format_exc()) |
|
|
|
|
|
return JSONResponse( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
content={ |
|
"detail": "Internal server error", |
|
"request_id": request_id, |
|
"timestamp": get_local_time() |
|
} |
|
) |
|
|
|
class DatabaseCheckMiddleware(BaseHTTPMiddleware): |
|
"""Middleware to check database connections before each request""" |
|
|
|
async def dispatch(self, request: Request, call_next): |
|
|
|
skip_paths = ["/", "/health", "/docs", "/redoc", "/openapi.json"] |
|
if request.url.path in skip_paths: |
|
return await call_next(request) |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
return await call_next(request) |
|
|
|
except Exception as e: |
|
|
|
logger.error(f"Database connection check failed: {str(e)}") |
|
|
|
|
|
return JSONResponse( |
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
|
content={ |
|
"detail": "Database connection failed", |
|
"timestamp": get_local_time() |
|
} |
|
) |