File size: 3,874 Bytes
3301b3c
 
 
 
 
 
 
 
 
 
 
 
 
a1d050d
 
3301b3c
 
 
 
 
 
 
33f4e34
69979b2
 
33f4e34
 
 
 
 
 
3301b3c
 
 
33f4e34
 
 
 
69979b2
33f4e34
 
 
 
 
 
 
 
04db7e0
 
 
 
 
 
 
69979b2
04db7e0
33f4e34
 
 
 
 
3301b3c
 
 
 
04db7e0
 
3301b3c
04db7e0
 
 
 
 
3301b3c
04db7e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3301b3c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
AnswerGenerator: orchestrates retrieval, re-ranking, and answer generation.

This module contains:
 - Retriever: Hybrid BM25 + dense retrieval over parsed chunks
 - Reranker: Cross-encoder based re-ranking of candidate chunks
 - AnswerGenerator: ties together retrieval, re-ranking, and LLM generation

Each component is modular and can be swapped or extended (e.g., add HyDE retriever).
"""
import os
from typing import List, Dict, Any, Tuple

from src import RerankerConfig, logger
from src.utils import LLMClient
from src.retriever import Retriever, RetrieverConfig

class Reranker:
    """
    Cross-encoder re-ranker using a transformer-based sequence classification model.
    """
    def __init__(self, config: RerankerConfig):
        try:
            from transformers import AutoTokenizer, AutoModelForSequenceClassification
            import torch
            self.tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME)
            self.model = AutoModelForSequenceClassification.from_pretrained(config.MODEL_NAME)
            self.model.to(config.DEVICE)
        except Exception as e:
            logger.error(f'Failed to load reranker model: {e}')
            raise

    def rerank(self, query: str, candidates: List[Dict[str, Any]], top_k: int) -> List[Dict[str, Any]]:
        """Score each candidate and return top_k sorted by relevance."""
        if not candidates:
            logger.warning('No candidates provided to rerank.')
            return []
        try:
            import torch
            inputs = self.tokenizer(
                [query] * len(candidates),
                [c.get('narration', '') for c in candidates],
                padding=True,
                truncation=True,
                return_tensors='pt'
            ).to(RerankerConfig.DEVICE)
            with torch.no_grad():
                out = self.model(**inputs)
            
            logits = out.logits
            if logits.ndim == 2 and logits.shape[1] == 1:
                logits = logits.squeeze(-1)  # only squeeze if it's (batch, 1)

            probs = torch.sigmoid(logits).cpu().numpy().flatten()  # flatten always ensures 1D array
            paired = [(c, float(probs[idx])) for idx, c in enumerate(candidates)]

            ranked = sorted(paired, key=lambda x: x[1], reverse=True)
            return [c for c, _ in ranked[:top_k]]
        except Exception as e:
            logger.error(f'Reranking failed: {e}')
            return candidates[:top_k]


class AnswerGenerator:
    """
    Main interface: initializes Retriever + Reranker once, then
    answers multiple questions without re-loading models each time.
    """
    def __init__(self, chunks: List[Dict[str, Any]]):
        self.chunks = chunks
        self.retriever = Retriever(chunks, RetrieverConfig)
        self.reranker  = Reranker(RerankerConfig)
        self.top_k = RetrieverConfig.TOP_K // 2

    def answer(
        self, question: str
    ) -> Tuple[str, List[Dict[str, Any]]]:
        candidates = self.retriever.retrieve(question)
        top_chunks = self.reranker.rerank(question, candidates, self.top_k)
        context = "\n\n".join(f"- {c['narration']}" for c in top_chunks)
        prompt = (
            "You are a knowledgeable assistant. Use the following snippets to answer."
            f"\n\nContext information is below: \n"
            '------------------------------------'
            f"{context}"
            '------------------------------------'
            "Given the context information above I want you \n"
            "to think step by step to answer the query in a crisp \n"
            "manner, incase you don't have enough information, \n"
            "just say I don't know!. \n\n"
            f"\n\nQuestion: {question} \n"
            "Answer:"
        )
        answer = LLMClient.generate(prompt)
        return answer, top_chunks