Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# coding: utf-8 | |
# # Retrieval-Augmented QA Demo | |
# | |
# This notebook builds a minimal RAG (Retrieval-Augmented Generation) pipeline with enhancements: | |
# | |
# - Slimmed & deduplicated corpora | |
# - Chunking long passages | |
# - Persistent FAISS index & embeddings | |
# - Distance threshold to avoid hallucinations | |
# - Context-length control | |
# - Polished Gradio interface with separate contexts panel | |
# ## 1. Configuration & Imports | |
# | |
# We detect device, print settings, and support loading saved index. | |
# In[2]: | |
import os | |
import pickle | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer, CrossEncoder | |
import faiss | |
import numpy as np | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
from transformers import AutoTokenizer as _AutoTokenizer | |
import gradio as gr | |
import evaluate | |
# Settings | |
data_dir = os.path.join(os.getcwd(), "data") | |
os.makedirs(data_dir, exist_ok=True) | |
INDEX_PATH = os.path.join(data_dir, "faiss_index.faiss") | |
EMB_PATH = os.path.join(data_dir, "embeddings.npy") | |
PCTX_PATH = os.path.join(data_dir, "passages.pkl") | |
MODEL_NAME = os.getenv("MODEL_NAME", "google/flan-t5-small") | |
EMBEDDER_MODEL = os.getenv("EMBEDDER_MODEL", "sentence-transformers/all-MiniLM-L6-v2") | |
device = 0 if torch.cuda.is_available() else -1 | |
print(f"Using model: {MODEL_NAME}, embedder: {EMBEDDER_MODEL}, device: {'GPU' if device==0 else 'CPU'}") | |
# Threshold for maximum acceptable L2 distance | |
dist_threshold = 1.0 # tune as needed | |
# Max words per context snippet | |
max_context_words = 200 | |
# ## Useful functions | |
def make_context_snippets(contexts, max_words=200): | |
snippets = [] | |
for c in contexts: | |
words = c.split() | |
if len(words) > max_words: | |
c = " ".join(words[:max_words]) + " ... [truncated]" | |
snippets.append(c) | |
return snippets | |
# ## 2. Load, Deduplicate & Chunk Corpora | |
# | |
# For this demo we sample small slices and remove duplicates. We also chunk any passage >512 tokens. | |
# | |
# tokenizer for chunking | |
chunk_tokenizer = _AutoTokenizer.from_pretrained(MODEL_NAME) | |
max_tokens = chunk_tokenizer.model_max_length | |
def chunk_text(text: str, max_tokens: int, stride: int = None) -> list[str]: | |
""" | |
Split `text` into overlapping chunks of up to max_tokens words. | |
By default uses 25% overlap (stride = max_tokens // 4). | |
""" | |
words = text.split() | |
if stride is None: | |
stride = max_tokens // 4 # 25% overlap | |
chunks = [] | |
start = 0 | |
while start < len(words): | |
end = start + max_tokens | |
chunk = " ".join(words[start:end]) | |
chunks.append(chunk) | |
# advance by stride, not full window | |
start += stride | |
return chunks | |
# Load corpora | |
wiki_ds = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages") | |
wiki_passages = wiki_ds["passage"] | |
squad_ds = load_dataset("rajpurkar/squad_v2", split="train[:100]") | |
squad_passages = [ex["context"] for ex in squad_ds] | |
trivia_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="validation[:100]") | |
trivia_passages = [] | |
for ex in trivia_ds: | |
for field in ("wiki_context", "search_context"): | |
txt = ex.get(field) or "" | |
if txt: | |
trivia_passages.append(txt) | |
# Combine, dedupe, chunk | |
all_passages = wiki_passages + squad_passages + trivia_passages | |
unique_passages = list(dict.fromkeys(all_passages)) | |
passages = [] | |
for p in unique_passages: | |
# count tokens without encoding to avoid warnings | |
tokens = chunk_tokenizer.tokenize(p) | |
if len(tokens) > max_tokens: | |
passages.extend(chunk_text(p, max_tokens)) | |
else: | |
passages.append(p) | |
print(f"Total passages after dedupe & chunk: {len(passages)}") | |
# Persist raw passages list | |
with open(PCTX_PATH, "wb") as f: | |
pickle.dump(passages, f) | |
# ## 3. Build or Load FAISS Index & Embeddings | |
# | |
# We save embeddings & index to disk to skip slow re-encoding. | |
# ββ Initialize embedder and reranker ββ | |
from sentence_transformers import SentenceTransformer | |
from torch import no_grad | |
embedder = SentenceTransformer(EMBEDDER_MODEL) | |
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
# ββ Load or (re)build FAISS index with cosine similarity ββ | |
if os.path.exists(INDEX_PATH) and os.path.exists(EMB_PATH): | |
print("Loading saved index and embeddingsβ¦") | |
index = faiss.read_index(INDEX_PATH) | |
embeddings = np.load(EMB_PATH) | |
else: | |
print("Encoding passages (with overlap)β¦") | |
embeddings = embedder.encode( | |
passages, | |
show_progress_bar=True, | |
convert_to_numpy=True, | |
batch_size=32 | |
) | |
# Normalize to unit length so that innerβproduct = cosine sim | |
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) | |
# Build a FAISS index over innerβproduct (cosine) space | |
dim = embeddings.shape[1] | |
index = faiss.IndexFlatIP(dim) | |
index.add(embeddings) | |
# Persist to disk for faster reload | |
faiss.write_index(index, INDEX_PATH) | |
np.save(EMB_PATH, embeddings) | |
print(f"Indexed {index.ntotal} vectors.") | |
# ## 4. Load QA Model & Pipeline | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
qa_pipeline = pipeline( | |
"text2text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device=device, | |
early_stopping=True | |
) | |
print("QA pipeline ready.") | |
# ## 5. Retrieval + Generation Functions | |
# | |
# We bail out early if top distance > threshold to avoid hallucination. | |
def retrieve(question: str, k: int = 20, rerank_k: int = 5): | |
# 1) denseβsearch top k | |
q_emb = embedder.encode([question], convert_to_numpy=True) | |
distances, indices = index.search(q_emb, k) | |
# 2) pull out those k contexts | |
candidates = [passages[i] for i in indices[0]] | |
# 3) score with crossβencoder | |
pairs = [[question, ctx] for ctx in candidates] | |
scores = reranker.predict(pairs) | |
# 4) pick top rerank_k | |
top_idxs = np.argsort(scores)[-rerank_k:][::-1] | |
final_ctxs = [candidates[i] for i in top_idxs] | |
final_dist = [distances[0][i] for i in top_idxs] | |
return final_ctxs, final_dist | |
def generate(question: str, contexts: list) -> str: | |
""" | |
Build a RAG prompt from the retrieved contexts and generate | |
an answer using the HF text2text pipeline. | |
""" | |
# 1) Turn each context into a truncated snippet | |
snippet_lines = [ | |
f"Context {i+1}: {s}" | |
for i, s in enumerate(make_context_snippets(contexts, max_context_words)) | |
] | |
# 2) Build the full prompt | |
prompt = ( | |
"You are a helpful assistant. Use ONLY the following contexts to answer. " | |
"If the answer is not contained, say 'Sorry, I don't know.'\n\n" | |
+ "\n".join(snippet_lines) | |
+ f"\n\nQuestion: {question}\nAnswer:" | |
) | |
# 3) Call the pipeline (it handles tokenization + generation + decoding) | |
result = qa_pipeline(prompt, truncation=True, max_new_tokens=200)[0]["generated_text"] | |
return result.strip() | |
def retrieve_and_answer(question, k=5): | |
contexts, distances = retrieve(question, k=20) | |
if not contexts or distances[0] > dist_threshold: | |
return "Sorry, I don't know.", [] | |
ans = generate(question, contexts) | |
return ans, contexts | |
import random | |
print("Some sample passages:\n") | |
for p in random.sample(passages, 5): | |
print(p, "\n" + "-"*80 + "\n") | |
# ## 6. Gradio Demo Interface | |
# | |
# Separate panels for answer and contexts. | |
def answer_and_contexts(question: str): | |
""" | |
Full end-to-end: retrieve, threshold-check, generate answer, | |
and return both the answer and a formatted string of contexts. | |
""" | |
answer, contexts = retrieve_and_answer(question) | |
# If no valid contexts, just return the apology | |
if not contexts: | |
return answer, "" | |
# Otherwise format each snippet for display | |
ctx_snippets = [ | |
f"Context {i+1}: {s}" | |
for i, s in enumerate(make_context_snippets(contexts, max_context_words)) | |
] | |
return answer, "\n\n---\n\n".join(ctx_snippets) | |
iface = gr.Interface( | |
fn=answer_and_contexts, | |
inputs=gr.Textbox(lines=1, placeholder="Enter your question here...", label="Question"), | |
outputs=[ | |
gr.Textbox(label="Answer"), | |
gr.Textbox(label="Retrieved Contexts") | |
], | |
title="π RAG QA Demo", | |
description="Retrieval-Augmented QA with distance threshold and context preview" | |
) | |
iface.launch() | |
# # Test the Model | |
# load SQuAD v2 (we only need validation split) | |
squad = load_dataset("rajpurkar/squad_v2", split="validation") | |
# load the SQuAD metric (handles no-answer properly) | |
squad_metric = evaluate.load("squad") | |
def retrieval_recall(dataset, k=20, num_samples=100): | |
hits = 0 | |
for ex in dataset.select(range(num_samples)): | |
question = ex["question"] | |
gold_answers = ex["answers"]["text"] # list, empty if unanswerable | |
# get your top-k contexts | |
ctxs, _ = retrieve(question, k=k, rerank_k=k) # or rerank_k smaller | |
# check if any gold answer appears in any context | |
if any(any(ans in ctx for ctx in ctxs) for ans in gold_answers): | |
hits += 1 | |
recall = hits / num_samples | |
print(f"Retrieval Recall@{k}: {recall:.3f}") | |
return recall | |
# ## Only answerable Questions | |
def retrieval_recall_answerable(dataset, k=20, num_samples=100): | |
hits = 0 | |
total = 0 | |
for ex in dataset.select(range(num_samples)): | |
if not ex["answers"]["text"]: | |
continue # skip unanswerable | |
total += 1 | |
ctxs, _ = retrieve(ex["question"], k=k, rerank_k=k) | |
if any(any(ans in ctx for ctx in ctxs) for ans in ex["answers"]["text"]): | |
hits += 1 | |
recall = hits / total | |
print(f"Retrieval Recall@{k} on answerable only: {recall:.3f} ({hits}/{total})") | |
return recall | |
def qa_eval_all(dataset, num_samples=100, k=20): | |
preds, refs = [], [] | |
for ex in dataset.select(range(num_samples)): | |
qid = ex["id"] | |
gold = ex["answers"] | |
# ensure metric has something to iterate over | |
if not gold["text"]: | |
gold = {"text":[""], "answer_start":[0]} | |
ans, _ = retrieve_and_answer(ex["question"], k=k) | |
# for metric purposes, treat our refusal as empty string | |
pred_text = "" if ans.strip().lower().startswith("sorry") else ans | |
preds.append({"id": qid, "prediction_text": pred_text}) | |
refs.append({"id": qid, "answers": gold}) | |
results = squad_metric.compute(predictions=preds, references=refs) | |
print(f"Full QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}") | |
return results | |
def qa_eval_answerable(dataset, num_samples=100, k=20): | |
preds, refs = [], [] | |
for ex in dataset.select(range(num_samples)): | |
if not ex["answers"]["text"]: | |
continue # skip unanswerable | |
qid = ex["id"] | |
ans, _ = retrieve_and_answer(ex["question"], k=k) | |
preds.append({"id": qid, "prediction_text": ans}) | |
refs.append({"id": qid, "answers": ex["answers"]}) | |
results = squad_metric.compute(predictions=preds, references=refs) | |
print(f"Answerable-only QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}") | |
return results | |
retrieval_recall(squad, k=2, num_samples=100) | |
retrieval_recall_answerable(squad, k=2, num_samples=100) | |
qa_eval_all(squad, num_samples=100, k=2) | |
qa_eval_answerable(squad, num_samples=100, k=2) | |