Spaces:
Running
Running
import functools | |
import os | |
import traceback | |
import uuid | |
from typing import Any, Dict, Optional, Tuple | |
from pydantic import BaseModel, Field, ValidationError | |
from starfish.common.logger import get_logger | |
logger = get_logger(__name__) | |
# Simple configuration flag (can be set from app config) | |
# Default to False for production safety | |
INCLUDE_TRACEBACK_IN_RESPONSE = os.environ.get("INCLUDE_TRACEBACK_IN_RESPONSE", False) | |
############################################# | |
# HTTP Status Codes | |
############################################# | |
class HTTPStatus: | |
"""Standard HTTP status codes.""" | |
OK = 200 | |
BAD_REQUEST = 400 | |
UNAUTHORIZED = 401 | |
FORBIDDEN = 403 | |
NOT_FOUND = 404 | |
UNPROCESSABLE_ENTITY = 422 | |
INTERNAL_SERVER_ERROR = 500 | |
############################################# | |
# Error Response Model | |
############################################# | |
class ErrorResponse(BaseModel): | |
"""Standardized error response format for API errors.""" | |
status: str = "error" | |
error_id: str = Field(..., description="Unique identifier for this error occurrence") | |
message: str | |
error_type: str | |
details: Optional[Dict[str, Any]] = None | |
############################################# | |
# Exception Classes | |
############################################# | |
class StarfishException(Exception): | |
"""Base exception for all Starfish exceptions.""" | |
status_code: int = HTTPStatus.INTERNAL_SERVER_ERROR | |
default_message: str = "An unexpected error occurred" | |
def __init__(self, message: Optional[str] = None, details: Optional[Dict[str, Any]] = None): | |
self.message = message or self.default_message | |
self.details = details | |
self.error_id = str(uuid.uuid4()) | |
super().__init__(self.message) | |
def __str__(self): | |
if self.details: | |
return f"{self.message} - Details: {self.details}" | |
return self.message | |
class ValidationError(StarfishException): | |
"""Exception raised for validation errors.""" | |
status_code = HTTPStatus.UNPROCESSABLE_ENTITY | |
default_message = "Validation error" | |
class PydanticValidationError(ValidationError): | |
"""Exception raised for Pydantic validation errors. | |
This class formats Pydantic validation errors into user-friendly messages | |
and preserves the detailed error information for debugging. | |
""" | |
default_message = "Data validation error" | |
def format_validation_error(error: ValidationError) -> Tuple[str, Dict[str, Any]]: | |
"""Format a Pydantic ValidationError into a user-friendly message and details. | |
Args: | |
error: The Pydantic ValidationError to format | |
Returns: | |
Tuple of (message, details) | |
""" | |
if not hasattr(error, "errors") or not callable(getattr(error, "errors", None)): | |
return str(error), {} | |
error_details = error.errors() | |
if not error_details: | |
return "Validation error", {} | |
# Format fields with errors | |
field_errors = [] | |
for err in error_details: | |
# Get error type and location | |
err_type = err.get("type", "unknown") | |
loc = err.get("loc", []) | |
# Special handling for discriminated unions | |
# If first element is a string and subsequent elements exist, might be a discriminated union | |
if len(loc) >= 2 and isinstance(loc[0], str) and isinstance(loc[1], str): | |
# This might be a discriminated union error like ['vanilla', 'user_input'] | |
type_name = loc[0] | |
field_name = loc[1] | |
# Handle errors differently based on type | |
if err_type == "missing": | |
field_errors.append(f"Field '{field_name}' is required for '{type_name}' type") | |
continue | |
# Standard handling for other errors | |
loc_str = ".".join(str(item) for item in loc) if loc else "unknown" | |
msg = err.get("msg", "") | |
# Create a user-friendly error message based on error type | |
if err_type == "missing": | |
field_errors.append(f"'{loc_str}' is required") | |
elif err_type == "type_error": | |
field_errors.append(f"'{loc_str}' has an invalid type") | |
elif err_type == "value_error": | |
field_errors.append(f"'{loc_str}' has an invalid value") | |
elif err_type.startswith("value_error"): | |
field_errors.append(f"'{loc_str}' {msg}") | |
elif err_type.startswith("type_error"): | |
field_errors.append(f"'{loc_str}' {msg}") | |
elif err_type == "extra_forbidden": | |
field_errors.append(f"'{loc_str}' is not allowed") | |
else: | |
field_errors.append(f"'{loc_str}': {msg}") | |
# Create a combined message | |
if len(field_errors) == 1: | |
message = f"Validation error: {field_errors[0]}" | |
else: | |
message = f"Validation errors: {', '.join(field_errors)}" | |
return message, {"validation_errors": error_details} | |
def __init__(self, validation_error: ValidationError, message: Optional[str] = None, details: Optional[Dict[str, Any]] = None): | |
# Format the validation error if no message is provided | |
if message is None: | |
message, error_details = self.format_validation_error(validation_error) | |
# Merge error details with provided details | |
if details is None: | |
details = error_details | |
else: | |
details = {**details, **error_details} | |
super().__init__(message=message, details=details) | |
class ParserError(StarfishException): | |
"""Base exception for all parser-related errors.""" | |
status_code = HTTPStatus.UNPROCESSABLE_ENTITY | |
default_message = "Parser error" | |
class JsonParserError(ParserError): | |
"""Exception raised when JSON parsing fails.""" | |
default_message = "JSON parsing error" | |
class SchemaValidationError(ParserError): | |
"""Exception raised when data doesn't conform to schema.""" | |
default_message = "Schema validation error" | |
def __str__(self): | |
if self.details and "errors" in self.details: | |
errors_text = "\n".join([f"- {err}" for err in self.details["errors"]]) | |
return f"{self.message}:\n{errors_text}" | |
return super().__str__() | |
class PydanticParserError(ParserError): | |
"""Exception raised when Pydantic parsing or validation fails.""" | |
default_message = "Pydantic parsing error" | |
############################################# | |
# Error Handling Functions | |
############################################# | |
def format_error(exc: Exception, include_traceback: bool = INCLUDE_TRACEBACK_IN_RESPONSE) -> Tuple[ErrorResponse, int]: | |
"""Format an exception into a standardized error response. | |
Args: | |
exc: The exception to format | |
include_traceback: Whether to include traceback in the response details | |
Returns: | |
Tuple of (error_response, status_code) | |
""" | |
# Get traceback for logging (always) - may optionally include in response | |
tb_str = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) | |
# Check for exception chaining | |
cause = getattr(exc, "__cause__", None) | |
cause_tb = None | |
if cause: | |
cause_tb = "".join(traceback.format_exception(type(cause), cause, cause.__traceback__)) | |
logger.error(f"Original exception: {type(cause).__name__}: {str(cause)}") | |
logger.error(f"Original traceback: {cause_tb}") | |
# Log the current exception | |
logger.error(f"Exception: {type(exc).__name__}: {str(exc)}") | |
logger.error(f"Traceback: {tb_str}") | |
# Handle Starfish exceptions | |
if isinstance(exc, StarfishException): | |
error_id = getattr(exc, "error_id", str(uuid.uuid4())) | |
status_code = exc.status_code | |
details = exc.details or {} | |
# Only add traceback to details if requested | |
if include_traceback: | |
details["traceback"] = tb_str | |
if cause_tb: | |
details["original_traceback"] = cause_tb | |
return ErrorResponse(error_id=error_id, message=exc.message, error_type=type(exc).__name__, details=details if details else None), status_code | |
# Handle Pydantic validation errors | |
elif isinstance(exc, ValidationError): | |
error_id = str(uuid.uuid4()) | |
status_code = HTTPStatus.UNPROCESSABLE_ENTITY | |
details = {"validation_errors": exc.errors()} | |
if include_traceback: | |
details["traceback"] = tb_str | |
if cause_tb: | |
details["original_traceback"] = cause_tb | |
return ErrorResponse(error_id=error_id, message="Validation error", error_type="ValidationError", details=details), status_code | |
# Handle all other exceptions | |
else: | |
error_id = str(uuid.uuid4()) | |
status_code = HTTPStatus.INTERNAL_SERVER_ERROR | |
details = {} | |
if include_traceback: | |
details["traceback"] = tb_str | |
if cause_tb: | |
details["original_traceback"] = cause_tb | |
return ErrorResponse( | |
error_id=error_id, message=str(exc) or "An unexpected error occurred", error_type=type(exc).__name__, details=details if details else None | |
), status_code | |
############################################# | |
# Utility Decorators | |
############################################# | |
def handle_exceptions(return_value=None): | |
"""Decorator to handle exceptions in both async and sync functions. | |
This decorator can be used with any function to catch exceptions, | |
log them, and return a default value instead of raising. | |
Args: | |
return_value: The value to return if an exception occurs | |
Returns: | |
Decorated function with exception handling | |
""" | |
def decorator(func): | |
# Import asyncio here to avoid dependency if not needed | |
try: | |
import asyncio | |
is_async_available = True | |
except ImportError: | |
is_async_available = False | |
# Handle async functions | |
if is_async_available and asyncio.iscoroutinefunction(func): | |
async def async_wrapper(*args, **kwargs): | |
try: | |
return await func(*args, **kwargs) | |
except Exception as exc: | |
# Format and log the error but don't raise | |
format_error(exc, include_traceback=True) | |
return return_value | |
return async_wrapper | |
# Handle synchronous functions | |
else: | |
def sync_wrapper(*args, **kwargs): | |
try: | |
return func(*args, **kwargs) | |
except Exception as exc: | |
# Format and log the error but don't raise | |
format_error(exc, include_traceback=True) | |
return return_value | |
return sync_wrapper | |
return decorator | |