File size: 4,106 Bytes
24ae72d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9ccd0b
24ae72d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# /home/bk_anupam/code/LLM_agents/RAG_BOT/agent/retrieval_nodes.py
import os
import sys
from operator import itemgetter
from typing import List

from langchain_core.messages import ToolMessage
from sentence_transformers import CrossEncoder

# Add the project root to the Python path
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
sys.path.insert(0, project_root)

from RAG_BOT.config import Config
from RAG_BOT.logger import logger
from RAG_BOT.agent.state import AgentState


def rerank_context_node(state: AgentState, reranker_model: CrossEncoder):
    """
    Reranks the initially retrieved documents based on the current query.
    Updates the state with the final, concatenated context string.
    """
    logger.info("--- Executing Rerank Context Node ---")
    current_query = state.get('current_query')
    logger.info(f"Current query for reranking: '{current_query}'")
    messages = state.get('messages')
    last_message = messages[-1] if messages else None
    logger.info(f"Last message type: {type(last_message)}")    
    retrieved_docs_artifact = None
    # *** KEY CHANGE: Access the artifact from the last message ***
    if isinstance(last_message, ToolMessage) and hasattr(last_message, 'artifact'):
        # Optional: Check if it's the correct tool message
        if last_message.name == 'retrieve_context':
            retrieved_docs_artifact = last_message.artifact
            logger.info(f"Artifact received from tool '{last_message.name}': type {type(retrieved_docs_artifact)}")
        else:
            logger.error(f"Last message is a ToolMessage, but not from 'retrieve_context' (name: {last_message.name}). Skipping reranking.")
            # Skip reranking if not from the correct tool
            return {"context": None}
    else:
        logger.warning(f"Last message is not a ToolMessage with an artifact (type: {type(last_message)}). Skipping reranking.")
        return {"context": None}

    # Proceed only if we got a valid list artifact
    if not isinstance(retrieved_docs_artifact, list):
        logger.error("No valid document list artifact found to rerank.")
        # Set context to empty or None depending on desired downstream handling
        return {"context": ""} # Or {"context": None}

    # Validate reranker and input
    if reranker_model is None:
        logger.error("Reranker model not loaded. Concatenating original artifact docs.")
        final_context = "\n\n".join(retrieved_docs_artifact)
        return {"context": final_context}

    if not retrieved_docs_artifact:
        logger.info("No documents in artifact to rerank.")
        return {"context": ""}

    logger.info(f"Reranking {len(retrieved_docs_artifact)} documents for query: '{current_query}'")
    # Prepare pairs for the cross-encoder
    pairs = [[current_query, doc] for doc in retrieved_docs_artifact]
    try:
        # Get scores from the cross-encoder
        scores = reranker_model.predict(pairs)
        logger.info(f"Reranking scores obtained (Top {Config.RERANK_TOP_N}): {scores[:Config.RERANK_TOP_N]}")
        # Combine docs with scores and sort
        scored_docs = list(zip(scores, retrieved_docs_artifact))
        scored_docs.sort(key=itemgetter(0), reverse=True)
        # Select top N documents based on config
        reranked_docs = [doc for score, doc in scored_docs[:Config.RERANK_TOP_N]]
        logger.info(f"Selected top {len(reranked_docs)} documents after reranking.")
        # Concatenate the final context
        final_context = "\n\n".join(reranked_docs)
        logger.info(f"Final reranked context snippet: {final_context[:400]}...")
        # Update state with the final context
        # Optionally store raw docs too: "raw_retrieved_docs": retrieved_docs
        return {"context": final_context}
    except Exception as e:
        logger.error(f"Error during reranking: {e}. Using original context without reranking", exc_info=True)
        # Fallback: use the original concatenated context
        final_context = "\n\n".join(retrieved_docs_artifact)
        return {"context": final_context}