John-Jiang's picture
init commit
5301c48
from typing import Any, Dict, List, Optional, Type, Union
from pydantic import BaseModel, ValidationError
from starfish.common.exceptions import PydanticParserError
from starfish.common.logger import get_logger
from .json_parser import JSONParser
logger = get_logger(__name__)
class PydanticParser:
"""Handles parsing and validation using Pydantic models.
Provides utilities for converting between Pydantic and JSON schemas.
"""
@staticmethod
def to_json_schema(model: Type[BaseModel]) -> Dict[str, Any]:
"""Convert a Pydantic model to JSON schema.
Args:
model: Pydantic model class
Returns:
JSON schema dictionary
Raises:
TypeError: If model is not a Pydantic model
"""
# Handle both Pydantic v1 and v2
if hasattr(model, "model_json_schema"):
# Pydantic v2
return model.model_json_schema()
else:
# Pydantic v1
return model.schema()
@staticmethod
def _process_schema_for_formatting(schema: Dict[str, Any]) -> Dict[str, Any]:
"""Process a Pydantic-generated JSON schema for better format instruction display.
This resolves $ref references to make a flattened schema for display purposes.
Args:
schema: The Pydantic JSON schema
Returns:
A processed schema with resolved references
"""
# Create a copy to avoid modifying the original schema
processed_schema = schema.copy()
# Process the schema recursively
def process_node(node):
if not isinstance(node, dict):
return node
processed_node = node.copy()
# Handle $ref directly
if "$ref" in processed_node:
ref_path = processed_node["$ref"]
if ref_path.startswith("#/$defs/"):
def_name = ref_path.split("/")[-1]
if "$defs" in schema and def_name in schema["$defs"]:
# Replace the reference with the actual definition
ref_definition = schema["$defs"][def_name].copy()
# Preserve any additional properties like description
for key, value in processed_node.items():
if key != "$ref":
ref_definition[key] = value
# Process the referenced definition recursively
processed_node = process_node(ref_definition)
# Handle anyOf with references (used for Optional fields)
if "anyOf" in processed_node:
# For format instructions, we'll just take the first non-null type
# from anyOf as a simplification
for item in processed_node["anyOf"]:
if item.get("type") != "null" and "$ref" in item:
ref_path = item["$ref"]
if ref_path.startswith("#/$defs/"):
def_name = ref_path.split("/")[-1]
if "$defs" in schema and def_name in schema["$defs"]:
# Replace anyOf with the referenced definition
ref_definition = schema["$defs"][def_name].copy()
# Preserve any additional properties like description from the parent
for key, value in processed_node.items():
if key != "anyOf":
ref_definition[key] = value
# Process the referenced definition recursively
processed_node = process_node(ref_definition)
break
# Process properties recursively
if "properties" in processed_node:
for prop_name, prop_value in list(processed_node["properties"].items()):
processed_node["properties"][prop_name] = process_node(prop_value)
# Process array items recursively
if "items" in processed_node:
processed_node["items"] = process_node(processed_node["items"])
return processed_node
# Start the recursive processing at the root level
if "properties" in processed_schema:
for prop_name, prop_value in list(processed_schema["properties"].items()):
processed_schema["properties"][prop_name] = process_node(prop_value)
return processed_schema
@staticmethod
def parse_dict_or_list(data: Union[Dict[str, Any], List[Dict[str, Any]]], model: Type[BaseModel]) -> Union[BaseModel, List[BaseModel]]:
"""Parse data into Pydantic model instances.
Args:
data: Dictionary or list of dictionaries to parse
model: Pydantic model class to parse into
Returns:
Single model instance or list of model instances
Raises:
TypeError: If model is not a Pydantic model or data has invalid type
ValidationError: If Pydantic validation fails
"""
if isinstance(data, list):
# Handle list of objects
if not all(isinstance(item, dict) for item in data):
raise TypeError("All items in list must be dictionaries")
if hasattr(model, "model_validate"):
# Pydantic v2
return [model.model_validate(item) for item in data]
else:
# Pydantic v1
return [model.parse_obj(item) for item in data]
else:
# Handle single object
if hasattr(model, "model_validate"):
# Pydantic v2
return model.model_validate(data)
else:
# Pydantic v1
return model.parse_obj(data)
@staticmethod
def parse_llm_output(
text: str, model: Type[BaseModel], json_wrapper_key: Optional[str] = None, strict: bool = False
) -> Optional[Union[BaseModel, List[BaseModel]]]:
"""Parse LLM output text into Pydantic model instances with configurable error handling.
Args:
text: Raw text from LLM that may contain JSON
model: Pydantic model class to parse into
json_wrapper_key: Optional key that may wrap the actual data
strict: If True, raise errors. If False, return None and log warning
Returns:
Single model instance or list of model instances if successful,
None if parsing fails in non-strict mode
Raises:
PydanticParserError: If parsing fails in strict mode
JsonParserError: If JSON parsing fails in strict mode
SchemaValidationError: If JSON schema validation fails in strict mode
"""
try:
# Use JSONParser to handle initial JSON parsing (let its errors propagate in strict mode)
json_data = JSONParser.parse_llm_output(
text,
json_wrapper_key=json_wrapper_key,
strict=strict, # Pass through the strict parameter
)
# If JSONParser returned None (in non-strict mode), return None
if json_data is None:
return None
# Convert to Pydantic model(s)
parsed_data = PydanticParser.parse_dict_or_list(json_data, model)
# If the parsed data is a list of one item and not wrapped, return just the item
# This makes it consistent with how most APIs would expect a single object
if isinstance(parsed_data, list) and len(parsed_data) == 1 and not json_wrapper_key:
return parsed_data[0]
return parsed_data
except ValidationError as e:
# Handle Pydantic validation errors
if strict:
raise PydanticParserError("Failed to validate against Pydantic model", details={"errors": e.errors()}) from e
logger.warning(f"Failed to validate LLM response against Pydantic model: {str(e)}")
logger.debug(f"Validation errors: {e.errors()}")
return None
except TypeError as e:
# Handle type errors from parse_dict_or_list
if strict:
raise PydanticParserError(f"Type error during parsing: {str(e)}") from e
logger.warning(f"Type error in LLM response: {str(e)}")
return None
@staticmethod
def get_format_instructions(model: Type[BaseModel], json_wrapper_key: Optional[str] = None, show_array_items: int = 1) -> str:
"""Format a Pydantic model schema as human-readable instructions.
Args:
model: Pydantic model class
json_wrapper_key: Optional key to wrap the schema in an array
show_array_items: Number of example items to show in an array wrapper
Returns:
Formatted string with schema instructions
Raises:
TypeError: If model is not a Pydantic model
"""
json_schema = PydanticParser.to_json_schema(model)
# Process the schema to resolve references for better display
processed_schema = PydanticParser._process_schema_for_formatting(json_schema)
return JSONParser.get_format_instructions(processed_schema, json_wrapper_key, show_array_items)