File size: 15,463 Bytes
5301c48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
import functools
import json
import re
from typing import Any, Dict, List, Set, Tuple

from jinja2 import Environment, StrictUndefined, meta, nodes

from starfish.llm.prompt.prompt_template import COMPLETE_PROMPTS, PARTIAL_PROMPTS


class PromptManager:
    """Manages Jinja2 template processing with variable analysis and rendering."""

    MANDATE_INSTRUCTION = """
{% if is_list_input %}
Additional Instructions:

You are provided with a list named |{{ list_input_variable }}| that contains exactly {{ input_list_length }} elements.

Processing:

1. Process each element according to the provided instructions.
2. Generate and return a JSON array containing exactly {{ input_list_length }} results, preserving the original order.
3. Your output must strictly adhere to the following JSON schema:
{{ schema_instruction }}

{% else %}
You are asked to generate exactly {{ num_records }} records and please return the data in the following JSON format:
{{ schema_instruction }}
{% endif %}
"""

    def __init__(self, template_str: str, header: str = "", footer: str = ""):
        """Initialize with template string and analyze variables immediately."""
        # Convert f-string format to Jinja2 if needed
        template_str = self._convert_to_jinja_format(template_str)

        self.template_full = f"{header}\n{template_str}\n{footer}".strip() + f"\n{self.MANDATE_INSTRUCTION}"
        self._env = Environment(undefined=StrictUndefined)
        self._template = self._env.from_string(self.template_full)
        self._ast = self._env.parse(self.template_full)

        # Analyze variables immediately to avoid repeated processing
        self.all_vars, self.required_vars, self.optional_vars = self._analyze_variables()

    @staticmethod
    def _convert_to_jinja_format(template_str: str) -> str:
        """
        Convert Python f-string or string with single braces to Jinja2 template syntax.

        This method safely detects and converts Python-style interpolation with single braces
        to Jinja2 double-brace format, while preserving:

        1. Existing Jinja2 syntax (double braces, control structures, comments)
        2. JSON/dict literals with braces and quotes
        3. Other valid uses of single braces that aren't meant for interpolation

        Known limitations:
        - Complex expressions with string literals inside braces may not be converted properly
        - Triple braces ({{{var}}}) will cause Jinja syntax errors
        - Whitespace inside braces is normalized in the conversion process

        Returns:
            str: A string properly formatted for Jinja2
        """
        if not template_str.strip():
            return template_str

        # If template already contains Jinja2 control structures or comments, preserve it
        if re.search(r"{%.*?%}|{#.*?#}", template_str):
            return template_str

        # If the template already contains Jinja2 variable syntax ({{ }}), preserve it
        if "{{" in template_str and "}}" in template_str:
            return template_str

        # Process the string character by character to handle complex cases
        result = []
        i = 0
        while i < len(template_str):
            # Look for potential variable interpolation pattern {var}
            if template_str[i] == "{" and i + 1 < len(template_str):
                # Skip if it looks like it might be a JSON/dict literal with quotes following
                if i + 1 < len(template_str) and template_str[i + 1] in "\"'":
                    result.append(template_str[i])
                    i += 1
                    continue

                # Skip if it's the start of an escaped brace like {{
                if i + 1 < len(template_str) and template_str[i + 1] == "{":
                    result.append(template_str[i])
                    i += 1
                    continue

                # Try to find the matching closing brace
                j = i + 1
                brace_depth = 1
                has_quotes = False

                while j < len(template_str) and brace_depth > 0:
                    # Track quotes inside the braces
                    if template_str[j] in "\"'" and (j == 0 or template_str[j - 1] != "\\"):
                        has_quotes = True

                    if template_str[j] == "{":
                        brace_depth += 1
                    elif template_str[j] == "}":
                        brace_depth -= 1
                    j += 1

                # Found a matching closing brace
                if brace_depth == 0 and j - i > 2:  # Must have at least one char between braces
                    # Extract the variable expression inside the braces
                    var_content = template_str[i + 1 : j - 1].strip()

                    # Skip conversion for empty braces or very short content
                    if not var_content:
                        result.append(template_str[i:j])
                        i = j
                        continue

                    # Skip complex expressions with quotes for safety
                    if has_quotes and ('"' in var_content or "'" in var_content):
                        result.append(template_str[i:j])
                        i = j
                        continue

                    # Only convert if it looks like a valid variable name or expression
                    # This pattern matches most variable names, attributes, indexing, and operators
                    # but avoids converting things that look like JSON objects
                    if re.match(r"^[a-zA-Z0-9_\.\[\]\(\)\+\-\*\/\|\s\%\<\>\=\!\&]+$", var_content):
                        result.append("{{ ")
                        result.append(var_content)
                        result.append(" }}")
                        i = j
                        continue

                    # If we get here, it's an expression we're unsure about - preserve it
                    result.append(template_str[i:j])
                    i = j
                    continue

            # No special case, add the current character
            result.append(template_str[i])
            i += 1

        return "".join(result)

    @classmethod
    def from_string(cls, template_str: str, header: str = "", footer: str = "") -> "PromptManager":
        """Create from template string."""
        return cls(template_str, header, footer)

    def get_all_variables(self) -> List[str]:
        """Return all variables in the template."""
        return list(self.all_vars)

    def get_prompt(self) -> List[str]:
        """Return all variables in the template."""
        return self.template_full

    def _analyze_variables(self) -> Tuple[Set[str], Set[str], Set[str]]:
        """Analyze variables to identify required vs optional.

        Returns:
            Tuple containing (all_vars, required_vars, optional_vars)
        """
        # Track all variables by context
        root_vars = set()  # Variables used at root level (required)
        conditional_vars = set()  # Variables only used in conditional blocks

        # Helper function to extract variables from a node
        def extract_variables_from_node(node, result):
            if isinstance(node, nodes.Name):
                result.add(node.name)
            elif isinstance(node, nodes.Getattr) and isinstance(node.node, nodes.Name):
                result.add(node.node.name)
            elif isinstance(node, nodes.Filter):
                if isinstance(node.node, nodes.Name):
                    result.add(node.node.name)
                elif hasattr(node, "node"):
                    extract_variables_from_node(node.node, result)

        # Helper function to extract variables from If test conditions
        def extract_test_variables(node, result):
            if isinstance(node, nodes.Name):
                result.add(node.name)
            elif isinstance(node, nodes.BinExpr):
                extract_test_variables(node.left, result)
                extract_test_variables(node.right, result)
            elif isinstance(node, nodes.Compare):
                extract_test_variables(node.expr, result)
                for op in node.ops:
                    # Handle different Jinja2 versions - in some versions, op is a tuple,
                    # in others, it's an Operand object with an 'expr' attribute
                    if hasattr(op, "expr"):
                        extract_test_variables(op.expr, result)
                    else:
                        extract_test_variables(op[1], result)
            elif isinstance(node, nodes.Test):
                if hasattr(node, "node"):
                    extract_test_variables(node.node, result)
            elif isinstance(node, nodes.Const):
                # Constants don't contribute variable names
                pass

        # Helper to process the template
        def visit_node(node, in_conditional=False):
            if isinstance(node, nodes.If):
                # Extract variables from the test condition (always optional)
                test_vars = set()
                extract_test_variables(node.test, test_vars)
                conditional_vars.update(test_vars)

                # Process the if block
                for child in node.body:
                    visit_node(child, in_conditional=True)

                # Process the else block if it exists
                if node.else_:
                    for child in node.else_:
                        visit_node(child, in_conditional=True)

            elif isinstance(node, nodes.Output):
                # Process output expressions
                for expr in node.nodes:
                    if isinstance(expr, nodes.TemplateData):
                        # Skip plain text
                        continue

                    # Extract variables
                    vars_in_expr = set()
                    extract_variables_from_node(expr, vars_in_expr)

                    if in_conditional:
                        # Variables in conditional blocks
                        conditional_vars.update(vars_in_expr)
                    else:
                        # Variables at root level
                        root_vars.update(vars_in_expr)

            # Process other nodes recursively
            if hasattr(node, "body") and not isinstance(node, nodes.If):  # Already processed If bodies
                for child in node.body:
                    visit_node(child, in_conditional)

        # Start traversal of the template
        for node in self._ast.body:
            visit_node(node)

        # All variables in the template
        all_vars = meta.find_undeclared_variables(self._ast)

        # Required are variables at root level
        required_vars = root_vars

        # Optional are variables that ONLY appear in conditional blocks
        optional_vars = all_vars - required_vars

        return all_vars, required_vars, optional_vars

    def get_optional_variables(self) -> List[str]:
        """Return optional variables (used only in conditionals)."""
        return list(self.optional_vars)

    def get_required_variables(self) -> List[str]:
        """Return required variables (used outside conditionals)."""
        return list(self.required_vars)

    def render_template(self, variables: Dict[str, Any]) -> str:
        """Render template with variables. Validates required variables exist and aren't None.

        Raises:
            ValueError: If a required variable is missing or None
        """
        # Validate required variables
        for var in self.required_vars:
            if var == "num_records":
                continue  # num_records has default value of 1, so it is not required
            if var not in variables:
                raise ValueError(f"Required variable '{var}' is missing")
            if variables[var] is None:
                raise ValueError(f"Required variable '{var}' cannot be None")

        # Create a copy of variables to avoid modifying the original
        render_vars = variables.copy()

        # Check for list inputs with priority given to required variables
        is_list_input = False
        list_input_variable = None
        input_list_length = None

        # Check variables in priority order (required first, then optional)
        for var in list(self.required_vars) + list(self.optional_vars):
            if var in render_vars and isinstance(render_vars[var], list):
                is_list_input = True
                list_input_variable = var
                input_list_length = len(render_vars[var])
                # Add reference to the variable in the prompt before serializing
                original_list = render_vars[var]
                render_vars[var] = f"|{var}| : {json.dumps(original_list)}"
                break  # Stop after finding the first list

        # Add default None for optional variables (except for num_records which gets special treatment)
        for var in self.optional_vars:
            if var != "num_records" and var not in render_vars:
                render_vars[var] = None

        # Add list processing variables
        render_vars["is_list_input"] = is_list_input
        render_vars["list_input_variable"] = list_input_variable
        render_vars["input_list_length"] = input_list_length

        # Add default num_records (always use default value of 1 if not specified)
        render_vars["num_records"] = render_vars.get("num_records", 1)

        return self._template.render(**render_vars)

    def construct_messages(self, variables: Dict[str, Any]) -> List[Dict[str, str]]:
        """Return template rendered as a system message."""
        rendered_template = self.render_template(variables)
        return [{"role": "user", "content": rendered_template}]

    def get_printable_messages(self, messages: List[Dict[str, str]]) -> str:
        """Format constructed messages with visual separators."""
        lines = ["\n" + "=" * 80, "πŸ“ CONSTRUCTED MESSAGES:", "=" * 80]
        for msg in messages:
            lines.append(f"\nRole: {msg['role']}\nContent:\n{msg['content']}")
        lines.extend(["\n" + "=" * 80, "End of prompt", "=" * 80 + "\n"])
        return "\n".join(lines)

    def print_template(self, variables: Dict[str, Any]) -> None:
        """Render and print template with formatting."""
        rendered = self.render_template(variables)
        print(self.get_printable_messages(rendered))


@functools.lru_cache(maxsize=32)
def get_prompt(prompt_name: str):
    """Get a complete preset prompt template that requires no additional template content."""
    if prompt_name not in COMPLETE_PROMPTS:
        available = ", ".join(COMPLETE_PROMPTS.keys())
        raise ValueError(f"Unknown complete prompt: '{prompt_name}'. Available options: {available}")

    return COMPLETE_PROMPTS[prompt_name]


@functools.lru_cache(maxsize=32)
def get_partial_prompt(prompt_name: str, template_str: str) -> PromptManager:
    """Get a partial prompt combined with user-provided template content."""
    if prompt_name not in PARTIAL_PROMPTS:
        available = ", ".join(PARTIAL_PROMPTS.keys())
        raise ValueError(f"Unknown partial prompt: '{prompt_name}'. Available options: {available}")

    partial = PARTIAL_PROMPTS[prompt_name]
    header = partial.get("header", "")
    footer = partial.get("footer", "")
    return PromptManager(template_str, header, footer)