""" AnswerGenerator: orchestrates retrieval, re-ranking, and answer generation. This module contains: - Retriever: Hybrid BM25 + dense retrieval over parsed chunks - Reranker: Cross-encoder based re-ranking of candidate chunks - AnswerGenerator: ties together retrieval, re-ranking, and LLM generation Each component is modular and can be swapped or extended (e.g., add HyDE retriever). """ import os from typing import List, Dict, Any, Tuple from src import RerankerConfig, logger from src.utils import LLMClient from src.retriever import Retriever, RetrieverConfig class Reranker: """ Cross-encoder re-ranker using a transformer-based sequence classification model. """ def __init__(self, config: RerankerConfig): try: from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch self.tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME) self.model = AutoModelForSequenceClassification.from_pretrained(config.MODEL_NAME) self.model.to(config.DEVICE) except Exception as e: logger.error(f'Failed to load reranker model: {e}') raise def rerank(self, query: str, candidates: List[Dict[str, Any]], top_k: int) -> List[Dict[str, Any]]: """Score each candidate and return top_k sorted by relevance.""" if not candidates: logger.warning('No candidates provided to rerank.') return [] try: import torch inputs = self.tokenizer( [query] * len(candidates), [c.get('narration', '') for c in candidates], padding=True, truncation=True, return_tensors='pt' ).to(RerankerConfig.DEVICE) with torch.no_grad(): out = self.model(**inputs) logits = out.logits if logits.ndim == 2 and logits.shape[1] == 1: logits = logits.squeeze(-1) # only squeeze if it's (batch, 1) probs = torch.sigmoid(logits).cpu().numpy().flatten() # flatten always ensures 1D array paired = [(c, float(probs[idx])) for idx, c in enumerate(candidates)] ranked = sorted(paired, key=lambda x: x[1], reverse=True) return [c for c, _ in ranked[:top_k]] except Exception as e: logger.error(f'Reranking failed: {e}') return candidates[:top_k] class AnswerGenerator: """ Main interface: initializes Retriever + Reranker once, then answers multiple questions without re-loading models each time. """ def __init__(self, chunks: List[Dict[str, Any]]): self.chunks = chunks self.retriever = Retriever(chunks, RetrieverConfig) self.reranker = Reranker(RerankerConfig) self.top_k = RetrieverConfig.TOP_K // 2 def answer( self, question: str ) -> Tuple[str, List[Dict[str, Any]]]: candidates = self.retriever.retrieve(question) top_chunks = self.reranker.rerank(question, candidates, self.top_k) context = "\n\n".join(f"- {c['narration']}" for c in top_chunks) prompt = ( "You are a knowledgeable assistant. Use the following snippets to answer." f"\n\nContext information is below: \n" '------------------------------------' f"{context}" '------------------------------------' "Given the context information above I want you \n" "to think step by step to answer the query in a crisp \n" "manner, incase you don't have enough information, \n" "just say I don't know!. \n\n" f"\n\nQuestion: {question} \n" "Answer:" ) answer = LLMClient.generate(prompt) return answer, top_chunks