|
""" |
|
AnswerGenerator: orchestrates retrieval, re-ranking, and answer generation. |
|
|
|
This module contains: |
|
- Retriever: Hybrid BM25 + dense retrieval over parsed chunks |
|
- Reranker: Cross-encoder based re-ranking of candidate chunks |
|
- AnswerGenerator: ties together retrieval, re-ranking, and LLM generation |
|
|
|
Each component is modular and can be swapped or extended (e.g., add HyDE retriever). |
|
""" |
|
import os |
|
import json |
|
import numpy as np |
|
import redis |
|
from typing import List, Dict, Any, Tuple |
|
|
|
from sentence_transformers import SentenceTransformer |
|
from rank_bm25 import BM25Okapi |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
|
|
from src import sanitize_html |
|
from src.utils import LLMClient, logger |
|
from src.retriever import Retriever, RetrieverConfig |
|
|
|
|
|
class RerankerConfig: |
|
MODEL_NAME = os.getenv('RERANKER_MODEL', 'BAAI/bge-reranker-v2-Gemma') |
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
class Reranker: |
|
""" |
|
Cross-encoder re-ranker using a transformer-based sequence classification model. |
|
""" |
|
def __init__(self, config: RerankerConfig): |
|
try: |
|
self.tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME) |
|
self.model = AutoModelForSequenceClassification.from_pretrained(config.MODEL_NAME) |
|
self.model.to(config.DEVICE) |
|
except Exception as e: |
|
logger.error(f'Failed to load reranker model: {e}') |
|
raise |
|
|
|
def rerank(self, query: str, candidates: List[Dict[str, Any]], top_k: int) -> List[Dict[str, Any]]: |
|
"""Score each candidate and return top_k sorted by relevance.""" |
|
if not candidates: |
|
logger.warning('No candidates provided to rerank.') |
|
return [] |
|
try: |
|
inputs = self.tokenizer( |
|
[query] * len(candidates), |
|
[c.get('narration', '') for c in candidates], |
|
padding=True, |
|
truncation=True, |
|
return_tensors='pt' |
|
).to(RerankerConfig.DEVICE) |
|
with torch.no_grad(): |
|
logits = self.model(**inputs).logits.squeeze(-1) |
|
scores = torch.sigmoid(logits).cpu().numpy() |
|
paired = list(zip(candidates, scores)) |
|
ranked = sorted(paired, key=lambda x: x[1], reverse=True) |
|
return [c for c, _ in ranked[:top_k]] |
|
except Exception as e: |
|
logger.error(f'Reranking failed: {e}') |
|
return candidates[:top_k] |
|
|
|
|
|
class AnswerGenerator: |
|
""" |
|
Main interface: given parsed chunks and a question, returns answer and supporting chunks. |
|
""" |
|
def __init__(self): |
|
self.ret_config = RetrieverConfig() |
|
self.rerank_config = RerankerConfig() |
|
|
|
def answer(self, chunks: List[Dict[str, Any]], question: str) -> Tuple[str, List[Dict[str, Any]]]: |
|
logger.info('Answering question', question=question) |
|
question = sanitize_html(question) |
|
try: |
|
retriever = Retriever(chunks, self.ret_config) |
|
candidates = retriever.retrieve(question) |
|
reranker = Reranker(self.rerank_config) |
|
top_chunks = reranker.rerank(question, candidates, top_k=5) |
|
context = "\n\n".join([f"- {c.get('narration', '')}" for c in top_chunks]) |
|
prompt = ( |
|
f"You are a knowledgeable assistant. " |
|
f"Use the following extracted document snippets to answer the question." |
|
f"\n\nContext:\n{context}" |
|
f"\n\nQuestion: {question}\nAnswer:" |
|
) |
|
answer = LLMClient.generate(prompt) |
|
return answer, top_chunks |
|
except Exception as e: |
|
logger.error(f'Failed to answer question: {e}') |
|
return "Failed to generate answer due to error.", [] |
|
|
|
|
|
|
|
|
|
|