Spaces:
Build error
Build error
""" | |
AnswerGenerator: orchestrates retrieval, re-ranking, and answer generation. | |
This module contains: | |
- Retriever: Hybrid BM25 + dense retrieval over parsed chunks | |
- Reranker: Cross-encoder based re-ranking of candidate chunks | |
- AnswerGenerator: ties together retrieval, re-ranking, and LLM generation | |
Each component is modular and can be swapped or extended (e.g., add HyDE retriever). | |
""" | |
import os | |
import json | |
import numpy as np | |
import redis | |
from typing import List, Dict, Any, Tuple | |
from sentence_transformers import SentenceTransformer | |
from rank_bm25 import BM25Okapi | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
from src import sanitize_html | |
from src.utils import LLMClient, logger | |
from src.retriever import Retriever, RetrieverConfig | |
class RerankerConfig: | |
MODEL_NAME = os.getenv('RERANKER_MODEL', 'BAAI/bge-reranker-v2-Gemma') | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
class Reranker: | |
""" | |
Cross-encoder re-ranker using a transformer-based sequence classification model. | |
""" | |
def __init__(self, config: RerankerConfig): | |
try: | |
self.tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME) | |
self.model = AutoModelForSequenceClassification.from_pretrained(config.MODEL_NAME) | |
self.model.to(config.DEVICE) | |
except Exception as e: | |
logger.error(f'Failed to load reranker model: {e}') | |
raise | |
def rerank(self, query: str, candidates: List[Dict[str, Any]], top_k: int) -> List[Dict[str, Any]]: | |
"""Score each candidate and return top_k sorted by relevance.""" | |
if not candidates: | |
logger.warning('No candidates provided to rerank.') | |
return [] | |
try: | |
inputs = self.tokenizer( | |
[query] * len(candidates), | |
[c.get('narration', '') for c in candidates], | |
padding=True, | |
truncation=True, | |
return_tensors='pt' | |
).to(RerankerConfig.DEVICE) | |
with torch.no_grad(): | |
logits = self.model(**inputs).logits.squeeze(-1) | |
scores = torch.sigmoid(logits).cpu().numpy() | |
paired = list(zip(candidates, scores)) | |
ranked = sorted(paired, key=lambda x: x[1], reverse=True) | |
return [c for c, _ in ranked[:top_k]] | |
except Exception as e: | |
logger.error(f'Reranking failed: {e}') | |
return candidates[:top_k] | |
class AnswerGenerator: | |
""" | |
Main interface: given parsed chunks and a question, returns answer and supporting chunks. | |
""" | |
def __init__(self): | |
self.ret_config = RetrieverConfig() | |
self.rerank_config = RerankerConfig() | |
def answer(self, chunks: List[Dict[str, Any]], question: str) -> Tuple[str, List[Dict[str, Any]]]: | |
logger.info('Answering question', question=question) | |
question = sanitize_html(question) | |
try: | |
retriever = Retriever(chunks, self.ret_config) | |
candidates = retriever.retrieve(question) | |
reranker = Reranker(self.rerank_config) | |
top_chunks = reranker.rerank(question, candidates, top_k=5) | |
context = "\n\n".join([f"- {c.get('narration', '')}" for c in top_chunks]) | |
prompt = ( | |
f"You are a knowledgeable assistant. " | |
f"Use the following extracted document snippets to answer the question." | |
f"\n\nContext:\n{context}" | |
f"\n\nQuestion: {question}\nAnswer:" | |
) | |
answer = LLMClient.generate(prompt) | |
return answer, top_chunks | |
except Exception as e: | |
logger.error(f'Failed to answer question: {e}') | |
return "Failed to generate answer due to error.", [] | |
# Example usage: | |
# generator = AnswerGenerator() | |
# ans, ctx = generator.answer(parsed_chunks, "What was the Q2 revenue?") | |