John-Jiang's picture
init commit
5301c48
raw
history blame
11 kB
import functools
import os
import traceback
import uuid
from typing import Any, Dict, Optional, Tuple
from pydantic import BaseModel, Field, ValidationError
from starfish.common.logger import get_logger
logger = get_logger(__name__)
# Simple configuration flag (can be set from app config)
# Default to False for production safety
INCLUDE_TRACEBACK_IN_RESPONSE = os.environ.get("INCLUDE_TRACEBACK_IN_RESPONSE", False)
#############################################
# HTTP Status Codes
#############################################
class HTTPStatus:
"""Standard HTTP status codes."""
OK = 200
BAD_REQUEST = 400
UNAUTHORIZED = 401
FORBIDDEN = 403
NOT_FOUND = 404
UNPROCESSABLE_ENTITY = 422
INTERNAL_SERVER_ERROR = 500
#############################################
# Error Response Model
#############################################
class ErrorResponse(BaseModel):
"""Standardized error response format for API errors."""
status: str = "error"
error_id: str = Field(..., description="Unique identifier for this error occurrence")
message: str
error_type: str
details: Optional[Dict[str, Any]] = None
#############################################
# Exception Classes
#############################################
class StarfishException(Exception):
"""Base exception for all Starfish exceptions."""
status_code: int = HTTPStatus.INTERNAL_SERVER_ERROR
default_message: str = "An unexpected error occurred"
def __init__(self, message: Optional[str] = None, details: Optional[Dict[str, Any]] = None):
self.message = message or self.default_message
self.details = details
self.error_id = str(uuid.uuid4())
super().__init__(self.message)
def __str__(self):
if self.details:
return f"{self.message} - Details: {self.details}"
return self.message
class ValidationError(StarfishException):
"""Exception raised for validation errors."""
status_code = HTTPStatus.UNPROCESSABLE_ENTITY
default_message = "Validation error"
class PydanticValidationError(ValidationError):
"""Exception raised for Pydantic validation errors.
This class formats Pydantic validation errors into user-friendly messages
and preserves the detailed error information for debugging.
"""
default_message = "Data validation error"
@staticmethod
def format_validation_error(error: ValidationError) -> Tuple[str, Dict[str, Any]]:
"""Format a Pydantic ValidationError into a user-friendly message and details.
Args:
error: The Pydantic ValidationError to format
Returns:
Tuple of (message, details)
"""
if not hasattr(error, "errors") or not callable(getattr(error, "errors", None)):
return str(error), {}
error_details = error.errors()
if not error_details:
return "Validation error", {}
# Format fields with errors
field_errors = []
for err in error_details:
# Get error type and location
err_type = err.get("type", "unknown")
loc = err.get("loc", [])
# Special handling for discriminated unions
# If first element is a string and subsequent elements exist, might be a discriminated union
if len(loc) >= 2 and isinstance(loc[0], str) and isinstance(loc[1], str):
# This might be a discriminated union error like ['vanilla', 'user_input']
type_name = loc[0]
field_name = loc[1]
# Handle errors differently based on type
if err_type == "missing":
field_errors.append(f"Field '{field_name}' is required for '{type_name}' type")
continue
# Standard handling for other errors
loc_str = ".".join(str(item) for item in loc) if loc else "unknown"
msg = err.get("msg", "")
# Create a user-friendly error message based on error type
if err_type == "missing":
field_errors.append(f"'{loc_str}' is required")
elif err_type == "type_error":
field_errors.append(f"'{loc_str}' has an invalid type")
elif err_type == "value_error":
field_errors.append(f"'{loc_str}' has an invalid value")
elif err_type.startswith("value_error"):
field_errors.append(f"'{loc_str}' {msg}")
elif err_type.startswith("type_error"):
field_errors.append(f"'{loc_str}' {msg}")
elif err_type == "extra_forbidden":
field_errors.append(f"'{loc_str}' is not allowed")
else:
field_errors.append(f"'{loc_str}': {msg}")
# Create a combined message
if len(field_errors) == 1:
message = f"Validation error: {field_errors[0]}"
else:
message = f"Validation errors: {', '.join(field_errors)}"
return message, {"validation_errors": error_details}
def __init__(self, validation_error: ValidationError, message: Optional[str] = None, details: Optional[Dict[str, Any]] = None):
# Format the validation error if no message is provided
if message is None:
message, error_details = self.format_validation_error(validation_error)
# Merge error details with provided details
if details is None:
details = error_details
else:
details = {**details, **error_details}
super().__init__(message=message, details=details)
class ParserError(StarfishException):
"""Base exception for all parser-related errors."""
status_code = HTTPStatus.UNPROCESSABLE_ENTITY
default_message = "Parser error"
class JsonParserError(ParserError):
"""Exception raised when JSON parsing fails."""
default_message = "JSON parsing error"
class SchemaValidationError(ParserError):
"""Exception raised when data doesn't conform to schema."""
default_message = "Schema validation error"
def __str__(self):
if self.details and "errors" in self.details:
errors_text = "\n".join([f"- {err}" for err in self.details["errors"]])
return f"{self.message}:\n{errors_text}"
return super().__str__()
class PydanticParserError(ParserError):
"""Exception raised when Pydantic parsing or validation fails."""
default_message = "Pydantic parsing error"
#############################################
# Error Handling Functions
#############################################
def format_error(exc: Exception, include_traceback: bool = INCLUDE_TRACEBACK_IN_RESPONSE) -> Tuple[ErrorResponse, int]:
"""Format an exception into a standardized error response.
Args:
exc: The exception to format
include_traceback: Whether to include traceback in the response details
Returns:
Tuple of (error_response, status_code)
"""
# Get traceback for logging (always) - may optionally include in response
tb_str = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
# Check for exception chaining
cause = getattr(exc, "__cause__", None)
cause_tb = None
if cause:
cause_tb = "".join(traceback.format_exception(type(cause), cause, cause.__traceback__))
logger.error(f"Original exception: {type(cause).__name__}: {str(cause)}")
logger.error(f"Original traceback: {cause_tb}")
# Log the current exception
logger.error(f"Exception: {type(exc).__name__}: {str(exc)}")
logger.error(f"Traceback: {tb_str}")
# Handle Starfish exceptions
if isinstance(exc, StarfishException):
error_id = getattr(exc, "error_id", str(uuid.uuid4()))
status_code = exc.status_code
details = exc.details or {}
# Only add traceback to details if requested
if include_traceback:
details["traceback"] = tb_str
if cause_tb:
details["original_traceback"] = cause_tb
return ErrorResponse(error_id=error_id, message=exc.message, error_type=type(exc).__name__, details=details if details else None), status_code
# Handle Pydantic validation errors
elif isinstance(exc, ValidationError):
error_id = str(uuid.uuid4())
status_code = HTTPStatus.UNPROCESSABLE_ENTITY
details = {"validation_errors": exc.errors()}
if include_traceback:
details["traceback"] = tb_str
if cause_tb:
details["original_traceback"] = cause_tb
return ErrorResponse(error_id=error_id, message="Validation error", error_type="ValidationError", details=details), status_code
# Handle all other exceptions
else:
error_id = str(uuid.uuid4())
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
details = {}
if include_traceback:
details["traceback"] = tb_str
if cause_tb:
details["original_traceback"] = cause_tb
return ErrorResponse(
error_id=error_id, message=str(exc) or "An unexpected error occurred", error_type=type(exc).__name__, details=details if details else None
), status_code
#############################################
# Utility Decorators
#############################################
def handle_exceptions(return_value=None):
"""Decorator to handle exceptions in both async and sync functions.
This decorator can be used with any function to catch exceptions,
log them, and return a default value instead of raising.
Args:
return_value: The value to return if an exception occurs
Returns:
Decorated function with exception handling
"""
def decorator(func):
# Import asyncio here to avoid dependency if not needed
try:
import asyncio
is_async_available = True
except ImportError:
is_async_available = False
# Handle async functions
if is_async_available and asyncio.iscoroutinefunction(func):
@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except Exception as exc:
# Format and log the error but don't raise
format_error(exc, include_traceback=True)
return return_value
return async_wrapper
# Handle synchronous functions
else:
@functools.wraps(func)
def sync_wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as exc:
# Format and log the error but don't raise
format_error(exc, include_traceback=True)
return return_value
return sync_wrapper
return decorator