File size: 9,496 Bytes
5301c48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
from typing import Any, Dict, List, Optional, Type, Union

from pydantic import BaseModel, ValidationError

from starfish.common.exceptions import PydanticParserError
from starfish.common.logger import get_logger

from .json_parser import JSONParser

logger = get_logger(__name__)


class PydanticParser:
    """Handles parsing and validation using Pydantic models.

    Provides utilities for converting between Pydantic and JSON schemas.
    """

    @staticmethod
    def to_json_schema(model: Type[BaseModel]) -> Dict[str, Any]:
        """Convert a Pydantic model to JSON schema.

        Args:
            model: Pydantic model class

        Returns:
            JSON schema dictionary

        Raises:
            TypeError: If model is not a Pydantic model
        """
        # Handle both Pydantic v1 and v2
        if hasattr(model, "model_json_schema"):
            # Pydantic v2
            return model.model_json_schema()
        else:
            # Pydantic v1
            return model.schema()

    @staticmethod
    def _process_schema_for_formatting(schema: Dict[str, Any]) -> Dict[str, Any]:
        """Process a Pydantic-generated JSON schema for better format instruction display.

        This resolves $ref references to make a flattened schema for display purposes.

        Args:
            schema: The Pydantic JSON schema

        Returns:
            A processed schema with resolved references
        """
        # Create a copy to avoid modifying the original schema
        processed_schema = schema.copy()

        # Process the schema recursively
        def process_node(node):
            if not isinstance(node, dict):
                return node

            processed_node = node.copy()

            # Handle $ref directly
            if "$ref" in processed_node:
                ref_path = processed_node["$ref"]
                if ref_path.startswith("#/$defs/"):
                    def_name = ref_path.split("/")[-1]
                    if "$defs" in schema and def_name in schema["$defs"]:
                        # Replace the reference with the actual definition
                        ref_definition = schema["$defs"][def_name].copy()

                        # Preserve any additional properties like description
                        for key, value in processed_node.items():
                            if key != "$ref":
                                ref_definition[key] = value

                        # Process the referenced definition recursively
                        processed_node = process_node(ref_definition)

            # Handle anyOf with references (used for Optional fields)
            if "anyOf" in processed_node:
                # For format instructions, we'll just take the first non-null type
                # from anyOf as a simplification
                for item in processed_node["anyOf"]:
                    if item.get("type") != "null" and "$ref" in item:
                        ref_path = item["$ref"]
                        if ref_path.startswith("#/$defs/"):
                            def_name = ref_path.split("/")[-1]
                            if "$defs" in schema and def_name in schema["$defs"]:
                                # Replace anyOf with the referenced definition
                                ref_definition = schema["$defs"][def_name].copy()

                                # Preserve any additional properties like description from the parent
                                for key, value in processed_node.items():
                                    if key != "anyOf":
                                        ref_definition[key] = value

                                # Process the referenced definition recursively
                                processed_node = process_node(ref_definition)
                                break

            # Process properties recursively
            if "properties" in processed_node:
                for prop_name, prop_value in list(processed_node["properties"].items()):
                    processed_node["properties"][prop_name] = process_node(prop_value)

            # Process array items recursively
            if "items" in processed_node:
                processed_node["items"] = process_node(processed_node["items"])

            return processed_node

        # Start the recursive processing at the root level
        if "properties" in processed_schema:
            for prop_name, prop_value in list(processed_schema["properties"].items()):
                processed_schema["properties"][prop_name] = process_node(prop_value)

        return processed_schema

    @staticmethod
    def parse_dict_or_list(data: Union[Dict[str, Any], List[Dict[str, Any]]], model: Type[BaseModel]) -> Union[BaseModel, List[BaseModel]]:
        """Parse data into Pydantic model instances.

        Args:
            data: Dictionary or list of dictionaries to parse
            model: Pydantic model class to parse into

        Returns:
            Single model instance or list of model instances

        Raises:
            TypeError: If model is not a Pydantic model or data has invalid type
            ValidationError: If Pydantic validation fails
        """
        if isinstance(data, list):
            # Handle list of objects
            if not all(isinstance(item, dict) for item in data):
                raise TypeError("All items in list must be dictionaries")

            if hasattr(model, "model_validate"):
                # Pydantic v2
                return [model.model_validate(item) for item in data]
            else:
                # Pydantic v1
                return [model.parse_obj(item) for item in data]
        else:
            # Handle single object
            if hasattr(model, "model_validate"):
                # Pydantic v2
                return model.model_validate(data)
            else:
                # Pydantic v1
                return model.parse_obj(data)

    @staticmethod
    def parse_llm_output(
        text: str, model: Type[BaseModel], json_wrapper_key: Optional[str] = None, strict: bool = False
    ) -> Optional[Union[BaseModel, List[BaseModel]]]:
        """Parse LLM output text into Pydantic model instances with configurable error handling.

        Args:
            text: Raw text from LLM that may contain JSON
            model: Pydantic model class to parse into
            json_wrapper_key: Optional key that may wrap the actual data
            strict: If True, raise errors. If False, return None and log warning

        Returns:
            Single model instance or list of model instances if successful,
            None if parsing fails in non-strict mode

        Raises:
            PydanticParserError: If parsing fails in strict mode
            JsonParserError: If JSON parsing fails in strict mode
            SchemaValidationError: If JSON schema validation fails in strict mode
        """
        try:
            # Use JSONParser to handle initial JSON parsing (let its errors propagate in strict mode)
            json_data = JSONParser.parse_llm_output(
                text,
                json_wrapper_key=json_wrapper_key,
                strict=strict,  # Pass through the strict parameter
            )

            # If JSONParser returned None (in non-strict mode), return None
            if json_data is None:
                return None

            # Convert to Pydantic model(s)
            parsed_data = PydanticParser.parse_dict_or_list(json_data, model)

            # If the parsed data is a list of one item and not wrapped, return just the item
            # This makes it consistent with how most APIs would expect a single object
            if isinstance(parsed_data, list) and len(parsed_data) == 1 and not json_wrapper_key:
                return parsed_data[0]

            return parsed_data

        except ValidationError as e:
            # Handle Pydantic validation errors
            if strict:
                raise PydanticParserError("Failed to validate against Pydantic model", details={"errors": e.errors()}) from e
            logger.warning(f"Failed to validate LLM response against Pydantic model: {str(e)}")
            logger.debug(f"Validation errors: {e.errors()}")
            return None
        except TypeError as e:
            # Handle type errors from parse_dict_or_list
            if strict:
                raise PydanticParserError(f"Type error during parsing: {str(e)}") from e
            logger.warning(f"Type error in LLM response: {str(e)}")
            return None

    @staticmethod
    def get_format_instructions(model: Type[BaseModel], json_wrapper_key: Optional[str] = None, show_array_items: int = 1) -> str:
        """Format a Pydantic model schema as human-readable instructions.

        Args:
            model: Pydantic model class
            json_wrapper_key: Optional key to wrap the schema in an array
            show_array_items: Number of example items to show in an array wrapper

        Returns:
            Formatted string with schema instructions

        Raises:
            TypeError: If model is not a Pydantic model
        """
        json_schema = PydanticParser.to_json_schema(model)
        # Process the schema to resolve references for better display
        processed_schema = PydanticParser._process_schema_for_formatting(json_schema)
        return JSONParser.get_format_instructions(processed_schema, json_wrapper_key, show_array_items)