Spaces:
Building
Building
File size: 6,491 Bytes
24ae72d b9ccd0b 24ae72d b9ccd0b 24ae72d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# /home/bk_anupam/code/LLM_agents/RAG_BOT/agent/graph_builder.py
import functools
import os
import sys
from typing import Literal
from langchain_chroma import Chroma
from langchain_google_genai import ChatGoogleGenerativeAI
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode, tools_condition
from sentence_transformers import CrossEncoder
# Add the project root to the Python path
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
sys.path.insert(0, project_root)
from RAG_BOT.config import Config
from RAG_BOT.logger import logger
from RAG_BOT.context_retriever_tool import create_context_retriever_tool
from RAG_BOT.agent.state import AgentState
from RAG_BOT.agent.agent_node import agent_node
from RAG_BOT.agent.retrieval_nodes import rerank_context_node
from RAG_BOT.agent.evaluation_nodes import evaluate_context_node, reframe_query_node
# --- Conditional Edge Logic ---
def decide_next_step(state: AgentState) -> Literal["reframe_query", "agent_final_answer", "__end__"]:
"""
Determines the next node based on evaluation result and retry status.
"""
logger.info("--- Deciding Next Step ---")
evaluation = state.get('evaluation_result')
retry_attempted = state.get('retry_attempted', False)
logger.info(f"Evaluation: {evaluation}, Retry Attempted: {retry_attempted}")
if evaluation == "sufficient":
logger.info("Decision: Context sufficient, proceed to final answer generation.")
return "agent_final_answer" # Route to agent node for final answer
elif not retry_attempted:
logger.info("Decision: Context insufficient, attempt retry.")
return "reframe_query" # Route to reframe node
else:
logger.info("Decision: Context insufficient after retry, proceed to 'cannot find' message.")
return "agent_final_answer" # Route to agent node for "cannot find" message
# --- Graph Builder ---
def build_agent(vectordb: Chroma, model_name: str = Config.LLM_MODEL_NAME) -> StateGraph:
"""Builds the multi-node LangGraph agent."""
llm = ChatGoogleGenerativeAI(model=model_name, temperature=Config.TEMPERATURE)
logger.info(f"LLM model '{model_name}' initialized with temperature {Config.TEMPERATURE}.")
# --- Reranker Model Initialization ---
reranker_model = None # Initialize as None
try:
reranker_model_name = Config.RERANKER_MODEL_NAME
logger.info(f"Loading reranker model: {reranker_model_name}")
reranker_model = CrossEncoder(reranker_model_name)
logger.info("Reranker model loaded successfully.")
except Exception as e:
logger.error(f"Failed to load reranker model '{Config.RERANKER_MODEL_NAME}': {e}", exc_info=True)
# The graph will proceed, but rerank_context_node will skip reranking
# --- Tool Preparation ---
# Use INITIAL_RETRIEVAL_K for the retriever tool that feeds the reranker
ctx_retriever_tool_instance = create_context_retriever_tool(
vectordb=vectordb,
k=Config.INITIAL_RETRIEVAL_K, # Use the larger K for initial retrieval
search_type=Config.SEARCH_TYPE
)
logger.info(f"Context retriever tool created with k={Config.INITIAL_RETRIEVAL_K}, search_type='{Config.SEARCH_TYPE}'.")
available_tools = [ctx_retriever_tool_instance]
# --- LLM Binding (for initial decision in agent_node) ---
llm_with_tools = llm.bind_tools(available_tools)
logger.info("LLM bound with tools successfully.")
# Create ToolNode specifically for context retrieval
retrieve_context_node = ToolNode(tools=[ctx_retriever_tool_instance])
# --- Bind LLM and Reranker to Nodes ---
agent_node_runnable = functools.partial(
agent_node,
llm=llm,
llm_with_tools=llm_with_tools
)
# Bind the loaded reranker model (or None if loading failed)
rerank_context_node_runnable = functools.partial(rerank_context_node, reranker_model=reranker_model)
evaluate_context_node_runnable = functools.partial(evaluate_context_node, llm=llm)
reframe_query_node_runnable = functools.partial(reframe_query_node, llm=llm)
# --- Define the Graph ---
builder = StateGraph(AgentState)
# --- Add Nodes ---
builder.add_node("agent_initial", agent_node_runnable) # Handles initial query & first decision
builder.add_node("retrieve_context", retrieve_context_node)
builder.add_node("rerank_context", rerank_context_node_runnable) # Add the new reranker node
builder.add_node("evaluate_context", evaluate_context_node_runnable)
builder.add_node("reframe_query", reframe_query_node_runnable)
# Use a distinct node name for final answer generation step
builder.add_node("agent_final_answer", agent_node_runnable)
# --- Define Edges ---
builder.set_entry_point("agent_initial")
# Decide whether to retrieve or answer directly from the start
builder.add_conditional_edges(
"agent_initial",
tools_condition, # Checks if the AIMessage from agent_initial has tool_calls
{
"tools": "retrieve_context", # If tool call exists, go retrieve
"__end__": "agent_final_answer", # If no tool call, go directly to final answer generation
},
)
# --- Main RAG loop with Reranking ---
builder.add_edge("retrieve_context", "rerank_context") # Retrieve -> Rerank
builder.add_edge("rerank_context", "evaluate_context") # Rerank -> Evaluate
# Conditional logic after evaluation remains the same
builder.add_conditional_edges(
"evaluate_context",
decide_next_step, # Use the dedicated decision function based on evaluation of reranked context
{
"reframe_query": "reframe_query",
"agent_final_answer": "agent_final_answer", # Route to final answer generation
}
)
# Loop back to retrieve after reframing
builder.add_edge("reframe_query", "retrieve_context")
# Final answer generation leads to end
builder.add_edge("agent_final_answer", END)
# Compile the graph
graph = builder.compile()
# # Optional: Save graph visualization
# try:
# graph.get_graph().draw_mermaid_png(output_file_path="rag_agent_graph.png")
# logger.info("Saved graph visualization to rag_agent_graph.png")
# except Exception as e:
# logger.warning(f"Could not save graph visualization: {e}")
logger.info("LangGraph agent compiled successfully...")
return graph
|