File size: 4,042 Bytes
ac0f906 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
from fastapi import Request, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import logging
import time
import traceback
import uuid
from .utils import get_local_time
# Configure logging
logger = logging.getLogger(__name__)
class RequestLoggingMiddleware(BaseHTTPMiddleware):
"""Middleware to log requests and responses"""
async def dispatch(self, request: Request, call_next):
request_id = str(uuid.uuid4())
request.state.request_id = request_id
# Log request information
client_host = request.client.host if request.client else "unknown"
logger.info(f"Request [{request_id}]: {request.method} {request.url.path} from {client_host}")
# Measure processing time
start_time = time.time()
try:
# Process request
response = await call_next(request)
# Calculate processing time
process_time = time.time() - start_time
logger.info(f"Response [{request_id}]: {response.status_code} processed in {process_time:.4f}s")
# Add headers
response.headers["X-Request-ID"] = request_id
response.headers["X-Process-Time"] = str(process_time)
return response
except Exception as e:
# Log error
process_time = time.time() - start_time
logger.error(f"Error [{request_id}] after {process_time:.4f}s: {str(e)}")
logger.error(traceback.format_exc())
# Return error response
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"detail": "Internal server error",
"request_id": request_id,
"timestamp": get_local_time()
}
)
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
"""Middleware to handle uncaught exceptions in the application"""
async def dispatch(self, request: Request, call_next):
try:
return await call_next(request)
except Exception as e:
# Get request_id if available
request_id = getattr(request.state, "request_id", str(uuid.uuid4()))
# Log error
logger.error(f"Uncaught exception [{request_id}]: {str(e)}")
logger.error(traceback.format_exc())
# Return error response
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"detail": "Internal server error",
"request_id": request_id,
"timestamp": get_local_time()
}
)
class DatabaseCheckMiddleware(BaseHTTPMiddleware):
"""Middleware to check database connections before each request"""
async def dispatch(self, request: Request, call_next):
# Skip paths that don't need database checks
skip_paths = ["/", "/health", "/docs", "/redoc", "/openapi.json"]
if request.url.path in skip_paths:
return await call_next(request)
# Check database connections
try:
# TODO: Add checks for MongoDB and Pinecone if needed
# PostgreSQL check is already done in route handler with get_db() method
# Process request normally
return await call_next(request)
except Exception as e:
# Log error
logger.error(f"Database connection check failed: {str(e)}")
# Return error response
return JSONResponse(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
content={
"detail": "Database connection failed",
"timestamp": get_local_time()
}
) |