PDF_Recogni / src /qa.py
Abhinav Gavireddi
fix: removed redis to store embeddings in memory
33f4e34
raw
history blame
3.93 kB
"""
AnswerGenerator: orchestrates retrieval, re-ranking, and answer generation.
This module contains:
- Retriever: Hybrid BM25 + dense retrieval over parsed chunks
- Reranker: Cross-encoder based re-ranking of candidate chunks
- AnswerGenerator: ties together retrieval, re-ranking, and LLM generation
Each component is modular and can be swapped or extended (e.g., add HyDE retriever).
"""
import os
import json
import numpy as np
import redis
from typing import List, Dict, Any, Tuple
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from src import sanitize_html
from src.utils import LLMClient, logger
from src.retriever import Retriever, RetrieverConfig
class RerankerConfig:
MODEL_NAME = os.getenv('RERANKER_MODEL', 'BAAI/bge-reranker-v2-Gemma')
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
class Reranker:
"""
Cross-encoder re-ranker using a transformer-based sequence classification model.
"""
def __init__(self, config: RerankerConfig):
try:
self.tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME)
self.model = AutoModelForSequenceClassification.from_pretrained(config.MODEL_NAME)
self.model.to(config.DEVICE)
except Exception as e:
logger.error(f'Failed to load reranker model: {e}')
raise
def rerank(self, query: str, candidates: List[Dict[str, Any]], top_k: int) -> List[Dict[str, Any]]:
"""Score each candidate and return top_k sorted by relevance."""
if not candidates:
logger.warning('No candidates provided to rerank.')
return []
try:
inputs = self.tokenizer(
[query] * len(candidates),
[c.get('narration', '') for c in candidates],
padding=True,
truncation=True,
return_tensors='pt'
).to(RerankerConfig.DEVICE)
with torch.no_grad():
logits = self.model(**inputs).logits.squeeze(-1)
scores = torch.sigmoid(logits).cpu().numpy()
paired = list(zip(candidates, scores))
ranked = sorted(paired, key=lambda x: x[1], reverse=True)
return [c for c, _ in ranked[:top_k]]
except Exception as e:
logger.error(f'Reranking failed: {e}')
return candidates[:top_k]
class AnswerGenerator:
"""
Main interface: given parsed chunks and a question, returns answer and supporting chunks.
"""
def __init__(self):
self.ret_config = RetrieverConfig()
self.rerank_config = RerankerConfig()
def answer(self, chunks: List[Dict[str, Any]], question: str) -> Tuple[str, List[Dict[str, Any]]]:
logger.info('Answering question', question=question)
question = sanitize_html(question)
try:
retriever = Retriever(chunks, self.ret_config)
candidates = retriever.retrieve(question)
reranker = Reranker(self.rerank_config)
top_chunks = reranker.rerank(question, candidates, top_k=5)
context = "\n\n".join([f"- {c.get('narration', '')}" for c in top_chunks])
prompt = (
f"You are a knowledgeable assistant. "
f"Use the following extracted document snippets to answer the question."
f"\n\nContext:\n{context}"
f"\n\nQuestion: {question}\nAnswer:"
)
answer = LLMClient.generate(prompt)
return answer, top_chunks
except Exception as e:
logger.error(f'Failed to answer question: {e}')
return "Failed to generate answer due to error.", []
# Example usage:
# generator = AnswerGenerator()
# ans, ctx = generator.answer(parsed_chunks, "What was the Q2 revenue?")