Spaces:
Running
Running
from typing import Any, Dict, List, Optional, Type, Union | |
from pydantic import BaseModel, ValidationError | |
from starfish.common.exceptions import PydanticParserError | |
from starfish.common.logger import get_logger | |
from .json_parser import JSONParser | |
logger = get_logger(__name__) | |
class PydanticParser: | |
"""Handles parsing and validation using Pydantic models. | |
Provides utilities for converting between Pydantic and JSON schemas. | |
""" | |
def to_json_schema(model: Type[BaseModel]) -> Dict[str, Any]: | |
"""Convert a Pydantic model to JSON schema. | |
Args: | |
model: Pydantic model class | |
Returns: | |
JSON schema dictionary | |
Raises: | |
TypeError: If model is not a Pydantic model | |
""" | |
# Handle both Pydantic v1 and v2 | |
if hasattr(model, "model_json_schema"): | |
# Pydantic v2 | |
return model.model_json_schema() | |
else: | |
# Pydantic v1 | |
return model.schema() | |
def _process_schema_for_formatting(schema: Dict[str, Any]) -> Dict[str, Any]: | |
"""Process a Pydantic-generated JSON schema for better format instruction display. | |
This resolves $ref references to make a flattened schema for display purposes. | |
Args: | |
schema: The Pydantic JSON schema | |
Returns: | |
A processed schema with resolved references | |
""" | |
# Create a copy to avoid modifying the original schema | |
processed_schema = schema.copy() | |
# Process the schema recursively | |
def process_node(node): | |
if not isinstance(node, dict): | |
return node | |
processed_node = node.copy() | |
# Handle $ref directly | |
if "$ref" in processed_node: | |
ref_path = processed_node["$ref"] | |
if ref_path.startswith("#/$defs/"): | |
def_name = ref_path.split("/")[-1] | |
if "$defs" in schema and def_name in schema["$defs"]: | |
# Replace the reference with the actual definition | |
ref_definition = schema["$defs"][def_name].copy() | |
# Preserve any additional properties like description | |
for key, value in processed_node.items(): | |
if key != "$ref": | |
ref_definition[key] = value | |
# Process the referenced definition recursively | |
processed_node = process_node(ref_definition) | |
# Handle anyOf with references (used for Optional fields) | |
if "anyOf" in processed_node: | |
# For format instructions, we'll just take the first non-null type | |
# from anyOf as a simplification | |
for item in processed_node["anyOf"]: | |
if item.get("type") != "null" and "$ref" in item: | |
ref_path = item["$ref"] | |
if ref_path.startswith("#/$defs/"): | |
def_name = ref_path.split("/")[-1] | |
if "$defs" in schema and def_name in schema["$defs"]: | |
# Replace anyOf with the referenced definition | |
ref_definition = schema["$defs"][def_name].copy() | |
# Preserve any additional properties like description from the parent | |
for key, value in processed_node.items(): | |
if key != "anyOf": | |
ref_definition[key] = value | |
# Process the referenced definition recursively | |
processed_node = process_node(ref_definition) | |
break | |
# Process properties recursively | |
if "properties" in processed_node: | |
for prop_name, prop_value in list(processed_node["properties"].items()): | |
processed_node["properties"][prop_name] = process_node(prop_value) | |
# Process array items recursively | |
if "items" in processed_node: | |
processed_node["items"] = process_node(processed_node["items"]) | |
return processed_node | |
# Start the recursive processing at the root level | |
if "properties" in processed_schema: | |
for prop_name, prop_value in list(processed_schema["properties"].items()): | |
processed_schema["properties"][prop_name] = process_node(prop_value) | |
return processed_schema | |
def parse_dict_or_list(data: Union[Dict[str, Any], List[Dict[str, Any]]], model: Type[BaseModel]) -> Union[BaseModel, List[BaseModel]]: | |
"""Parse data into Pydantic model instances. | |
Args: | |
data: Dictionary or list of dictionaries to parse | |
model: Pydantic model class to parse into | |
Returns: | |
Single model instance or list of model instances | |
Raises: | |
TypeError: If model is not a Pydantic model or data has invalid type | |
ValidationError: If Pydantic validation fails | |
""" | |
if isinstance(data, list): | |
# Handle list of objects | |
if not all(isinstance(item, dict) for item in data): | |
raise TypeError("All items in list must be dictionaries") | |
if hasattr(model, "model_validate"): | |
# Pydantic v2 | |
return [model.model_validate(item) for item in data] | |
else: | |
# Pydantic v1 | |
return [model.parse_obj(item) for item in data] | |
else: | |
# Handle single object | |
if hasattr(model, "model_validate"): | |
# Pydantic v2 | |
return model.model_validate(data) | |
else: | |
# Pydantic v1 | |
return model.parse_obj(data) | |
def parse_llm_output( | |
text: str, model: Type[BaseModel], json_wrapper_key: Optional[str] = None, strict: bool = False | |
) -> Optional[Union[BaseModel, List[BaseModel]]]: | |
"""Parse LLM output text into Pydantic model instances with configurable error handling. | |
Args: | |
text: Raw text from LLM that may contain JSON | |
model: Pydantic model class to parse into | |
json_wrapper_key: Optional key that may wrap the actual data | |
strict: If True, raise errors. If False, return None and log warning | |
Returns: | |
Single model instance or list of model instances if successful, | |
None if parsing fails in non-strict mode | |
Raises: | |
PydanticParserError: If parsing fails in strict mode | |
JsonParserError: If JSON parsing fails in strict mode | |
SchemaValidationError: If JSON schema validation fails in strict mode | |
""" | |
try: | |
# Use JSONParser to handle initial JSON parsing (let its errors propagate in strict mode) | |
json_data = JSONParser.parse_llm_output( | |
text, | |
json_wrapper_key=json_wrapper_key, | |
strict=strict, # Pass through the strict parameter | |
) | |
# If JSONParser returned None (in non-strict mode), return None | |
if json_data is None: | |
return None | |
# Convert to Pydantic model(s) | |
parsed_data = PydanticParser.parse_dict_or_list(json_data, model) | |
# If the parsed data is a list of one item and not wrapped, return just the item | |
# This makes it consistent with how most APIs would expect a single object | |
if isinstance(parsed_data, list) and len(parsed_data) == 1 and not json_wrapper_key: | |
return parsed_data[0] | |
return parsed_data | |
except ValidationError as e: | |
# Handle Pydantic validation errors | |
if strict: | |
raise PydanticParserError("Failed to validate against Pydantic model", details={"errors": e.errors()}) from e | |
logger.warning(f"Failed to validate LLM response against Pydantic model: {str(e)}") | |
logger.debug(f"Validation errors: {e.errors()}") | |
return None | |
except TypeError as e: | |
# Handle type errors from parse_dict_or_list | |
if strict: | |
raise PydanticParserError(f"Type error during parsing: {str(e)}") from e | |
logger.warning(f"Type error in LLM response: {str(e)}") | |
return None | |
def get_format_instructions(model: Type[BaseModel], json_wrapper_key: Optional[str] = None, show_array_items: int = 1) -> str: | |
"""Format a Pydantic model schema as human-readable instructions. | |
Args: | |
model: Pydantic model class | |
json_wrapper_key: Optional key to wrap the schema in an array | |
show_array_items: Number of example items to show in an array wrapper | |
Returns: | |
Formatted string with schema instructions | |
Raises: | |
TypeError: If model is not a Pydantic model | |
""" | |
json_schema = PydanticParser.to_json_schema(model) | |
# Process the schema to resolve references for better display | |
processed_schema = PydanticParser._process_schema_for_formatting(json_schema) | |
return JSONParser.get_format_instructions(processed_schema, json_wrapper_key, show_array_items) | |