Spaces:
Build error
Build error
""" | |
AnswerGenerator: orchestrates retrieval, re-ranking, and answer generation. | |
This module contains: | |
- Retriever: Hybrid BM25 + dense retrieval over parsed chunks | |
- Reranker: Cross-encoder based re-ranking of candidate chunks | |
- AnswerGenerator: ties together retrieval, re-ranking, and LLM generation | |
Each component is modular and can be swapped or extended (e.g., add HyDE retriever). | |
""" | |
import os | |
from typing import List, Dict, Any, Tuple | |
from sentence_transformers import SentenceTransformer | |
from rank_bm25 import BM25Okapi | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
from src import sanitize_html | |
from src.utils import LLMClient, logger | |
from src.retriever import Retriever, RetrieverConfig | |
class RerankerConfig: | |
MODEL_NAME = os.getenv('RERANKER_MODEL', 'BAAI/bge-reranker-v2-Gemma') | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
class Reranker: | |
""" | |
Cross-encoder re-ranker using a transformer-based sequence classification model. | |
""" | |
def __init__(self, config: RerankerConfig): | |
try: | |
self.tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME) | |
self.model = AutoModelForSequenceClassification.from_pretrained(config.MODEL_NAME) | |
self.model.to(config.DEVICE) | |
except Exception as e: | |
logger.error(f'Failed to load reranker model: {e}') | |
raise | |
def rerank(self, query: str, candidates: List[Dict[str, Any]], top_k: int) -> List[Dict[str, Any]]: | |
"""Score each candidate and return top_k sorted by relevance.""" | |
if not candidates: | |
logger.warning('No candidates provided to rerank.') | |
return [] | |
try: | |
inputs = self.tokenizer( | |
[query] * len(candidates), | |
[c.get('narration', '') for c in candidates], | |
padding=True, | |
truncation=True, | |
return_tensors='pt' | |
).to(RerankerConfig.DEVICE) | |
with torch.no_grad(): | |
out = self.model(**inputs) | |
logits = out.logits | |
if logits.ndim == 2 and logits.shape[1] == 1: | |
logits = logits.squeeze(-1) # only squeeze if it's (batch, 1) | |
probs = torch.sigmoid(logits).cpu().numpy().flatten() # flatten always ensures 1D array | |
paired = [] | |
for idx, c in enumerate(candidates): | |
score = float(probs[idx]) | |
paired.append((c, score)) | |
ranked = sorted(paired, key=lambda x: x[1], reverse=True) | |
return [c for c, _ in ranked[:top_k]] | |
except Exception as e: | |
logger.error(f'Reranking failed: {e}') | |
return candidates[:top_k] | |
class AnswerGenerator: | |
""" | |
Main interface: initializes Retriever + Reranker once, then | |
answers multiple questions without re-loading models each time. | |
""" | |
def __init__(self, chunks: List[Dict[str, Any]]): | |
self.chunks = chunks | |
self.retriever = Retriever(chunks, RetrieverConfig) | |
self.reranker = Reranker(RerankerConfig) | |
self.top_k = RetrieverConfig.TOP_K // 2 | |
def answer( | |
self, question: str | |
) -> Tuple[str, List[Dict[str, Any]]]: | |
candidates = self.retriever.retrieve(question) | |
top_chunks = self.reranker.rerank(question, candidates, self.top_k) | |
context = "\n\n".join(f"- {c['narration']}" for c in top_chunks) | |
prompt = ( | |
"You are a knowledgeable assistant. Use the following snippets to answer." | |
f"\n\nContext information is below: \n" | |
'------------------------------------' | |
f"{context}" | |
'------------------------------------' | |
"Given the context information above I want you \n" | |
"to think step by step to answer the query in a crisp \n" | |
"manner, incase you don't have enough information, \n" | |
"just say I don't know!. \n\n" | |
f"\n\nQuestion: {question} \n" | |
"Answer:" | |
) | |
answer = LLMClient.generate(prompt) | |
return answer, top_chunks | |