John-Jiang's picture
init commit
5301c48
import pytest
from starfish.llm.prompt import PromptManager
from starfish.llm.prompt.prompt_loader import get_partial_prompt, get_prompt
"""Tests for the PromptManager class and related functionality in starfish.llm.prompt.
This test suite covers:
1. Variable identification (required vs optional variables in templates)
2. Template rendering with different variable combinations
3. Error handling for missing or invalid variables
4. Special features like list input detection, num_records handling
5. Utility methods like get_prompt and get_partial_prompt
Note: Although the PromptManager appends a MANDATE_INSTRUCTION that contains
schema_instruction and other variables, it does not treat schema_instruction
as a required variable. The test cases follow the actual implementation behavior
which identifies only variables used outside of conditional blocks in the
original user template as "required".
"""
# Utility functions for test setup
def get_expected_mandate_vars(required=False):
"""Get the list of variables added by MANDATE_INSTRUCTION.
Args:
required: If True, return only variables that would be classified as required.
If False (default), return only variables that would be classified as optional.
Returns:
Set of variable names.
"""
all_mandate_vars = {"is_list_input", "list_input_variable", "input_list_length", "schema_instruction", "num_records"}
# In practice, none of the mandate variables are treated as required
if required:
return set()
else:
return all_mandate_vars
# Fixtures
@pytest.fixture
def basic_template():
"""Simple template with basic variables for testing."""
return "Hello, {{ name }}! Your age is {{ age }}."
@pytest.fixture
def simple_prompt_manager(basic_template):
"""Basic PromptManager instance with a simple template."""
return PromptManager(basic_template)
@pytest.fixture
def standard_variables():
"""Common variables used in multiple tests."""
return {"name": "Alice", "age": 30, "schema_instruction": "Test schema"}
class TestPromptManager:
"""Test cases for the PromptManager class."""
# ---------------------------------------------------------------------------
# Tests for variable identification
# ---------------------------------------------------------------------------
def test_basic_required_variables(self, simple_prompt_manager, standard_variables):
"""Test identifying required variables in a basic template."""
# Note: MANDATE_INSTRUCTION is automatically appended by PromptManager
manager = simple_prompt_manager
# Basic variables from template
expected_template_required = {"name", "age"}
expected_template_optional = set()
# Get expected variables including those from MANDATE_INSTRUCTION
expected_required = expected_template_required.union(get_expected_mandate_vars(required=True))
expected_optional = expected_template_optional.union(get_expected_mandate_vars(required=False))
expected_all = expected_required.union(expected_optional)
# Check variable identification
assert set(manager.get_all_variables()) == expected_all
assert set(manager.get_required_variables()) == expected_required
assert set(manager.get_optional_variables()) == expected_optional
# Test rendering - requires schema_instruction now
result = manager.render_template(standard_variables)
assert "Hello, Alice! Your age is 30." in result
# Check that part of the non-list mandate instruction is present
assert "You are asked to generate exactly 1 records" in result
assert "Test schema" in result
@pytest.mark.parametrize(
"template,template_required,template_optional,test_name",
[
# Basic conditional test
(
"""
Hello, {{ name }}!
{% if show_age %}
Your age is {{ age }}.
{% endif %}
""",
{"name"},
{"show_age", "age"},
"basic_conditional",
),
# Nested conditional test
(
"""
Hello, {{ name }}!
{% if show_details %}
{% if show_age %}
Your age is {{ age }}.
{% endif %}
{% if show_location %}
Your location is {{ location }}.
{% endif %}
{% endif %}
""",
{"name"},
{"show_details", "show_age", "age", "show_location", "location"},
"nested_conditional",
),
# Mixed conditional test
(
"""
Hello, {{ name }}!
{% if show_details %}
Your details: {{ details }}
{% endif %}
Always show: {{ details }}
""",
{"name", "details"},
{"show_details"},
"mixed_variables",
),
# Conditional in conditional test
(
"""
Hello, {{ name }}!
{% if show_details %}
{% if details %}
Your details: {{ details }}
{% endif %}
{% endif %}
""",
{"name"},
{"show_details", "details"},
"conditional_in_conditional",
),
],
)
def test_conditional_variable_analysis(self, template, template_required, template_optional, test_name):
"""Test variable analysis with various conditional structures."""
manager = PromptManager(template)
# Get the variables
all_vars = set(manager.get_all_variables())
required_vars = set(manager.get_required_variables())
optional_vars = set(manager.get_optional_variables())
# Get expected variables including those from MANDATE_INSTRUCTION
expected_required = template_required.union(get_expected_mandate_vars(required=True))
expected_optional = template_optional.union(get_expected_mandate_vars(required=False))
expected_all = expected_required.union(expected_optional)
# Check variable identification
assert all_vars == expected_all, f"Failed for {test_name}: all variables"
assert required_vars == expected_required, f"Failed for {test_name}: required variables"
assert optional_vars == expected_optional, f"Failed for {test_name}: optional variables"
# Basic rendering test
if test_name == "basic_conditional":
# Test showing age
result = manager.render_template({"name": "Bob", "show_age": True, "age": 25, "schema_instruction": "Age schema"}).strip()
assert "Hello, Bob!" in result
assert "Your age is 25." in result
# Test hiding age
result = manager.render_template({"name": "Charlie", "show_age": False, "schema_instruction": "No age schema"}).strip()
assert "Hello, Charlie!" in result
assert "Your age is" not in result
def test_complex_templates(self):
"""Test more complex template structures."""
template = """
{% if condition1 %}
{{ var1 }} is shown in condition1
{% if nested_condition %}
{{ var2 }} is in nested condition
{{ var3 }} is also in nested condition
{% endif %}
{% elif condition2 %}
{{ var1 }} is shown in condition2
{{ var4 }} is only in condition2
{% else %}
{{ var1 }} is shown in else block
{{ var5 }} is only in else block
{% endif %}
{{ var1 }} appears outside all conditions
{% if standalone_condition %}
{{ var6 }} is in a different conditional
{% endif %}
"""
manager = PromptManager(template)
# Template variables
template_required = {"var1"}
template_optional = {"condition1", "nested_condition", "var2", "var3", "condition2", "var4", "var5", "standalone_condition", "var6"}
# Get expected variables including those from MANDATE_INSTRUCTION
expected_req = template_required.union(get_expected_mandate_vars(required=True))
expected_opt = template_optional.union(get_expected_mandate_vars(required=False))
expected_all = expected_req.union(expected_opt)
# Check variable identification
all_vars = set(manager.get_all_variables())
req_vars = set(manager.get_required_variables())
opt_vars = set(manager.get_optional_variables())
assert all_vars == expected_all
assert req_vars == expected_req
assert opt_vars == expected_opt
# Add a basic render test for completeness
result = manager.render_template(
{"var1": "Value1", "schema_instruction": "Complex Schema", "condition1": False, "condition2": False, "standalone_condition": False}
).strip()
assert "Value1 is shown in else block" in result
assert "Value1 appears outside all conditions" in result
assert "is in a different conditional" not in result
assert "You are asked to generate exactly 1 records" in result
assert "Complex Schema" in result
# ---------------------------------------------------------------------------
# Tests for error handling
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"missing_var,variables,error_message",
[
# Missing required variables
("color", {"name": "Helen", "schema_instruction": "Schema"}, "Required variable 'color' is missing"),
("name", {"color": "Blue", "schema_instruction": "Schema"}, "Required variable 'name' is missing"),
# None values for required variables
("color", {"name": "Ivan", "color": None, "schema_instruction": "Schema"}, "Required variable 'color' cannot be None"),
("name", {"name": None, "color": "Green", "schema_instruction": "Schema"}, "Required variable 'name' cannot be None"),
],
)
def test_error_handling(self, missing_var, variables, error_message):
"""Test error handling for missing or None required variables."""
template = "Hello, {{ name }}! Your favorite color is {{ color }}."
manager = PromptManager(template)
with pytest.raises(ValueError) as exc_info:
manager.render_template(variables)
assert error_message in str(exc_info.value)
# ---------------------------------------------------------------------------
# Tests for template rendering
# ---------------------------------------------------------------------------
def test_list_input_rendering(self):
"""Test rendering when a list is provided as input."""
template = "Processing items: {{ items_to_process }}"
manager = PromptManager(template)
items = ["apple", "banana"]
result = manager.render_template({"items_to_process": items, "schema_instruction": "List schema"})
# Check original template part
# Note: The list itself is replaced by a reference and JSON dump
assert 'Processing items: |items_to_process| : ["apple", "banana"]' in result
# Check mandate instruction part for lists
assert "You are provided with a list named |items_to_process|" in result
assert "contains exactly 2 elements." in result
assert "Generate and return a JSON array containing exactly 2 results" in result
assert "List schema" in result
# Check mandate instruction part for non-lists is NOT present
assert "You are asked to generate exactly" not in result
def test_num_records_rendering(self):
"""Test default and custom num_records rendering."""
template = "Generate data for {{ topic }}."
manager = PromptManager(template)
# Test default num_records = 1
result_default = manager.render_template({"topic": "Weather", "schema_instruction": "Weather schema"})
assert "You are asked to generate exactly 1 records" in result_default
# Test custom num_records = 5
result_custom = manager.render_template({"topic": "Cities", "schema_instruction": "City schema", "num_records": 5})
assert "You are asked to generate exactly 5 records" in result_custom
def test_header_footer_rendering(self):
"""Test rendering with header and footer."""
template = "This is the main content."
header = "Header Info"
footer = "Footer Info"
manager = PromptManager(template, header=header, footer=footer)
result = manager.render_template({"schema_instruction": "Header/Footer schema"})
# Check that header, template, footer are all present
assert header in result
assert template in result
assert footer in result
# Ensure correct order - header comes before template, template before footer
assert result.index(header) < result.index(template)
assert result.index(template) < result.index(footer)
# Verify schema instruction is included
assert "Header/Footer schema" in result
# ---------------------------------------------------------------------------
# Tests for utility methods
# ---------------------------------------------------------------------------
def test_from_string_constructor(self):
"""Test the from_string class method."""
template = "Test template: {{ var }}"
manager = PromptManager.from_string(template)
assert isinstance(manager, PromptManager)
assert set(manager.get_required_variables()) == {"var"} # schema_instruction is not treated as required
def test_construct_messages(self):
"""Test the construct_messages method format."""
template = "User query: {{ query }}"
manager = PromptManager(template)
variables = {"query": "How does this work?", "schema_instruction": "Query schema"}
messages = manager.construct_messages(variables)
assert isinstance(messages, list)
assert len(messages) == 1
assert isinstance(messages[0], dict)
assert messages[0]["role"] == "user"
assert "User query: How does this work?" in messages[0]["content"]
assert "Query schema" in messages[0]["content"] # Mandate part
def test_get_printable_messages(self):
"""Test the get_printable_messages formatting."""
manager = PromptManager("") # Empty template, just mandate
messages = [{"role": "user", "content": "Test content line 1\nTest content line 2"}, {"role": "assistant", "content": "Assistant response"}]
formatted_string = manager.get_printable_messages(messages)
assert "========" in formatted_string
assert "CONSTRUCTED MESSAGES:" in formatted_string
assert "Role: user" in formatted_string
assert "Content:\nTest content line 1\nTest content line 2" in formatted_string
assert "Role: assistant" in formatted_string
assert "Content:\nAssistant response" in formatted_string
assert "End of prompt" in formatted_string
# Test the utility functions outside the class
def test_get_prompt():
"""Test the get_prompt utility function."""
from starfish.llm.prompt.prompt_template import COMPLETE_PROMPTS
# Get a key from COMPLETE_PROMPTS
prompt_name = next(iter(COMPLETE_PROMPTS.keys()))
# Test retrieving a valid prompt
prompt = get_prompt(prompt_name)
assert prompt is not None
# Test cache works (call again)
prompt_again = get_prompt(prompt_name)
assert prompt is prompt_again # Should be the same object (cached)
# Test invalid prompt name
with pytest.raises(ValueError) as exc_info:
get_prompt("nonexistent_prompt_name")
assert "Unknown complete prompt" in str(exc_info.value)
assert prompt_name in str(exc_info.value) # Should list available options
def test_get_partial_prompt():
"""Test the get_partial_prompt utility function."""
from starfish.llm.prompt.prompt_template import PARTIAL_PROMPTS
# Get a key from PARTIAL_PROMPTS
prompt_name = next(iter(PARTIAL_PROMPTS.keys()))
# Test retrieving a valid partial prompt
template_str = "Custom template: {{ var }}"
prompt_manager = get_partial_prompt(prompt_name, template_str)
assert isinstance(prompt_manager, PromptManager)
assert "var" in prompt_manager.get_all_variables()
# Test invalid prompt name
with pytest.raises(ValueError) as exc_info:
get_partial_prompt("nonexistent_prompt_name", template_str)
assert "Unknown partial prompt" in str(exc_info.value)
assert prompt_name in str(exc_info.value) # Should list available options