Spaces:
Paused
Paused
File size: 4,138 Bytes
ad33df7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import pytest
from kotaemon.llms import PromptTemplate
def test_prompt_template_creation():
# Ensure the PromptTemplate object is created correctly
template_string = "This is a template"
template = PromptTemplate(template_string)
assert template.template == template_string
template_string = "Hello, {name}! Today is {day}."
template = PromptTemplate(template_string)
assert template.template == template_string
assert template.placeholders == {"name", "day"}
def test_prompt_template_creation_invalid_placeholder():
# Ensure the PromptTemplate object handle invalid placeholder correctly
template_string = "Hello, {name}! Today is {0day}."
with pytest.raises(ValueError):
PromptTemplate(template_string, ignore_invalid=False)
with pytest.warns(
UserWarning,
match="Ignore invalid placeholder: 0day.",
):
PromptTemplate(template_string, ignore_invalid=True)
def test_prompt_template_addition():
# Ensure the __add__ method concatenates the templates correctly
template1 = PromptTemplate("Hello, ")
template2 = PromptTemplate("world!")
result = template1 + template2
assert result.template == "Hello, \nworld!"
template1 = PromptTemplate("Hello, {name}!")
template2 = PromptTemplate("Today is {day}.")
result = template1 + template2
assert result.template == "Hello, {name}!\nToday is {day}."
def test_prompt_template_extract_placeholders():
# Ensure the PromptTemplate correctly extracts placeholders
template_string = "Hello, {name}! Today is {day}."
result = PromptTemplate(template_string).placeholders
assert result == {"name", "day"}
def test_prompt_template_populate():
# Ensure the populate method populates the template correctly
template_string = "Hello, {name}! Today is {day}."
template = PromptTemplate(template_string)
result = template.populate(name="John", day="Monday")
assert result == "Hello, John! Today is Monday."
def test_prompt_template_check_missing_kwargs():
# Ensure the check_missing_kwargs and populate methods raise an exception for
# missing placeholders
template_string = "Hello, {name}! Today is {day}."
template = PromptTemplate(template_string)
kwargs = dict(name="John")
with pytest.raises(ValueError):
template.check_missing_kwargs(**kwargs)
with pytest.raises(ValueError):
template.populate(**kwargs)
def test_prompt_template_check_redundant_kwargs():
# Ensure the check_redundant_kwargs, partial_populate and populate methods warn for
# redundant placeholders
template_string = "Hello, {name}! Today is {day}."
template = PromptTemplate(template_string)
kwargs = dict(name="John", day="Monday", age="30")
with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
template.check_redundant_kwargs(**kwargs)
with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
template.partial_populate(**kwargs)
with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
template.populate(**kwargs)
def test_prompt_template_populate_complex_template():
# Ensure the populate method produces the same results as the built-in str.format
# function
template_string = (
"a = {a:.2f}, b = {b}, c = {c:.1%}, d = {d:#.0g}, ascii of {e} = {e!a:>2}"
)
template = PromptTemplate(template_string)
kwargs = dict(a=1, b="two", c=3, d=4, e="á")
populated = template.populate(**kwargs)
expected = template_string.format(**kwargs)
assert populated == expected
def test_prompt_template_partial_populate():
# Ensure the partial_populate method populates correctly
template_string = (
"a = {a:.2f}, b = {b}, c = {c:.1%}, d = {d:#.0g}, ascii of {e} = {e!a:>2}"
)
template = PromptTemplate(template_string)
kwargs = dict(a=1, b="two", d=4, e="á")
populated = template.partial_populate(**kwargs)
expected = "a = 1.00, b = two, c = {c:.1%}, d = 4., ascii of á = '\\xe1'"
assert populated == expected
|