Spaces:
Runtime error
Runtime error
import unittest | |
from config import load_secrets | |
from agents.agent import Agent | |
import datetime | |
DEBUG = True | |
class TestAgentMethods(unittest.TestCase): | |
def assert_secrets(self, secrets_dict): | |
assert secrets_dict["OPENAI_API_KEY"] is not None | |
assert secrets_dict["GOOGLE_MAPS_API_KEY"] is not None | |
assert secrets_dict["GOOGLE_PALM_API_KEY"] is not None | |
def setUp(self): | |
self.debug = DEBUG | |
secrets = load_secrets() | |
self.assert_secrets(secrets) | |
# Assuming you only need the OPENAI_API_KEY for this test | |
self.agent = Agent( | |
openai_api_key=secrets["OPENAI_API_KEY"], | |
debug=self.debug, | |
) | |
def test_validation_chain(self): | |
validation_chain = self.agent._set_up_validation_chain(debug=self.debug) | |
# not a reasonable request | |
q1 = "fly to the moon" | |
q1_res = validation_chain( | |
{ | |
"query": q1, | |
"format_instructions": self.agent.validation_prompt.parser.get_format_instructions(), | |
} | |
) | |
q1_out = q1_res["validation_output"].dict() | |
self.assertEqual(q1_out["plan_is_valid"], "no") | |
# not a reasonable request | |
q2 = "1 day road trip from Chicago to Brazilia" | |
q2_res = validation_chain( | |
{ | |
"query": q2, | |
"format_instructions": self.agent.validation_prompt.parser.get_format_instructions(), | |
} | |
) | |
q2_out = q2_res["validation_output"].dict() | |
self.assertEqual(q2_out["plan_is_valid"], "no") | |
# a reasonable request | |
q3 = "1 week road trip from Chicago to Mexico city" | |
q3_res = validation_chain( | |
{ | |
"query": q3, | |
"format_instructions": self.agent.validation_prompt.parser.get_format_instructions(), | |
} | |
) | |
q3_out = q3_res["validation_output"].dict() | |
self.assertEqual(q3_out["plan_is_valid"], "yes") | |
def test_generate_itinerary(self): | |
user_details = { | |
"start_location": "Berkeley, CA", | |
"end_location": "Seattle, WA", | |
"start_date": datetime.date(2023, 12, 10), | |
"end_date": datetime.date(2023, 12, 15), | |
"attractions": ["museums", "parks"], | |
"budget": "1500-3000 USD", | |
"transportation": "rental car, public Transport", | |
"accommodation": "hotels", | |
"schedule": "relaxed" | |
} | |
# Call the generate_itinerary method with the user_details | |
itinerary_result = self.agent.generate_itinerary(user_details) | |
itinerary_suggestion = itinerary_result["itinerary_suggestion"] | |
list_of_places = itinerary_result["list_of_places"] | |
validation_dict = itinerary_result["validation_dict"] | |
print("\nItinerary Suggestion Returned:\n", itinerary_suggestion) | |
print("\nList of Places Returned:\n", list_of_places) | |
# Assert that the itinerary contains expected keys or values. | |
# This depends on what `generate_itinerary` returns. For example: | |
self.assertIsNotNone(itinerary_suggestion) | |
if __name__ == "__main__": | |
unittest.main() |