File size: 3,297 Bytes
47c93d7
d137f7e
 
 
47c93d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d137f7e
47c93d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d137f7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47c93d7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import unittest
from config import load_secrets 
from agents.agent import Agent  
import datetime

DEBUG = True

class TestAgentMethods(unittest.TestCase):

    def assert_secrets(self, secrets_dict):
        assert secrets_dict["OPENAI_API_KEY"] is not None
        assert secrets_dict["GOOGLE_MAPS_API_KEY"] is not None
        assert secrets_dict["GOOGLE_PALM_API_KEY"] is not None


    def setUp(self):
        self.debug = DEBUG
        secrets = load_secrets()
        self.assert_secrets(secrets)

        # Assuming you only need the OPENAI_API_KEY for this test
        self.agent = Agent(
            openai_api_key=secrets["OPENAI_API_KEY"],
            debug=self.debug,
        )

    @unittest.skipIf(DEBUG, "Skipping this test while debugging other tests")
    def test_validation_chain(self):
        validation_chain = self.agent._set_up_validation_chain(debug=self.debug)

        # not a reasonable request
        q1 = "fly to the moon"
        q1_res = validation_chain(
            {
                "query": q1,
                "format_instructions": self.agent.validation_prompt.parser.get_format_instructions(),
            }
        )
        q1_out = q1_res["validation_output"].dict()
        self.assertEqual(q1_out["plan_is_valid"], "no")

        # not a reasonable request
        q2 = "1 day road trip from Chicago to Brazilia"
        q2_res = validation_chain(
            {
                "query": q2,
                "format_instructions": self.agent.validation_prompt.parser.get_format_instructions(),
            }
        )
        q2_out = q2_res["validation_output"].dict()
        self.assertEqual(q2_out["plan_is_valid"], "no")

        # a reasonable request
        q3 = "1 week road trip from Chicago to Mexico city"
        q3_res = validation_chain(
            {
                "query": q3,
                "format_instructions": self.agent.validation_prompt.parser.get_format_instructions(),
            }
        )
        q3_out = q3_res["validation_output"].dict()
        self.assertEqual(q3_out["plan_is_valid"], "yes")

    
    def test_generate_itinerary(self):

        user_details = {
            "start_location": "Berkeley, CA",
            "end_location": "Seattle, WA",
            "start_date": datetime.date(2023, 12, 10),
            "end_date": datetime.date(2023, 12, 15),
            "attractions": ["museums", "parks"],
            "budget": "1500-3000 USD",
            "transportation": "rental car, public Transport",
            "accommodation": "hotels",
            "schedule": "relaxed"
        }

        # Call the generate_itinerary method with the user_details
        itinerary_result = self.agent.generate_itinerary(user_details)

        itinerary_suggestion = itinerary_result["itinerary_suggestion"] 
        list_of_places = itinerary_result["list_of_places"] 
        validation_dict = itinerary_result["validation_dict"] 

        print("\nItinerary Suggestion Returned:\n", itinerary_suggestion)
        print("\nList of Places Returned:\n", list_of_places)

        # Assert that the itinerary contains expected keys or values.
        # This depends on what `generate_itinerary` returns. For example:
        self.assertIsNotNone(itinerary_suggestion)

if __name__ == "__main__":
    unittest.main()