File size: 13,514 Bytes
24ae72d
3f61806
 
 
 
 
 
 
24ae72d
 
3f61806
 
 
 
 
 
 
24ae72d
 
 
3f61806
 
24ae72d
3f61806
 
 
 
 
 
 
 
 
 
24ae72d
 
 
3f61806
 
 
 
24ae72d
3f61806
 
 
 
 
 
 
24ae72d
 
3f61806
 
 
 
 
 
24ae72d
 
3f61806
 
 
 
 
 
 
24ae72d
 
 
 
3f61806
24ae72d
3f61806
24ae72d
3f61806
 
24ae72d
 
 
 
 
 
 
 
 
 
3f61806
24ae72d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f61806
 
24ae72d
 
 
 
 
 
 
 
 
3f61806
24ae72d
 
 
 
 
 
 
3f61806
24ae72d
 
 
 
 
 
3f61806
 
 
 
 
 
 
 
 
 
 
24ae72d
3f61806
 
24ae72d
 
 
 
 
3f61806
 
 
24ae72d
 
4623a33
24ae72d
3f61806
 
 
 
 
 
 
 
 
 
 
24ae72d
3f61806
 
24ae72d
 
 
4623a33
 
 
 
 
 
3f61806
 
 
24ae72d
 
 
 
3f61806
24ae72d
3f61806
24ae72d
 
 
 
 
 
 
 
 
 
 
 
 
 
3f61806
 
24ae72d
 
 
3f61806
24ae72d
 
3f61806
 
24ae72d
 
 
 
3f61806
 
24ae72d
 
 
 
 
3f61806
 
 
24ae72d
3f61806
 
 
 
 
 
24ae72d
3f61806
 
24ae72d
3f61806
 
24ae72d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f61806
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
# /home/bk_anupam/code/LLM_agents/RAG_BOT/tests/integration/test_integration.py
import os
import re
import sys
import shutil
import unittest
import json
from unittest.mock import MagicMock
from typing import Optional
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage # Added ToolMessage
from langchain_google_genai import ChatGoogleGenerativeAI

# Add the project root to the Python path
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.insert(0, project_root)

from RAG_BOT.vector_store import VectorStore
# Updated imports for build_agent and AgentState
from RAG_BOT.agent.graph_builder import build_agent
from RAG_BOT.agent.state import AgentState
from RAG_BOT.logger import logger
from RAG_BOT.config import Config
from RAG_BOT import utils

class TestIntegration(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        """Setup method that is called once before all tests in the class."""
        cls.config = Config()
        cls.delete_exisiting_test_vector_store()
        logger.info("Deleted existing test vector store.")
        cls.test_vector_store = cls.setup_test_environment()
        cls.vectordb = cls.test_vector_store.get_vectordb()
        # Build agent once for the class
        cls.agent = build_agent(vectordb=cls.vectordb, model_name=cls.config.LLM_MODEL_NAME)


    @classmethod
    def tearDownClass(cls):
        """Teardown method that is called once after all tests in the class."""
        pass # Keep vector store for inspection if needed, or delete

    @classmethod
    def delete_exisiting_test_vector_store(cls):
        current_dir = os.path.dirname(os.path.abspath(__file__))
        test_vector_store_dir = os.path.join(current_dir, "..", "test_vector_store")
        if os.path.exists(test_vector_store_dir):
            shutil.rmtree(test_vector_store_dir)
            logger.info(f"Deleted test vector store at: {test_vector_store_dir}")


    @classmethod
    def setup_test_environment(cls):
        current_dir = os.path.dirname(os.path.abspath(__file__))
        pdf_dir = os.path.join(current_dir, "..", "data")
        test_vector_store_dir = os.path.join(current_dir, "..", "test_vector_store")
        os.makedirs(test_vector_store_dir, exist_ok=True) # Ensure dir exists
        logger.info(f"Setting up test vector store in: {test_vector_store_dir}")
        # Create a test vector store and index sample PDFs
        test_vector_store = VectorStore(persist_directory=test_vector_store_dir)
        pdf_files = [
            os.path.join(pdf_dir, f)
            for f in os.listdir(pdf_dir)
            if f.endswith(".pdf")
        ]
        if not pdf_files:
             logger.warning(f"No PDF files found in {pdf_dir} for indexing.")
             return test_vector_store # Return empty store if no PDFs

        for pdf_file in pdf_files:
            logger.info(f"Indexing test file: {pdf_file}")
            test_vector_store.build_index(pdf_file, semantic_chunk=cls.config.SEMANTIC_CHUNKING)
        logger.info("Test vector store setup complete.")
        return test_vector_store

    def _run_agent(self, query: str) -> AgentState:
        """Helper method to run the agent with a query."""
        initial_state = AgentState(messages=[HumanMessage(content=query)])
        # Add recursion limit for safety
        final_state = self.agent.invoke(initial_state, {"recursion_limit": 15})
        self.assertIsInstance(final_state, dict)
        self.assertIn("messages", final_state)
        return final_state


    def test_indexing_documents(self):
        """Verify that documents were indexed in the test vector store."""
        # Skip if vectordb wasn't created properly
        if not hasattr(self, 'vectordb') or self.vectordb is None:
            self.skipTest("VectorDB instance not available.")
        try:
            documents_dict = self.vectordb.get(limit=1) # Fetch just one to confirm collection exists
            # Check if the collection is empty or exists
            self.assertIsNotNone(documents_dict, "VectorDB get() returned None.")
            # Check if 'ids' list exists and is not empty
            self.assertIn("ids", documents_dict)
            self.assertIsInstance(documents_dict["ids"], list)
            # We only check if *any* document was indexed, as exact count depends on chunking
            self.assertGreater(len(documents_dict["ids"]), 0, "No documents were indexed.")
        except Exception as e:
             # Catch potential errors if the collection doesn't exist yet
             self.fail(f"Failed to get documents from VectorDB: {e}")


    def evaluate_response_with_llm(self, query: str, context: Optional[str], response: str) -> bool:
        """Uses an LLM to judge the quality of the agent's response."""
        judge_llm = ChatGoogleGenerativeAI(model=Config.JUDGE_LLM_MODEL_NAME, temperature=0.0)
        judge_prompt_template = Config.get_judge_prompt_template()
        # The judge prompt expects the raw response string, which includes the JSON structure
        judge_prompt = judge_prompt_template.format(
            query=query,
            context=context if context else "N/A",
            response=response # Pass the raw response string
        )
        try:
            evaluation = judge_llm.invoke([HumanMessage(content=judge_prompt)]).content.strip().upper()
            logger.info(f"LLM Judge Evaluation for query '{query[:50]}...': {evaluation}")
            return evaluation == 'PASS'
        except Exception as e:
            logger.error(f"LLM Judge call failed: {e}")
            return False # Fail the test if judge fails


    def test_agent_with_retrieval(self):
        """Tests the agent's ability to retrieve context and answer in JSON."""
        # Query without JSON instruction
        query = "What is the title of the murli from 1969-01-23?"
        final_state = self._run_agent(query)
        messages = final_state["messages"]
        self.assertGreater(len(messages), 1)

        # Check that the tool was called at least once
        tool_called = any(
            isinstance(msg, AIMessage) and msg.tool_calls and
            any(tc.get("name") == "retrieve_context" for tc in msg.tool_calls)
            for msg in messages
        )
        self.assertTrue(tool_called, "The 'retrieve_context' tool was not called as expected.")

        # Check the final answer format and content
        final_answer_message = messages[-1]
        self.assertEqual(final_answer_message.type, "ai")
        json_result = utils.parse_json_answer(final_answer_message.content)
        self.assertIsNotNone(json_result, f"Final answer is not valid JSON: {final_answer_message.content}")
        self.assertIn("answer", json_result)
        # Make comparison case-insensitive and check for substring
        self.assertIn("the ashes are to remind you of the stage", json_result["answer"].lower())


    def test_agent_without_retrieval(self):
        """Tests the agent's ability to answer a general question without retrieval, in JSON."""
        # Query without JSON instruction
        query = "What is the purpose of life?"
        final_state = self._run_agent(query)
        messages = final_state["messages"]
        self.assertGreater(len(messages), 1)

        # Ensure no tool call was made
        tool_called = any(
            isinstance(msg, AIMessage) and msg.tool_calls and
            any(tc.get("name") == "retrieve_context" for tc in msg.tool_calls)
            for msg in messages
        )
        self.assertFalse(tool_called, "The 'retrieve_context' tool was called unexpectedly.")

        # Check the final answer format and content
        final_answer_message = messages[-1]
        self.assertEqual(final_answer_message.type, "ai")
        json_result = utils.parse_json_answer(final_answer_message.content)
        self.assertIsNotNone(json_result, f"Final answer is not valid JSON: {final_answer_message.content}")
        self.assertIn("answer", json_result)
        # check that cannot find is not in the answer
        answer_lower = json_result["answer"].lower()
        self.assertNotIn("cannot be found", answer_lower,
                         f"Agent returned 'cannot be found' unexpectedly: {json_result['answer']}")
        self.assertNotIn("cannot find", answer_lower,
                         f"Agent returned 'cannot find' unexpectedly: {json_result['answer']}")


    def test_agent_insufficient_context(self):
        """Test agent response (in JSON) when no relevant context is found."""
        # Query without JSON instruction
        query = "Can you summarize the murli from 1950-01-18?"
        final_state = self._run_agent(query)
        messages = final_state["messages"]
        self.assertGreater(len(messages), 1)

        # --- Behavioral Assertions ---
        # 1. Check if retry was attempted (assuming the first retrieval yields nothing relevant)        
        self.assertTrue(final_state.get("retry_attempted"),
                         "Agent state should indicate retry_attempted was True if initial retrieval failed")

        # 2. Check that the tool was called (at least once)
        tool_call_count = sum(
            1 for msg in messages
            if isinstance(msg, AIMessage) and msg.tool_calls and
            any(tc.get("name") == "retrieve_context" for tc in msg.tool_calls)
        )
        self.assertGreaterEqual(tool_call_count, 1, "The 'retrieve_context' tool was not called.")

        # 3. Check the final answer format and content
        final_answer_message = messages[-1]
        self.assertEqual(final_answer_message.type, "ai")
        json_result = utils.parse_json_answer(final_answer_message.content)
        self.assertIsNotNone(json_result, f"Final 'cannot find' answer is not valid JSON: {final_answer_message.content}")
        self.assertIn("answer", json_result)
        self.assertTrue(
            "cannot be found" in json_result["answer"].lower() or "cannot find" in json_result["answer"].lower(),
            f"Agent did not return a 'cannot find' message within the JSON answer: {json_result['answer']}"
        )

        # 4. Check state reflects insufficient evaluation (if retry occurred) or final decision path        
        if final_state.get("retry_attempted"):
            self.assertEqual(final_state.get("evaluation_result"), "insufficient",
                             "Agent state should indicate evaluation_result was insufficient after retry")


    def test_agent_retry_logic_reframing(self):
        """Test agent retry logic (reframing) and final JSON output."""
        # Query without JSON instruction - date likely not in test data
        query = "Can you summarize the murli from 1970-01-18?"
        final_state = self._run_agent(query)
        messages = final_state["messages"]
        self.assertGreater(len(messages), 1)

        # Check that at least one tool call was made
        tool_calls = [
            msg for msg in messages
            if isinstance(msg, AIMessage) and msg.tool_calls and
            any(tc.get("name") == "retrieve_context" for tc in msg.tool_calls)
        ]
        self.assertGreaterEqual(len(tool_calls), 1, "No tool call was made during retry logic.")
        # Check that the retry logic was invoked        
        self.assertTrue(final_state.get("retry_attempted"), "Agent state should indicate retry_attempted was True")

        # Check the final answer format (should be JSON, likely a 'cannot find' message)
        final_answer_message = messages[-1]
        self.assertEqual(final_answer_message.type, "ai")
        json_result = utils.parse_json_answer(final_answer_message.content)
        self.assertIsNotNone(json_result, f"Final answer after retry is not valid JSON: {final_answer_message.content}")
        self.assertIn("answer", json_result)
        # Content could be a summary if found after retry, or 'cannot find'
        self.assertIsInstance(json_result["answer"], str)


    def test_summarization_for_a_date(self):
        """Test agent's ability to summarize a murli for a specific date in JSON."""
        # Query without JSON instruction
        query = "Can you summarize the murli from 1969-01-23?"
        final_state = self._run_agent(query)

        # --- Explicitly check context presence in final state ---
        self.assertIn("context", final_state, "The 'context' key is missing from the final agent state.")
        context = final_state.get("context")
        # Context could be None if retrieval failed, but the final answer should reflect that.
        # If context *is* present, it should be a string.
        if context is not None:
            self.assertIsInstance(context, str, "Context field in the final state is not a string.")
            # Optional: Check if context is not empty if retrieval was expected to succeed
            # self.assertTrue(len(context.strip()) > 0, "Context retrieved from final state appears to be empty.")

        # Evaluate the response using the LLM judge
        final_answer_content = final_state["messages"][-1].content
        evaluation_result = self.evaluate_response_with_llm(query, context, final_answer_content)
        json_result = utils.parse_json_answer(final_answer_content)
        response_answer = json_result.get("answer", "")
        self.assertTrue(evaluation_result, f"LLM Judge evaluation failed for query '{query}'. Response: {response_answer}")


if __name__ == "__main__":
    unittest.main()