Spaces:
Build error
Build error
File size: 4,145 Bytes
3301b3c 33f4e34 3301b3c 33f4e34 3301b3c 33f4e34 04db7e0 33f4e34 3301b3c 04db7e0 3301b3c 04db7e0 3301b3c 04db7e0 3301b3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
"""
AnswerGenerator: orchestrates retrieval, re-ranking, and answer generation.
This module contains:
- Retriever: Hybrid BM25 + dense retrieval over parsed chunks
- Reranker: Cross-encoder based re-ranking of candidate chunks
- AnswerGenerator: ties together retrieval, re-ranking, and LLM generation
Each component is modular and can be swapped or extended (e.g., add HyDE retriever).
"""
import os
from typing import List, Dict, Any, Tuple
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from src import sanitize_html
from src.utils import LLMClient, logger
from src.retriever import Retriever, RetrieverConfig
class RerankerConfig:
MODEL_NAME = os.getenv('RERANKER_MODEL', 'BAAI/bge-reranker-v2-Gemma')
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
class Reranker:
"""
Cross-encoder re-ranker using a transformer-based sequence classification model.
"""
def __init__(self, config: RerankerConfig):
try:
self.tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME)
self.model = AutoModelForSequenceClassification.from_pretrained(config.MODEL_NAME)
self.model.to(config.DEVICE)
except Exception as e:
logger.error(f'Failed to load reranker model: {e}')
raise
def rerank(self, query: str, candidates: List[Dict[str, Any]], top_k: int) -> List[Dict[str, Any]]:
"""Score each candidate and return top_k sorted by relevance."""
if not candidates:
logger.warning('No candidates provided to rerank.')
return []
try:
inputs = self.tokenizer(
[query] * len(candidates),
[c.get('narration', '') for c in candidates],
padding=True,
truncation=True,
return_tensors='pt'
).to(RerankerConfig.DEVICE)
with torch.no_grad():
out = self.model(**inputs)
logits = out.logits
if logits.ndim == 2 and logits.shape[1] == 1:
logits = logits.squeeze(-1) # only squeeze if it's (batch, 1)
probs = torch.sigmoid(logits).cpu().numpy().flatten() # flatten always ensures 1D array
paired = []
for idx, c in enumerate(candidates):
score = float(probs[idx])
paired.append((c, score))
ranked = sorted(paired, key=lambda x: x[1], reverse=True)
return [c for c, _ in ranked[:top_k]]
except Exception as e:
logger.error(f'Reranking failed: {e}')
return candidates[:top_k]
class AnswerGenerator:
"""
Main interface: initializes Retriever + Reranker once, then
answers multiple questions without re-loading models each time.
"""
def __init__(self, chunks: List[Dict[str, Any]]):
self.chunks = chunks
self.retriever = Retriever(chunks, RetrieverConfig)
self.reranker = Reranker(RerankerConfig)
self.top_k = RetrieverConfig.TOP_K // 2
def answer(
self, question: str
) -> Tuple[str, List[Dict[str, Any]]]:
candidates = self.retriever.retrieve(question)
top_chunks = self.reranker.rerank(question, candidates, self.top_k)
context = "\n\n".join(f"- {c['narration']}" for c in top_chunks)
prompt = (
"You are a knowledgeable assistant. Use the following snippets to answer."
f"\n\nContext information is below: \n"
'------------------------------------'
f"{context}"
'------------------------------------'
"Given the context information above I want you \n"
"to think step by step to answer the query in a crisp \n"
"manner, incase you don't have enough information, \n"
"just say I don't know!. \n\n"
f"\n\nQuestion: {question} \n"
"Answer:"
)
answer = LLMClient.generate(prompt)
return answer, top_chunks
|