Spaces:
Running
Running
import pytest | |
from starfish.llm.prompt import PromptManager | |
from starfish.llm.prompt.prompt_loader import get_partial_prompt, get_prompt | |
"""Tests for the PromptManager class and related functionality in starfish.llm.prompt. | |
This test suite covers: | |
1. Variable identification (required vs optional variables in templates) | |
2. Template rendering with different variable combinations | |
3. Error handling for missing or invalid variables | |
4. Special features like list input detection, num_records handling | |
5. Utility methods like get_prompt and get_partial_prompt | |
Note: Although the PromptManager appends a MANDATE_INSTRUCTION that contains | |
schema_instruction and other variables, it does not treat schema_instruction | |
as a required variable. The test cases follow the actual implementation behavior | |
which identifies only variables used outside of conditional blocks in the | |
original user template as "required". | |
""" | |
# Utility functions for test setup | |
def get_expected_mandate_vars(required=False): | |
"""Get the list of variables added by MANDATE_INSTRUCTION. | |
Args: | |
required: If True, return only variables that would be classified as required. | |
If False (default), return only variables that would be classified as optional. | |
Returns: | |
Set of variable names. | |
""" | |
all_mandate_vars = {"is_list_input", "list_input_variable", "input_list_length", "schema_instruction", "num_records"} | |
# In practice, none of the mandate variables are treated as required | |
if required: | |
return set() | |
else: | |
return all_mandate_vars | |
# Fixtures | |
def basic_template(): | |
"""Simple template with basic variables for testing.""" | |
return "Hello, {{ name }}! Your age is {{ age }}." | |
def simple_prompt_manager(basic_template): | |
"""Basic PromptManager instance with a simple template.""" | |
return PromptManager(basic_template) | |
def standard_variables(): | |
"""Common variables used in multiple tests.""" | |
return {"name": "Alice", "age": 30, "schema_instruction": "Test schema"} | |
class TestPromptManager: | |
"""Test cases for the PromptManager class.""" | |
# --------------------------------------------------------------------------- | |
# Tests for variable identification | |
# --------------------------------------------------------------------------- | |
def test_basic_required_variables(self, simple_prompt_manager, standard_variables): | |
"""Test identifying required variables in a basic template.""" | |
# Note: MANDATE_INSTRUCTION is automatically appended by PromptManager | |
manager = simple_prompt_manager | |
# Basic variables from template | |
expected_template_required = {"name", "age"} | |
expected_template_optional = set() | |
# Get expected variables including those from MANDATE_INSTRUCTION | |
expected_required = expected_template_required.union(get_expected_mandate_vars(required=True)) | |
expected_optional = expected_template_optional.union(get_expected_mandate_vars(required=False)) | |
expected_all = expected_required.union(expected_optional) | |
# Check variable identification | |
assert set(manager.get_all_variables()) == expected_all | |
assert set(manager.get_required_variables()) == expected_required | |
assert set(manager.get_optional_variables()) == expected_optional | |
# Test rendering - requires schema_instruction now | |
result = manager.render_template(standard_variables) | |
assert "Hello, Alice! Your age is 30." in result | |
# Check that part of the non-list mandate instruction is present | |
assert "You are asked to generate exactly 1 records" in result | |
assert "Test schema" in result | |
def test_conditional_variable_analysis(self, template, template_required, template_optional, test_name): | |
"""Test variable analysis with various conditional structures.""" | |
manager = PromptManager(template) | |
# Get the variables | |
all_vars = set(manager.get_all_variables()) | |
required_vars = set(manager.get_required_variables()) | |
optional_vars = set(manager.get_optional_variables()) | |
# Get expected variables including those from MANDATE_INSTRUCTION | |
expected_required = template_required.union(get_expected_mandate_vars(required=True)) | |
expected_optional = template_optional.union(get_expected_mandate_vars(required=False)) | |
expected_all = expected_required.union(expected_optional) | |
# Check variable identification | |
assert all_vars == expected_all, f"Failed for {test_name}: all variables" | |
assert required_vars == expected_required, f"Failed for {test_name}: required variables" | |
assert optional_vars == expected_optional, f"Failed for {test_name}: optional variables" | |
# Basic rendering test | |
if test_name == "basic_conditional": | |
# Test showing age | |
result = manager.render_template({"name": "Bob", "show_age": True, "age": 25, "schema_instruction": "Age schema"}).strip() | |
assert "Hello, Bob!" in result | |
assert "Your age is 25." in result | |
# Test hiding age | |
result = manager.render_template({"name": "Charlie", "show_age": False, "schema_instruction": "No age schema"}).strip() | |
assert "Hello, Charlie!" in result | |
assert "Your age is" not in result | |
def test_complex_templates(self): | |
"""Test more complex template structures.""" | |
template = """ | |
{% if condition1 %} | |
{{ var1 }} is shown in condition1 | |
{% if nested_condition %} | |
{{ var2 }} is in nested condition | |
{{ var3 }} is also in nested condition | |
{% endif %} | |
{% elif condition2 %} | |
{{ var1 }} is shown in condition2 | |
{{ var4 }} is only in condition2 | |
{% else %} | |
{{ var1 }} is shown in else block | |
{{ var5 }} is only in else block | |
{% endif %} | |
{{ var1 }} appears outside all conditions | |
{% if standalone_condition %} | |
{{ var6 }} is in a different conditional | |
{% endif %} | |
""" | |
manager = PromptManager(template) | |
# Template variables | |
template_required = {"var1"} | |
template_optional = {"condition1", "nested_condition", "var2", "var3", "condition2", "var4", "var5", "standalone_condition", "var6"} | |
# Get expected variables including those from MANDATE_INSTRUCTION | |
expected_req = template_required.union(get_expected_mandate_vars(required=True)) | |
expected_opt = template_optional.union(get_expected_mandate_vars(required=False)) | |
expected_all = expected_req.union(expected_opt) | |
# Check variable identification | |
all_vars = set(manager.get_all_variables()) | |
req_vars = set(manager.get_required_variables()) | |
opt_vars = set(manager.get_optional_variables()) | |
assert all_vars == expected_all | |
assert req_vars == expected_req | |
assert opt_vars == expected_opt | |
# Add a basic render test for completeness | |
result = manager.render_template( | |
{"var1": "Value1", "schema_instruction": "Complex Schema", "condition1": False, "condition2": False, "standalone_condition": False} | |
).strip() | |
assert "Value1 is shown in else block" in result | |
assert "Value1 appears outside all conditions" in result | |
assert "is in a different conditional" not in result | |
assert "You are asked to generate exactly 1 records" in result | |
assert "Complex Schema" in result | |
# --------------------------------------------------------------------------- | |
# Tests for error handling | |
# --------------------------------------------------------------------------- | |
def test_error_handling(self, missing_var, variables, error_message): | |
"""Test error handling for missing or None required variables.""" | |
template = "Hello, {{ name }}! Your favorite color is {{ color }}." | |
manager = PromptManager(template) | |
with pytest.raises(ValueError) as exc_info: | |
manager.render_template(variables) | |
assert error_message in str(exc_info.value) | |
# --------------------------------------------------------------------------- | |
# Tests for template rendering | |
# --------------------------------------------------------------------------- | |
def test_list_input_rendering(self): | |
"""Test rendering when a list is provided as input.""" | |
template = "Processing items: {{ items_to_process }}" | |
manager = PromptManager(template) | |
items = ["apple", "banana"] | |
result = manager.render_template({"items_to_process": items, "schema_instruction": "List schema"}) | |
# Check original template part | |
# Note: The list itself is replaced by a reference and JSON dump | |
assert 'Processing items: |items_to_process| : ["apple", "banana"]' in result | |
# Check mandate instruction part for lists | |
assert "You are provided with a list named |items_to_process|" in result | |
assert "contains exactly 2 elements." in result | |
assert "Generate and return a JSON array containing exactly 2 results" in result | |
assert "List schema" in result | |
# Check mandate instruction part for non-lists is NOT present | |
assert "You are asked to generate exactly" not in result | |
def test_num_records_rendering(self): | |
"""Test default and custom num_records rendering.""" | |
template = "Generate data for {{ topic }}." | |
manager = PromptManager(template) | |
# Test default num_records = 1 | |
result_default = manager.render_template({"topic": "Weather", "schema_instruction": "Weather schema"}) | |
assert "You are asked to generate exactly 1 records" in result_default | |
# Test custom num_records = 5 | |
result_custom = manager.render_template({"topic": "Cities", "schema_instruction": "City schema", "num_records": 5}) | |
assert "You are asked to generate exactly 5 records" in result_custom | |
def test_header_footer_rendering(self): | |
"""Test rendering with header and footer.""" | |
template = "This is the main content." | |
header = "Header Info" | |
footer = "Footer Info" | |
manager = PromptManager(template, header=header, footer=footer) | |
result = manager.render_template({"schema_instruction": "Header/Footer schema"}) | |
# Check that header, template, footer are all present | |
assert header in result | |
assert template in result | |
assert footer in result | |
# Ensure correct order - header comes before template, template before footer | |
assert result.index(header) < result.index(template) | |
assert result.index(template) < result.index(footer) | |
# Verify schema instruction is included | |
assert "Header/Footer schema" in result | |
# --------------------------------------------------------------------------- | |
# Tests for utility methods | |
# --------------------------------------------------------------------------- | |
def test_from_string_constructor(self): | |
"""Test the from_string class method.""" | |
template = "Test template: {{ var }}" | |
manager = PromptManager.from_string(template) | |
assert isinstance(manager, PromptManager) | |
assert set(manager.get_required_variables()) == {"var"} # schema_instruction is not treated as required | |
def test_construct_messages(self): | |
"""Test the construct_messages method format.""" | |
template = "User query: {{ query }}" | |
manager = PromptManager(template) | |
variables = {"query": "How does this work?", "schema_instruction": "Query schema"} | |
messages = manager.construct_messages(variables) | |
assert isinstance(messages, list) | |
assert len(messages) == 1 | |
assert isinstance(messages[0], dict) | |
assert messages[0]["role"] == "user" | |
assert "User query: How does this work?" in messages[0]["content"] | |
assert "Query schema" in messages[0]["content"] # Mandate part | |
def test_get_printable_messages(self): | |
"""Test the get_printable_messages formatting.""" | |
manager = PromptManager("") # Empty template, just mandate | |
messages = [{"role": "user", "content": "Test content line 1\nTest content line 2"}, {"role": "assistant", "content": "Assistant response"}] | |
formatted_string = manager.get_printable_messages(messages) | |
assert "========" in formatted_string | |
assert "CONSTRUCTED MESSAGES:" in formatted_string | |
assert "Role: user" in formatted_string | |
assert "Content:\nTest content line 1\nTest content line 2" in formatted_string | |
assert "Role: assistant" in formatted_string | |
assert "Content:\nAssistant response" in formatted_string | |
assert "End of prompt" in formatted_string | |
# Test the utility functions outside the class | |
def test_get_prompt(): | |
"""Test the get_prompt utility function.""" | |
from starfish.llm.prompt.prompt_template import COMPLETE_PROMPTS | |
# Get a key from COMPLETE_PROMPTS | |
prompt_name = next(iter(COMPLETE_PROMPTS.keys())) | |
# Test retrieving a valid prompt | |
prompt = get_prompt(prompt_name) | |
assert prompt is not None | |
# Test cache works (call again) | |
prompt_again = get_prompt(prompt_name) | |
assert prompt is prompt_again # Should be the same object (cached) | |
# Test invalid prompt name | |
with pytest.raises(ValueError) as exc_info: | |
get_prompt("nonexistent_prompt_name") | |
assert "Unknown complete prompt" in str(exc_info.value) | |
assert prompt_name in str(exc_info.value) # Should list available options | |
def test_get_partial_prompt(): | |
"""Test the get_partial_prompt utility function.""" | |
from starfish.llm.prompt.prompt_template import PARTIAL_PROMPTS | |
# Get a key from PARTIAL_PROMPTS | |
prompt_name = next(iter(PARTIAL_PROMPTS.keys())) | |
# Test retrieving a valid partial prompt | |
template_str = "Custom template: {{ var }}" | |
prompt_manager = get_partial_prompt(prompt_name, template_str) | |
assert isinstance(prompt_manager, PromptManager) | |
assert "var" in prompt_manager.get_all_variables() | |
# Test invalid prompt name | |
with pytest.raises(ValueError) as exc_info: | |
get_partial_prompt("nonexistent_prompt_name", template_str) | |
assert "Unknown partial prompt" in str(exc_info.value) | |
assert prompt_name in str(exc_info.value) # Should list available options | |