File size: 6,603 Bytes
5301c48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union

from pydantic import BaseModel

from starfish.llm.parser import JSONParser, PydanticParser
from starfish.llm.prompt import PromptManager, get_partial_prompt
from starfish.llm.proxy.litellm_adapter import call_chat_model
from starfish.llm.utils import to_sync

T = TypeVar("T")


class LLMResponse(Generic[T]):
    """Container for LLM response with both raw response and parsed data."""

    def __init__(self, raw_response: Any, parsed_data: Optional[T] = None):
        self.raw = raw_response
        self.data = parsed_data

    def __repr__(self) -> str:
        return f"LLMResponse(raw={type(self.raw)}, data={type(self.data) if self.data else None})"


class StructuredLLM:
    """A builder for LLM-powered functions that can be called with custom parameters.

    This class creates a callable object that handles:
    - Jinja template rendering with dynamic parameters
    - LLM API calls
    - Response parsing according to provided schema

    Provides async (`run`, `__call__`) and sync (`run_sync`) execution methods.
    Use `process_list_input=True` in run methods to process a list argument.

    Note: The prompt parameter accepts both Jinja2 templates and Python f-string-like
    syntax with single braces {variable}. Single braces are automatically converted
    to proper Jinja2 syntax with double braces {{ variable }}. This feature simplifies
    template writing for users familiar with Python's f-string syntax.
    """

    def __init__(
        self,
        model_name: str,
        prompt: str,
        model_kwargs: Optional[Dict[str, Any]] = None,
        output_schema: Optional[Union[List[Dict[str, Any]], Dict[str, Any], type]] = None,
        prompt_template: Optional[str] = None,
        strict_parsing: bool = False,
        type_check: bool = False,
    ):
        """Initialize the LLM builder.

        Args:
            model_name: Name of the LLM model to use (e.g., 'openai/gpt-4o-mini')
            prompt: Template string in either Jinja2 format or Python-style single-brace format.
                   Single-brace format like "Hello {name}" will be automatically converted to
                   Jinja2 format "Hello {{ name }}". Existing Jinja2 templates are preserved.
            model_kwargs: Additional arguments to pass to the LLM
            output_schema: Schema for response parsing (JSON list/dict or Pydantic model)
            prompt_template: Optional name of partial prompt template to wrap around the prompt
            strict_parsing: If True, raise errors on parsing failures. If False, return None for data
            type_check: If True, check field types against schema. If False, skip type validation.
        """
        # Model settings
        self.model_name = model_name
        self.model_kwargs = model_kwargs or {}

        # Initialize prompt manager
        if prompt_template:
            self.prompt_manager = get_partial_prompt(prompt_name=prompt_template, template_str=prompt)
        else:
            self.prompt_manager = PromptManager(prompt)

        # Extract template variables
        self.prompt_variables = self.prompt_manager.get_all_variables()
        self.required_prompt_variables = self.prompt_manager.get_required_variables()
        self.optional_prompt_variables = self.prompt_manager.get_optional_variables()
        self.prompt = self.prompt_manager.get_prompt()

        # Schema processing
        self.output_schema = output_schema
        self.strict_parsing = strict_parsing
        self.type_check = type_check
        self.is_pydantic = isinstance(output_schema, type) and issubclass(output_schema, BaseModel)

        if self.output_schema:
            if self.is_pydantic:
                self.json_schema = PydanticParser.to_json_schema(output_schema)
            else:
                self.json_schema = JSONParser.convert_to_schema(output_schema)

    def _get_schema_instructions(self) -> Optional[str]:
        """Get formatted schema instructions if schema is provided."""
        if not self.output_schema:
            return None

        if self.is_pydantic:
            return PydanticParser.get_format_instructions(self.output_schema)
        return JSONParser.get_format_instructions(self.json_schema)

    async def __call__(self, **kwargs) -> LLMResponse:
        """A convenience wrapper around the run() method."""
        return await self.run(**kwargs)

    def _prepare_prompt_inputs(self, **kwargs) -> Dict[str, Any]:
        """Prepare the prompt input for the LLM."""
        # Filter keys that are in prompt_variables
        filtered_kwargs = {k: v for k, v in kwargs.items() if k in self.prompt_variables}

        # Add schema instructions if available
        schema_str = self._get_schema_instructions()
        if schema_str:
            filtered_kwargs["schema_instruction"] = schema_str

        return filtered_kwargs

    def render_prompt(self, **kwargs) -> str:
        # Add schema instructions if needed
        prompt_inputs = self._prepare_prompt_inputs(**kwargs)
        # Render the prompt template
        try:
            messages = self.prompt_manager.construct_messages(prompt_inputs)
        except ValueError as e:
            raise ValueError(f"Error rendering prompt template: {str(e)}")
        return messages

    def render_prompt_printable(self, **kwargs) -> str:
        """Print the prompt template with the provided parameters."""
        messages = self.render_prompt(**kwargs)
        return self.prompt_manager.get_printable_messages(messages)

    async def run(self, **kwargs) -> LLMResponse:
        """Main async method to run the LLM with the provided parameters."""
        # Render the prompt template
        messages = self.render_prompt(**kwargs)

        # Call the LLM
        raw_response = await call_chat_model(model_name=self.model_name, messages=messages, model_kwargs=self.model_kwargs)

        # Parse the response
        response_text = raw_response.choices[0].message.content
        if not self.output_schema:
            return LLMResponse(raw_response, response_text)

        parsed_data = JSONParser.parse_llm_output(response_text, schema=self.json_schema, strict=self.strict_parsing, type_check=self.type_check)

        return LLMResponse(raw_response, parsed_data)

    @to_sync
    async def run_sync(self, **kwargs) -> LLMResponse:
        """Synchronously call the LLM with the provided parameters.

        When used in Jupyter notebooks, make sure to apply nest_asyncio.
        """
        return await self.run(**kwargs)