# /home/bk_anupam/code/LLM_agents/RAG_BOT/agent/graph_builder.py import functools import os import sys from typing import Literal from langchain_chroma import Chroma from langchain_google_genai import ChatGoogleGenerativeAI from langgraph.graph import StateGraph, END from langgraph.prebuilt import ToolNode, tools_condition from sentence_transformers import CrossEncoder # Add the project root to the Python path project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) sys.path.insert(0, project_root) from RAG_BOT.config import Config from RAG_BOT.logger import logger from RAG_BOT.context_retriever_tool import create_context_retriever_tool from RAG_BOT.agent.state import AgentState from RAG_BOT.agent.agent_node import agent_node from RAG_BOT.agent.retrieval_nodes import rerank_context_node from RAG_BOT.agent.evaluation_nodes import evaluate_context_node, reframe_query_node # --- Conditional Edge Logic --- def decide_next_step(state: AgentState) -> Literal["reframe_query", "agent_final_answer", "__end__"]: """ Determines the next node based on evaluation result and retry status. """ logger.info("--- Deciding Next Step ---") evaluation = state.get('evaluation_result') retry_attempted = state.get('retry_attempted', False) logger.info(f"Evaluation: {evaluation}, Retry Attempted: {retry_attempted}") if evaluation == "sufficient": logger.info("Decision: Context sufficient, proceed to final answer generation.") return "agent_final_answer" # Route to agent node for final answer elif not retry_attempted: logger.info("Decision: Context insufficient, attempt retry.") return "reframe_query" # Route to reframe node else: logger.info("Decision: Context insufficient after retry, proceed to 'cannot find' message.") return "agent_final_answer" # Route to agent node for "cannot find" message # --- Graph Builder --- def build_agent(vectordb: Chroma, model_name: str = Config.LLM_MODEL_NAME) -> StateGraph: """Builds the multi-node LangGraph agent.""" llm = ChatGoogleGenerativeAI(model=model_name, temperature=Config.TEMPERATURE) logger.info(f"LLM model '{model_name}' initialized with temperature {Config.TEMPERATURE}.") # --- Reranker Model Initialization --- reranker_model = None # Initialize as None try: reranker_model_name = Config.RERANKER_MODEL_NAME logger.info(f"Loading reranker model: {reranker_model_name}") reranker_model = CrossEncoder(reranker_model_name) logger.info("Reranker model loaded successfully.") except Exception as e: logger.error(f"Failed to load reranker model '{Config.RERANKER_MODEL_NAME}': {e}", exc_info=True) # The graph will proceed, but rerank_context_node will skip reranking # --- Tool Preparation --- # Use INITIAL_RETRIEVAL_K for the retriever tool that feeds the reranker ctx_retriever_tool_instance = create_context_retriever_tool( vectordb=vectordb, k=Config.INITIAL_RETRIEVAL_K, # Use the larger K for initial retrieval search_type=Config.SEARCH_TYPE ) logger.info(f"Context retriever tool created with k={Config.INITIAL_RETRIEVAL_K}, search_type='{Config.SEARCH_TYPE}'.") available_tools = [ctx_retriever_tool_instance] # --- LLM Binding (for initial decision in agent_node) --- llm_with_tools = llm.bind_tools(available_tools) logger.info("LLM bound with tools successfully.") # Create ToolNode specifically for context retrieval retrieve_context_node = ToolNode(tools=[ctx_retriever_tool_instance]) # --- Bind LLM and Reranker to Nodes --- agent_node_runnable = functools.partial( agent_node, llm=llm, llm_with_tools=llm_with_tools ) # Bind the loaded reranker model (or None if loading failed) rerank_context_node_runnable = functools.partial(rerank_context_node, reranker_model=reranker_model) evaluate_context_node_runnable = functools.partial(evaluate_context_node, llm=llm) reframe_query_node_runnable = functools.partial(reframe_query_node, llm=llm) # --- Define the Graph --- builder = StateGraph(AgentState) # --- Add Nodes --- builder.add_node("agent_initial", agent_node_runnable) # Handles initial query & first decision builder.add_node("retrieve_context", retrieve_context_node) builder.add_node("rerank_context", rerank_context_node_runnable) # Add the new reranker node builder.add_node("evaluate_context", evaluate_context_node_runnable) builder.add_node("reframe_query", reframe_query_node_runnable) # Use a distinct node name for final answer generation step builder.add_node("agent_final_answer", agent_node_runnable) # --- Define Edges --- builder.set_entry_point("agent_initial") # Decide whether to retrieve or answer directly from the start builder.add_conditional_edges( "agent_initial", tools_condition, # Checks if the AIMessage from agent_initial has tool_calls { "tools": "retrieve_context", # If tool call exists, go retrieve "__end__": "agent_final_answer", # If no tool call, go directly to final answer generation }, ) # --- Main RAG loop with Reranking --- builder.add_edge("retrieve_context", "rerank_context") # Retrieve -> Rerank builder.add_edge("rerank_context", "evaluate_context") # Rerank -> Evaluate # Conditional logic after evaluation remains the same builder.add_conditional_edges( "evaluate_context", decide_next_step, # Use the dedicated decision function based on evaluation of reranked context { "reframe_query": "reframe_query", "agent_final_answer": "agent_final_answer", # Route to final answer generation } ) # Loop back to retrieve after reframing builder.add_edge("reframe_query", "retrieve_context") # Final answer generation leads to end builder.add_edge("agent_final_answer", END) # Compile the graph graph = builder.compile() # # Optional: Save graph visualization # try: # graph.get_graph().draw_mermaid_png(output_file_path="rag_agent_graph.png") # logger.info("Saved graph visualization to rag_agent_graph.png") # except Exception as e: # logger.warning(f"Could not save graph visualization: {e}") logger.info("LangGraph agent compiled successfully...") return graph