Spaces:
Running
Running
File size: 6,603 Bytes
5301c48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
from pydantic import BaseModel
from starfish.llm.parser import JSONParser, PydanticParser
from starfish.llm.prompt import PromptManager, get_partial_prompt
from starfish.llm.proxy.litellm_adapter import call_chat_model
from starfish.llm.utils import to_sync
T = TypeVar("T")
class LLMResponse(Generic[T]):
"""Container for LLM response with both raw response and parsed data."""
def __init__(self, raw_response: Any, parsed_data: Optional[T] = None):
self.raw = raw_response
self.data = parsed_data
def __repr__(self) -> str:
return f"LLMResponse(raw={type(self.raw)}, data={type(self.data) if self.data else None})"
class StructuredLLM:
"""A builder for LLM-powered functions that can be called with custom parameters.
This class creates a callable object that handles:
- Jinja template rendering with dynamic parameters
- LLM API calls
- Response parsing according to provided schema
Provides async (`run`, `__call__`) and sync (`run_sync`) execution methods.
Use `process_list_input=True` in run methods to process a list argument.
Note: The prompt parameter accepts both Jinja2 templates and Python f-string-like
syntax with single braces {variable}. Single braces are automatically converted
to proper Jinja2 syntax with double braces {{ variable }}. This feature simplifies
template writing for users familiar with Python's f-string syntax.
"""
def __init__(
self,
model_name: str,
prompt: str,
model_kwargs: Optional[Dict[str, Any]] = None,
output_schema: Optional[Union[List[Dict[str, Any]], Dict[str, Any], type]] = None,
prompt_template: Optional[str] = None,
strict_parsing: bool = False,
type_check: bool = False,
):
"""Initialize the LLM builder.
Args:
model_name: Name of the LLM model to use (e.g., 'openai/gpt-4o-mini')
prompt: Template string in either Jinja2 format or Python-style single-brace format.
Single-brace format like "Hello {name}" will be automatically converted to
Jinja2 format "Hello {{ name }}". Existing Jinja2 templates are preserved.
model_kwargs: Additional arguments to pass to the LLM
output_schema: Schema for response parsing (JSON list/dict or Pydantic model)
prompt_template: Optional name of partial prompt template to wrap around the prompt
strict_parsing: If True, raise errors on parsing failures. If False, return None for data
type_check: If True, check field types against schema. If False, skip type validation.
"""
# Model settings
self.model_name = model_name
self.model_kwargs = model_kwargs or {}
# Initialize prompt manager
if prompt_template:
self.prompt_manager = get_partial_prompt(prompt_name=prompt_template, template_str=prompt)
else:
self.prompt_manager = PromptManager(prompt)
# Extract template variables
self.prompt_variables = self.prompt_manager.get_all_variables()
self.required_prompt_variables = self.prompt_manager.get_required_variables()
self.optional_prompt_variables = self.prompt_manager.get_optional_variables()
self.prompt = self.prompt_manager.get_prompt()
# Schema processing
self.output_schema = output_schema
self.strict_parsing = strict_parsing
self.type_check = type_check
self.is_pydantic = isinstance(output_schema, type) and issubclass(output_schema, BaseModel)
if self.output_schema:
if self.is_pydantic:
self.json_schema = PydanticParser.to_json_schema(output_schema)
else:
self.json_schema = JSONParser.convert_to_schema(output_schema)
def _get_schema_instructions(self) -> Optional[str]:
"""Get formatted schema instructions if schema is provided."""
if not self.output_schema:
return None
if self.is_pydantic:
return PydanticParser.get_format_instructions(self.output_schema)
return JSONParser.get_format_instructions(self.json_schema)
async def __call__(self, **kwargs) -> LLMResponse:
"""A convenience wrapper around the run() method."""
return await self.run(**kwargs)
def _prepare_prompt_inputs(self, **kwargs) -> Dict[str, Any]:
"""Prepare the prompt input for the LLM."""
# Filter keys that are in prompt_variables
filtered_kwargs = {k: v for k, v in kwargs.items() if k in self.prompt_variables}
# Add schema instructions if available
schema_str = self._get_schema_instructions()
if schema_str:
filtered_kwargs["schema_instruction"] = schema_str
return filtered_kwargs
def render_prompt(self, **kwargs) -> str:
# Add schema instructions if needed
prompt_inputs = self._prepare_prompt_inputs(**kwargs)
# Render the prompt template
try:
messages = self.prompt_manager.construct_messages(prompt_inputs)
except ValueError as e:
raise ValueError(f"Error rendering prompt template: {str(e)}")
return messages
def render_prompt_printable(self, **kwargs) -> str:
"""Print the prompt template with the provided parameters."""
messages = self.render_prompt(**kwargs)
return self.prompt_manager.get_printable_messages(messages)
async def run(self, **kwargs) -> LLMResponse:
"""Main async method to run the LLM with the provided parameters."""
# Render the prompt template
messages = self.render_prompt(**kwargs)
# Call the LLM
raw_response = await call_chat_model(model_name=self.model_name, messages=messages, model_kwargs=self.model_kwargs)
# Parse the response
response_text = raw_response.choices[0].message.content
if not self.output_schema:
return LLMResponse(raw_response, response_text)
parsed_data = JSONParser.parse_llm_output(response_text, schema=self.json_schema, strict=self.strict_parsing, type_check=self.type_check)
return LLMResponse(raw_response, parsed_data)
@to_sync
async def run_sync(self, **kwargs) -> LLMResponse:
"""Synchronously call the LLM with the provided parameters.
When used in Jupyter notebooks, make sure to apply nest_asyncio.
"""
return await self.run(**kwargs)
|