Spaces:
Running
Running
import functools | |
import json | |
import re | |
from typing import Any, Dict, List, Set, Tuple | |
from jinja2 import Environment, StrictUndefined, meta, nodes | |
from starfish.llm.prompt.prompt_template import COMPLETE_PROMPTS, PARTIAL_PROMPTS | |
class PromptManager: | |
"""Manages Jinja2 template processing with variable analysis and rendering.""" | |
MANDATE_INSTRUCTION = """ | |
{% if is_list_input %} | |
Additional Instructions: | |
You are provided with a list named |{{ list_input_variable }}| that contains exactly {{ input_list_length }} elements. | |
Processing: | |
1. Process each element according to the provided instructions. | |
2. Generate and return a JSON array containing exactly {{ input_list_length }} results, preserving the original order. | |
3. Your output must strictly adhere to the following JSON schema: | |
{{ schema_instruction }} | |
{% else %} | |
You are asked to generate exactly {{ num_records }} records and please return the data in the following JSON format: | |
{{ schema_instruction }} | |
{% endif %} | |
""" | |
def __init__(self, template_str: str, header: str = "", footer: str = ""): | |
"""Initialize with template string and analyze variables immediately.""" | |
# Convert f-string format to Jinja2 if needed | |
template_str = self._convert_to_jinja_format(template_str) | |
self.template_full = f"{header}\n{template_str}\n{footer}".strip() + f"\n{self.MANDATE_INSTRUCTION}" | |
self._env = Environment(undefined=StrictUndefined) | |
self._template = self._env.from_string(self.template_full) | |
self._ast = self._env.parse(self.template_full) | |
# Analyze variables immediately to avoid repeated processing | |
self.all_vars, self.required_vars, self.optional_vars = self._analyze_variables() | |
def _convert_to_jinja_format(template_str: str) -> str: | |
""" | |
Convert Python f-string or string with single braces to Jinja2 template syntax. | |
This method safely detects and converts Python-style interpolation with single braces | |
to Jinja2 double-brace format, while preserving: | |
1. Existing Jinja2 syntax (double braces, control structures, comments) | |
2. JSON/dict literals with braces and quotes | |
3. Other valid uses of single braces that aren't meant for interpolation | |
Known limitations: | |
- Complex expressions with string literals inside braces may not be converted properly | |
- Triple braces ({{{var}}}) will cause Jinja syntax errors | |
- Whitespace inside braces is normalized in the conversion process | |
Returns: | |
str: A string properly formatted for Jinja2 | |
""" | |
if not template_str.strip(): | |
return template_str | |
# If template already contains Jinja2 control structures or comments, preserve it | |
if re.search(r"{%.*?%}|{#.*?#}", template_str): | |
return template_str | |
# If the template already contains Jinja2 variable syntax ({{ }}), preserve it | |
if "{{" in template_str and "}}" in template_str: | |
return template_str | |
# Process the string character by character to handle complex cases | |
result = [] | |
i = 0 | |
while i < len(template_str): | |
# Look for potential variable interpolation pattern {var} | |
if template_str[i] == "{" and i + 1 < len(template_str): | |
# Skip if it looks like it might be a JSON/dict literal with quotes following | |
if i + 1 < len(template_str) and template_str[i + 1] in "\"'": | |
result.append(template_str[i]) | |
i += 1 | |
continue | |
# Skip if it's the start of an escaped brace like {{ | |
if i + 1 < len(template_str) and template_str[i + 1] == "{": | |
result.append(template_str[i]) | |
i += 1 | |
continue | |
# Try to find the matching closing brace | |
j = i + 1 | |
brace_depth = 1 | |
has_quotes = False | |
while j < len(template_str) and brace_depth > 0: | |
# Track quotes inside the braces | |
if template_str[j] in "\"'" and (j == 0 or template_str[j - 1] != "\\"): | |
has_quotes = True | |
if template_str[j] == "{": | |
brace_depth += 1 | |
elif template_str[j] == "}": | |
brace_depth -= 1 | |
j += 1 | |
# Found a matching closing brace | |
if brace_depth == 0 and j - i > 2: # Must have at least one char between braces | |
# Extract the variable expression inside the braces | |
var_content = template_str[i + 1 : j - 1].strip() | |
# Skip conversion for empty braces or very short content | |
if not var_content: | |
result.append(template_str[i:j]) | |
i = j | |
continue | |
# Skip complex expressions with quotes for safety | |
if has_quotes and ('"' in var_content or "'" in var_content): | |
result.append(template_str[i:j]) | |
i = j | |
continue | |
# Only convert if it looks like a valid variable name or expression | |
# This pattern matches most variable names, attributes, indexing, and operators | |
# but avoids converting things that look like JSON objects | |
if re.match(r"^[a-zA-Z0-9_\.\[\]\(\)\+\-\*\/\|\s\%\<\>\=\!\&]+$", var_content): | |
result.append("{{ ") | |
result.append(var_content) | |
result.append(" }}") | |
i = j | |
continue | |
# If we get here, it's an expression we're unsure about - preserve it | |
result.append(template_str[i:j]) | |
i = j | |
continue | |
# No special case, add the current character | |
result.append(template_str[i]) | |
i += 1 | |
return "".join(result) | |
def from_string(cls, template_str: str, header: str = "", footer: str = "") -> "PromptManager": | |
"""Create from template string.""" | |
return cls(template_str, header, footer) | |
def get_all_variables(self) -> List[str]: | |
"""Return all variables in the template.""" | |
return list(self.all_vars) | |
def get_prompt(self) -> List[str]: | |
"""Return all variables in the template.""" | |
return self.template_full | |
def _analyze_variables(self) -> Tuple[Set[str], Set[str], Set[str]]: | |
"""Analyze variables to identify required vs optional. | |
Returns: | |
Tuple containing (all_vars, required_vars, optional_vars) | |
""" | |
# Track all variables by context | |
root_vars = set() # Variables used at root level (required) | |
conditional_vars = set() # Variables only used in conditional blocks | |
# Helper function to extract variables from a node | |
def extract_variables_from_node(node, result): | |
if isinstance(node, nodes.Name): | |
result.add(node.name) | |
elif isinstance(node, nodes.Getattr) and isinstance(node.node, nodes.Name): | |
result.add(node.node.name) | |
elif isinstance(node, nodes.Filter): | |
if isinstance(node.node, nodes.Name): | |
result.add(node.node.name) | |
elif hasattr(node, "node"): | |
extract_variables_from_node(node.node, result) | |
# Helper function to extract variables from If test conditions | |
def extract_test_variables(node, result): | |
if isinstance(node, nodes.Name): | |
result.add(node.name) | |
elif isinstance(node, nodes.BinExpr): | |
extract_test_variables(node.left, result) | |
extract_test_variables(node.right, result) | |
elif isinstance(node, nodes.Compare): | |
extract_test_variables(node.expr, result) | |
for op in node.ops: | |
# Handle different Jinja2 versions - in some versions, op is a tuple, | |
# in others, it's an Operand object with an 'expr' attribute | |
if hasattr(op, "expr"): | |
extract_test_variables(op.expr, result) | |
else: | |
extract_test_variables(op[1], result) | |
elif isinstance(node, nodes.Test): | |
if hasattr(node, "node"): | |
extract_test_variables(node.node, result) | |
elif isinstance(node, nodes.Const): | |
# Constants don't contribute variable names | |
pass | |
# Helper to process the template | |
def visit_node(node, in_conditional=False): | |
if isinstance(node, nodes.If): | |
# Extract variables from the test condition (always optional) | |
test_vars = set() | |
extract_test_variables(node.test, test_vars) | |
conditional_vars.update(test_vars) | |
# Process the if block | |
for child in node.body: | |
visit_node(child, in_conditional=True) | |
# Process the else block if it exists | |
if node.else_: | |
for child in node.else_: | |
visit_node(child, in_conditional=True) | |
elif isinstance(node, nodes.Output): | |
# Process output expressions | |
for expr in node.nodes: | |
if isinstance(expr, nodes.TemplateData): | |
# Skip plain text | |
continue | |
# Extract variables | |
vars_in_expr = set() | |
extract_variables_from_node(expr, vars_in_expr) | |
if in_conditional: | |
# Variables in conditional blocks | |
conditional_vars.update(vars_in_expr) | |
else: | |
# Variables at root level | |
root_vars.update(vars_in_expr) | |
# Process other nodes recursively | |
if hasattr(node, "body") and not isinstance(node, nodes.If): # Already processed If bodies | |
for child in node.body: | |
visit_node(child, in_conditional) | |
# Start traversal of the template | |
for node in self._ast.body: | |
visit_node(node) | |
# All variables in the template | |
all_vars = meta.find_undeclared_variables(self._ast) | |
# Required are variables at root level | |
required_vars = root_vars | |
# Optional are variables that ONLY appear in conditional blocks | |
optional_vars = all_vars - required_vars | |
return all_vars, required_vars, optional_vars | |
def get_optional_variables(self) -> List[str]: | |
"""Return optional variables (used only in conditionals).""" | |
return list(self.optional_vars) | |
def get_required_variables(self) -> List[str]: | |
"""Return required variables (used outside conditionals).""" | |
return list(self.required_vars) | |
def render_template(self, variables: Dict[str, Any]) -> str: | |
"""Render template with variables. Validates required variables exist and aren't None. | |
Raises: | |
ValueError: If a required variable is missing or None | |
""" | |
# Validate required variables | |
for var in self.required_vars: | |
if var == "num_records": | |
continue # num_records has default value of 1, so it is not required | |
if var not in variables: | |
raise ValueError(f"Required variable '{var}' is missing") | |
if variables[var] is None: | |
raise ValueError(f"Required variable '{var}' cannot be None") | |
# Create a copy of variables to avoid modifying the original | |
render_vars = variables.copy() | |
# Check for list inputs with priority given to required variables | |
is_list_input = False | |
list_input_variable = None | |
input_list_length = None | |
# Check variables in priority order (required first, then optional) | |
for var in list(self.required_vars) + list(self.optional_vars): | |
if var in render_vars and isinstance(render_vars[var], list): | |
is_list_input = True | |
list_input_variable = var | |
input_list_length = len(render_vars[var]) | |
# Add reference to the variable in the prompt before serializing | |
original_list = render_vars[var] | |
render_vars[var] = f"|{var}| : {json.dumps(original_list)}" | |
break # Stop after finding the first list | |
# Add default None for optional variables (except for num_records which gets special treatment) | |
for var in self.optional_vars: | |
if var != "num_records" and var not in render_vars: | |
render_vars[var] = None | |
# Add list processing variables | |
render_vars["is_list_input"] = is_list_input | |
render_vars["list_input_variable"] = list_input_variable | |
render_vars["input_list_length"] = input_list_length | |
# Add default num_records (always use default value of 1 if not specified) | |
render_vars["num_records"] = render_vars.get("num_records", 1) | |
return self._template.render(**render_vars) | |
def construct_messages(self, variables: Dict[str, Any]) -> List[Dict[str, str]]: | |
"""Return template rendered as a system message.""" | |
rendered_template = self.render_template(variables) | |
return [{"role": "user", "content": rendered_template}] | |
def get_printable_messages(self, messages: List[Dict[str, str]]) -> str: | |
"""Format constructed messages with visual separators.""" | |
lines = ["\n" + "=" * 80, "π CONSTRUCTED MESSAGES:", "=" * 80] | |
for msg in messages: | |
lines.append(f"\nRole: {msg['role']}\nContent:\n{msg['content']}") | |
lines.extend(["\n" + "=" * 80, "End of prompt", "=" * 80 + "\n"]) | |
return "\n".join(lines) | |
def print_template(self, variables: Dict[str, Any]) -> None: | |
"""Render and print template with formatting.""" | |
rendered = self.render_template(variables) | |
print(self.get_printable_messages(rendered)) | |
def get_prompt(prompt_name: str): | |
"""Get a complete preset prompt template that requires no additional template content.""" | |
if prompt_name not in COMPLETE_PROMPTS: | |
available = ", ".join(COMPLETE_PROMPTS.keys()) | |
raise ValueError(f"Unknown complete prompt: '{prompt_name}'. Available options: {available}") | |
return COMPLETE_PROMPTS[prompt_name] | |
def get_partial_prompt(prompt_name: str, template_str: str) -> PromptManager: | |
"""Get a partial prompt combined with user-provided template content.""" | |
if prompt_name not in PARTIAL_PROMPTS: | |
available = ", ".join(PARTIAL_PROMPTS.keys()) | |
raise ValueError(f"Unknown partial prompt: '{prompt_name}'. Available options: {available}") | |
partial = PARTIAL_PROMPTS[prompt_name] | |
header = partial.get("header", "") | |
footer = partial.get("footer", "") | |
return PromptManager(template_str, header, footer) | |