John-Jiang's picture
init commit
5301c48
import functools
import json
import re
from typing import Any, Dict, List, Set, Tuple
from jinja2 import Environment, StrictUndefined, meta, nodes
from starfish.llm.prompt.prompt_template import COMPLETE_PROMPTS, PARTIAL_PROMPTS
class PromptManager:
"""Manages Jinja2 template processing with variable analysis and rendering."""
MANDATE_INSTRUCTION = """
{% if is_list_input %}
Additional Instructions:
You are provided with a list named |{{ list_input_variable }}| that contains exactly {{ input_list_length }} elements.
Processing:
1. Process each element according to the provided instructions.
2. Generate and return a JSON array containing exactly {{ input_list_length }} results, preserving the original order.
3. Your output must strictly adhere to the following JSON schema:
{{ schema_instruction }}
{% else %}
You are asked to generate exactly {{ num_records }} records and please return the data in the following JSON format:
{{ schema_instruction }}
{% endif %}
"""
def __init__(self, template_str: str, header: str = "", footer: str = ""):
"""Initialize with template string and analyze variables immediately."""
# Convert f-string format to Jinja2 if needed
template_str = self._convert_to_jinja_format(template_str)
self.template_full = f"{header}\n{template_str}\n{footer}".strip() + f"\n{self.MANDATE_INSTRUCTION}"
self._env = Environment(undefined=StrictUndefined)
self._template = self._env.from_string(self.template_full)
self._ast = self._env.parse(self.template_full)
# Analyze variables immediately to avoid repeated processing
self.all_vars, self.required_vars, self.optional_vars = self._analyze_variables()
@staticmethod
def _convert_to_jinja_format(template_str: str) -> str:
"""
Convert Python f-string or string with single braces to Jinja2 template syntax.
This method safely detects and converts Python-style interpolation with single braces
to Jinja2 double-brace format, while preserving:
1. Existing Jinja2 syntax (double braces, control structures, comments)
2. JSON/dict literals with braces and quotes
3. Other valid uses of single braces that aren't meant for interpolation
Known limitations:
- Complex expressions with string literals inside braces may not be converted properly
- Triple braces ({{{var}}}) will cause Jinja syntax errors
- Whitespace inside braces is normalized in the conversion process
Returns:
str: A string properly formatted for Jinja2
"""
if not template_str.strip():
return template_str
# If template already contains Jinja2 control structures or comments, preserve it
if re.search(r"{%.*?%}|{#.*?#}", template_str):
return template_str
# If the template already contains Jinja2 variable syntax ({{ }}), preserve it
if "{{" in template_str and "}}" in template_str:
return template_str
# Process the string character by character to handle complex cases
result = []
i = 0
while i < len(template_str):
# Look for potential variable interpolation pattern {var}
if template_str[i] == "{" and i + 1 < len(template_str):
# Skip if it looks like it might be a JSON/dict literal with quotes following
if i + 1 < len(template_str) and template_str[i + 1] in "\"'":
result.append(template_str[i])
i += 1
continue
# Skip if it's the start of an escaped brace like {{
if i + 1 < len(template_str) and template_str[i + 1] == "{":
result.append(template_str[i])
i += 1
continue
# Try to find the matching closing brace
j = i + 1
brace_depth = 1
has_quotes = False
while j < len(template_str) and brace_depth > 0:
# Track quotes inside the braces
if template_str[j] in "\"'" and (j == 0 or template_str[j - 1] != "\\"):
has_quotes = True
if template_str[j] == "{":
brace_depth += 1
elif template_str[j] == "}":
brace_depth -= 1
j += 1
# Found a matching closing brace
if brace_depth == 0 and j - i > 2: # Must have at least one char between braces
# Extract the variable expression inside the braces
var_content = template_str[i + 1 : j - 1].strip()
# Skip conversion for empty braces or very short content
if not var_content:
result.append(template_str[i:j])
i = j
continue
# Skip complex expressions with quotes for safety
if has_quotes and ('"' in var_content or "'" in var_content):
result.append(template_str[i:j])
i = j
continue
# Only convert if it looks like a valid variable name or expression
# This pattern matches most variable names, attributes, indexing, and operators
# but avoids converting things that look like JSON objects
if re.match(r"^[a-zA-Z0-9_\.\[\]\(\)\+\-\*\/\|\s\%\<\>\=\!\&]+$", var_content):
result.append("{{ ")
result.append(var_content)
result.append(" }}")
i = j
continue
# If we get here, it's an expression we're unsure about - preserve it
result.append(template_str[i:j])
i = j
continue
# No special case, add the current character
result.append(template_str[i])
i += 1
return "".join(result)
@classmethod
def from_string(cls, template_str: str, header: str = "", footer: str = "") -> "PromptManager":
"""Create from template string."""
return cls(template_str, header, footer)
def get_all_variables(self) -> List[str]:
"""Return all variables in the template."""
return list(self.all_vars)
def get_prompt(self) -> List[str]:
"""Return all variables in the template."""
return self.template_full
def _analyze_variables(self) -> Tuple[Set[str], Set[str], Set[str]]:
"""Analyze variables to identify required vs optional.
Returns:
Tuple containing (all_vars, required_vars, optional_vars)
"""
# Track all variables by context
root_vars = set() # Variables used at root level (required)
conditional_vars = set() # Variables only used in conditional blocks
# Helper function to extract variables from a node
def extract_variables_from_node(node, result):
if isinstance(node, nodes.Name):
result.add(node.name)
elif isinstance(node, nodes.Getattr) and isinstance(node.node, nodes.Name):
result.add(node.node.name)
elif isinstance(node, nodes.Filter):
if isinstance(node.node, nodes.Name):
result.add(node.node.name)
elif hasattr(node, "node"):
extract_variables_from_node(node.node, result)
# Helper function to extract variables from If test conditions
def extract_test_variables(node, result):
if isinstance(node, nodes.Name):
result.add(node.name)
elif isinstance(node, nodes.BinExpr):
extract_test_variables(node.left, result)
extract_test_variables(node.right, result)
elif isinstance(node, nodes.Compare):
extract_test_variables(node.expr, result)
for op in node.ops:
# Handle different Jinja2 versions - in some versions, op is a tuple,
# in others, it's an Operand object with an 'expr' attribute
if hasattr(op, "expr"):
extract_test_variables(op.expr, result)
else:
extract_test_variables(op[1], result)
elif isinstance(node, nodes.Test):
if hasattr(node, "node"):
extract_test_variables(node.node, result)
elif isinstance(node, nodes.Const):
# Constants don't contribute variable names
pass
# Helper to process the template
def visit_node(node, in_conditional=False):
if isinstance(node, nodes.If):
# Extract variables from the test condition (always optional)
test_vars = set()
extract_test_variables(node.test, test_vars)
conditional_vars.update(test_vars)
# Process the if block
for child in node.body:
visit_node(child, in_conditional=True)
# Process the else block if it exists
if node.else_:
for child in node.else_:
visit_node(child, in_conditional=True)
elif isinstance(node, nodes.Output):
# Process output expressions
for expr in node.nodes:
if isinstance(expr, nodes.TemplateData):
# Skip plain text
continue
# Extract variables
vars_in_expr = set()
extract_variables_from_node(expr, vars_in_expr)
if in_conditional:
# Variables in conditional blocks
conditional_vars.update(vars_in_expr)
else:
# Variables at root level
root_vars.update(vars_in_expr)
# Process other nodes recursively
if hasattr(node, "body") and not isinstance(node, nodes.If): # Already processed If bodies
for child in node.body:
visit_node(child, in_conditional)
# Start traversal of the template
for node in self._ast.body:
visit_node(node)
# All variables in the template
all_vars = meta.find_undeclared_variables(self._ast)
# Required are variables at root level
required_vars = root_vars
# Optional are variables that ONLY appear in conditional blocks
optional_vars = all_vars - required_vars
return all_vars, required_vars, optional_vars
def get_optional_variables(self) -> List[str]:
"""Return optional variables (used only in conditionals)."""
return list(self.optional_vars)
def get_required_variables(self) -> List[str]:
"""Return required variables (used outside conditionals)."""
return list(self.required_vars)
def render_template(self, variables: Dict[str, Any]) -> str:
"""Render template with variables. Validates required variables exist and aren't None.
Raises:
ValueError: If a required variable is missing or None
"""
# Validate required variables
for var in self.required_vars:
if var == "num_records":
continue # num_records has default value of 1, so it is not required
if var not in variables:
raise ValueError(f"Required variable '{var}' is missing")
if variables[var] is None:
raise ValueError(f"Required variable '{var}' cannot be None")
# Create a copy of variables to avoid modifying the original
render_vars = variables.copy()
# Check for list inputs with priority given to required variables
is_list_input = False
list_input_variable = None
input_list_length = None
# Check variables in priority order (required first, then optional)
for var in list(self.required_vars) + list(self.optional_vars):
if var in render_vars and isinstance(render_vars[var], list):
is_list_input = True
list_input_variable = var
input_list_length = len(render_vars[var])
# Add reference to the variable in the prompt before serializing
original_list = render_vars[var]
render_vars[var] = f"|{var}| : {json.dumps(original_list)}"
break # Stop after finding the first list
# Add default None for optional variables (except for num_records which gets special treatment)
for var in self.optional_vars:
if var != "num_records" and var not in render_vars:
render_vars[var] = None
# Add list processing variables
render_vars["is_list_input"] = is_list_input
render_vars["list_input_variable"] = list_input_variable
render_vars["input_list_length"] = input_list_length
# Add default num_records (always use default value of 1 if not specified)
render_vars["num_records"] = render_vars.get("num_records", 1)
return self._template.render(**render_vars)
def construct_messages(self, variables: Dict[str, Any]) -> List[Dict[str, str]]:
"""Return template rendered as a system message."""
rendered_template = self.render_template(variables)
return [{"role": "user", "content": rendered_template}]
def get_printable_messages(self, messages: List[Dict[str, str]]) -> str:
"""Format constructed messages with visual separators."""
lines = ["\n" + "=" * 80, "πŸ“ CONSTRUCTED MESSAGES:", "=" * 80]
for msg in messages:
lines.append(f"\nRole: {msg['role']}\nContent:\n{msg['content']}")
lines.extend(["\n" + "=" * 80, "End of prompt", "=" * 80 + "\n"])
return "\n".join(lines)
def print_template(self, variables: Dict[str, Any]) -> None:
"""Render and print template with formatting."""
rendered = self.render_template(variables)
print(self.get_printable_messages(rendered))
@functools.lru_cache(maxsize=32)
def get_prompt(prompt_name: str):
"""Get a complete preset prompt template that requires no additional template content."""
if prompt_name not in COMPLETE_PROMPTS:
available = ", ".join(COMPLETE_PROMPTS.keys())
raise ValueError(f"Unknown complete prompt: '{prompt_name}'. Available options: {available}")
return COMPLETE_PROMPTS[prompt_name]
@functools.lru_cache(maxsize=32)
def get_partial_prompt(prompt_name: str, template_str: str) -> PromptManager:
"""Get a partial prompt combined with user-provided template content."""
if prompt_name not in PARTIAL_PROMPTS:
available = ", ".join(PARTIAL_PROMPTS.keys())
raise ValueError(f"Unknown partial prompt: '{prompt_name}'. Available options: {available}")
partial = PARTIAL_PROMPTS[prompt_name]
header = partial.get("header", "")
footer = partial.get("footer", "")
return PromptManager(template_str, header, footer)