Spaces:
Runtime error
Runtime error
File size: 3,297 Bytes
47c93d7 d137f7e 47c93d7 d137f7e 47c93d7 d137f7e 47c93d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import unittest
from config import load_secrets
from agents.agent import Agent
import datetime
DEBUG = True
class TestAgentMethods(unittest.TestCase):
def assert_secrets(self, secrets_dict):
assert secrets_dict["OPENAI_API_KEY"] is not None
assert secrets_dict["GOOGLE_MAPS_API_KEY"] is not None
assert secrets_dict["GOOGLE_PALM_API_KEY"] is not None
def setUp(self):
self.debug = DEBUG
secrets = load_secrets()
self.assert_secrets(secrets)
# Assuming you only need the OPENAI_API_KEY for this test
self.agent = Agent(
openai_api_key=secrets["OPENAI_API_KEY"],
debug=self.debug,
)
@unittest.skipIf(DEBUG, "Skipping this test while debugging other tests")
def test_validation_chain(self):
validation_chain = self.agent._set_up_validation_chain(debug=self.debug)
# not a reasonable request
q1 = "fly to the moon"
q1_res = validation_chain(
{
"query": q1,
"format_instructions": self.agent.validation_prompt.parser.get_format_instructions(),
}
)
q1_out = q1_res["validation_output"].dict()
self.assertEqual(q1_out["plan_is_valid"], "no")
# not a reasonable request
q2 = "1 day road trip from Chicago to Brazilia"
q2_res = validation_chain(
{
"query": q2,
"format_instructions": self.agent.validation_prompt.parser.get_format_instructions(),
}
)
q2_out = q2_res["validation_output"].dict()
self.assertEqual(q2_out["plan_is_valid"], "no")
# a reasonable request
q3 = "1 week road trip from Chicago to Mexico city"
q3_res = validation_chain(
{
"query": q3,
"format_instructions": self.agent.validation_prompt.parser.get_format_instructions(),
}
)
q3_out = q3_res["validation_output"].dict()
self.assertEqual(q3_out["plan_is_valid"], "yes")
def test_generate_itinerary(self):
user_details = {
"start_location": "Berkeley, CA",
"end_location": "Seattle, WA",
"start_date": datetime.date(2023, 12, 10),
"end_date": datetime.date(2023, 12, 15),
"attractions": ["museums", "parks"],
"budget": "1500-3000 USD",
"transportation": "rental car, public Transport",
"accommodation": "hotels",
"schedule": "relaxed"
}
# Call the generate_itinerary method with the user_details
itinerary_result = self.agent.generate_itinerary(user_details)
itinerary_suggestion = itinerary_result["itinerary_suggestion"]
list_of_places = itinerary_result["list_of_places"]
validation_dict = itinerary_result["validation_dict"]
print("\nItinerary Suggestion Returned:\n", itinerary_suggestion)
print("\nList of Places Returned:\n", list_of_places)
# Assert that the itinerary contains expected keys or values.
# This depends on what `generate_itinerary` returns. For example:
self.assertIsNotNone(itinerary_suggestion)
if __name__ == "__main__":
unittest.main() |