Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| # # Retrieval-Augmented QA Demo | |
| # | |
| # This notebook builds a minimal RAG (Retrieval-Augmented Generation) pipeline with enhancements: | |
| # | |
| # - Slimmed & deduplicated corpora | |
| # - Chunking long passages | |
| # - Persistent FAISS index & embeddings | |
| # - Distance threshold to avoid hallucinations | |
| # - Context-length control | |
| # - Polished Gradio interface with separate contexts panel | |
| # ## 1. Configuration & Imports | |
| # | |
| # We detect device, print settings, and support loading saved index. | |
| # In[2]: | |
| import os | |
| import pickle | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| import faiss | |
| import numpy as np | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| from transformers import AutoTokenizer as _AutoTokenizer | |
| import gradio as gr | |
| import evaluate | |
| # Settings | |
| data_dir = os.path.join(os.getcwd(), "data") | |
| os.makedirs(data_dir, exist_ok=True) | |
| INDEX_PATH = os.path.join(data_dir, "faiss_index.faiss") | |
| EMB_PATH = os.path.join(data_dir, "embeddings.npy") | |
| PCTX_PATH = os.path.join(data_dir, "passages.pkl") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "google/flan-t5-small") | |
| EMBEDDER_MODEL = os.getenv("EMBEDDER_MODEL", "sentence-transformers/all-MiniLM-L6-v2") | |
| device = 0 if torch.cuda.is_available() else -1 | |
| print(f"Using model: {MODEL_NAME}, embedder: {EMBEDDER_MODEL}, device: {'GPU' if device==0 else 'CPU'}") | |
| # Threshold for maximum acceptable L2 distance | |
| dist_threshold = 1.0 # tune as needed | |
| # Max words per context snippet | |
| max_context_words = 200 | |
| # ## Useful functions | |
| def make_context_snippets(contexts, max_words=200): | |
| snippets = [] | |
| for c in contexts: | |
| words = c.split() | |
| if len(words) > max_words: | |
| c = " ".join(words[:max_words]) + " ... [truncated]" | |
| snippets.append(c) | |
| return snippets | |
| # ## 2. Load, Deduplicate & Chunk Corpora | |
| # | |
| # For this demo we sample small slices and remove duplicates. We also chunk any passage >512 tokens. | |
| # | |
| # tokenizer for chunking | |
| chunk_tokenizer = _AutoTokenizer.from_pretrained(MODEL_NAME) | |
| max_tokens = chunk_tokenizer.model_max_length | |
| def chunk_text(text: str, max_tokens: int, stride: int = None) -> list[str]: | |
| """ | |
| Split `text` into overlapping chunks of up to max_tokens words. | |
| By default uses 25% overlap (stride = max_tokens // 4). | |
| """ | |
| words = text.split() | |
| if stride is None: | |
| stride = max_tokens // 4 # 25% overlap | |
| chunks = [] | |
| start = 0 | |
| while start < len(words): | |
| end = start + max_tokens | |
| chunk = " ".join(words[start:end]) | |
| chunks.append(chunk) | |
| # advance by stride, not full window | |
| start += stride | |
| return chunks | |
| # Load corpora | |
| wiki_ds = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages") | |
| wiki_passages = wiki_ds["passage"] | |
| squad_ds = load_dataset("rajpurkar/squad_v2", split="train[:100]") | |
| squad_passages = [ex["context"] for ex in squad_ds] | |
| trivia_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="validation[:100]") | |
| trivia_passages = [] | |
| for ex in trivia_ds: | |
| for field in ("wiki_context", "search_context"): | |
| txt = ex.get(field) or "" | |
| if txt: | |
| trivia_passages.append(txt) | |
| # Combine, dedupe, chunk | |
| all_passages = wiki_passages + squad_passages + trivia_passages | |
| unique_passages = list(dict.fromkeys(all_passages)) | |
| passages = [] | |
| for p in unique_passages: | |
| # count tokens without encoding to avoid warnings | |
| tokens = chunk_tokenizer.tokenize(p) | |
| if len(tokens) > max_tokens: | |
| passages.extend(chunk_text(p, max_tokens)) | |
| else: | |
| passages.append(p) | |
| print(f"Total passages after dedupe & chunk: {len(passages)}") | |
| # Persist raw passages list | |
| with open(PCTX_PATH, "wb") as f: | |
| pickle.dump(passages, f) | |
| # ## 3. Build or Load FAISS Index & Embeddings | |
| # | |
| # We save embeddings & index to disk to skip slow re-encoding. | |
| # ββ Initialize embedder and reranker ββ | |
| from sentence_transformers import SentenceTransformer | |
| from torch import no_grad | |
| embedder = SentenceTransformer(EMBEDDER_MODEL) | |
| reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
| # ββ Load or (re)build FAISS index with cosine similarity ββ | |
| if os.path.exists(INDEX_PATH) and os.path.exists(EMB_PATH): | |
| print("Loading saved index and embeddingsβ¦") | |
| index = faiss.read_index(INDEX_PATH) | |
| embeddings = np.load(EMB_PATH) | |
| else: | |
| print("Encoding passages (with overlap)β¦") | |
| embeddings = embedder.encode( | |
| passages, | |
| show_progress_bar=True, | |
| convert_to_numpy=True, | |
| batch_size=32 | |
| ) | |
| # Normalize to unit length so that innerβproduct = cosine sim | |
| embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) | |
| # Build a FAISS index over innerβproduct (cosine) space | |
| dim = embeddings.shape[1] | |
| index = faiss.IndexFlatIP(dim) | |
| index.add(embeddings) | |
| # Persist to disk for faster reload | |
| faiss.write_index(index, INDEX_PATH) | |
| np.save(EMB_PATH, embeddings) | |
| print(f"Indexed {index.ntotal} vectors.") | |
| # ## 4. Load QA Model & Pipeline | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
| qa_pipeline = pipeline( | |
| "text2text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=device, | |
| early_stopping=True | |
| ) | |
| print("QA pipeline ready.") | |
| # ## 5. Retrieval + Generation Functions | |
| # | |
| # We bail out early if top distance > threshold to avoid hallucination. | |
| def retrieve(question: str, k: int = 20, rerank_k: int = 5): | |
| # 1) denseβsearch top k | |
| q_emb = embedder.encode([question], convert_to_numpy=True) | |
| distances, indices = index.search(q_emb, k) | |
| # 2) pull out those k contexts | |
| candidates = [passages[i] for i in indices[0]] | |
| # 3) score with crossβencoder | |
| pairs = [[question, ctx] for ctx in candidates] | |
| scores = reranker.predict(pairs) | |
| # 4) pick top rerank_k | |
| top_idxs = np.argsort(scores)[-rerank_k:][::-1] | |
| final_ctxs = [candidates[i] for i in top_idxs] | |
| final_dist = [distances[0][i] for i in top_idxs] | |
| return final_ctxs, final_dist | |
| def generate(question: str, contexts: list) -> str: | |
| """ | |
| Build a RAG prompt from the retrieved contexts and generate | |
| an answer using the HF text2text pipeline. | |
| """ | |
| # 1) Turn each context into a truncated snippet | |
| snippet_lines = [ | |
| f"Context {i+1}: {s}" | |
| for i, s in enumerate(make_context_snippets(contexts, max_context_words)) | |
| ] | |
| # 2) Build the full prompt | |
| prompt = ( | |
| "You are a helpful assistant. Use ONLY the following contexts to answer. " | |
| "If the answer is not contained, say 'Sorry, I don't know.'\n\n" | |
| + "\n".join(snippet_lines) | |
| + f"\n\nQuestion: {question}\nAnswer:" | |
| ) | |
| # 3) Call the pipeline (it handles tokenization + generation + decoding) | |
| result = qa_pipeline(prompt, truncation=True, max_new_tokens=200)[0]["generated_text"] | |
| return result.strip() | |
| def retrieve_and_answer(question, k=5): | |
| contexts, distances = retrieve(question, k=20) | |
| if not contexts or distances[0] > dist_threshold: | |
| return "Sorry, I don't know.", [] | |
| ans = generate(question, contexts) | |
| return ans, contexts | |
| import random | |
| print("Some sample passages:\n") | |
| for p in random.sample(passages, 5): | |
| print(p, "\n" + "-"*80 + "\n") | |
| # ## 6. Gradio Demo Interface | |
| # | |
| # Separate panels for answer and contexts. | |
| def answer_and_contexts(question: str): | |
| """ | |
| Full end-to-end: retrieve, threshold-check, generate answer, | |
| and return both the answer and a formatted string of contexts. | |
| """ | |
| answer, contexts = retrieve_and_answer(question) | |
| # If no valid contexts, just return the apology | |
| if not contexts: | |
| return answer, "" | |
| # Otherwise format each snippet for display | |
| ctx_snippets = [ | |
| f"Context {i+1}: {s}" | |
| for i, s in enumerate(make_context_snippets(contexts, max_context_words)) | |
| ] | |
| return answer, "\n\n---\n\n".join(ctx_snippets) | |
| iface = gr.Interface( | |
| fn=answer_and_contexts, | |
| inputs=gr.Textbox(lines=1, placeholder="Enter your question here...", label="Question"), | |
| outputs=[ | |
| gr.Textbox(label="Answer"), | |
| gr.Textbox(label="Retrieved Contexts") | |
| ], | |
| title="π RAG QA Demo", | |
| description="Retrieval-Augmented QA with distance threshold and context preview" | |
| ) | |
| iface.launch() | |
| # # Test the Model | |
| # load SQuAD v2 (we only need validation split) | |
| squad = load_dataset("rajpurkar/squad_v2", split="validation") | |
| # load the SQuAD metric (handles no-answer properly) | |
| squad_metric = evaluate.load("squad") | |
| def retrieval_recall(dataset, k=20, num_samples=100): | |
| hits = 0 | |
| for ex in dataset.select(range(num_samples)): | |
| question = ex["question"] | |
| gold_answers = ex["answers"]["text"] # list, empty if unanswerable | |
| # get your top-k contexts | |
| ctxs, _ = retrieve(question, k=k, rerank_k=k) # or rerank_k smaller | |
| # check if any gold answer appears in any context | |
| if any(any(ans in ctx for ctx in ctxs) for ans in gold_answers): | |
| hits += 1 | |
| recall = hits / num_samples | |
| print(f"Retrieval Recall@{k}: {recall:.3f}") | |
| return recall | |
| # ## Only answerable Questions | |
| def retrieval_recall_answerable(dataset, k=20, num_samples=100): | |
| hits = 0 | |
| total = 0 | |
| for ex in dataset.select(range(num_samples)): | |
| if not ex["answers"]["text"]: | |
| continue # skip unanswerable | |
| total += 1 | |
| ctxs, _ = retrieve(ex["question"], k=k, rerank_k=k) | |
| if any(any(ans in ctx for ctx in ctxs) for ans in ex["answers"]["text"]): | |
| hits += 1 | |
| recall = hits / total | |
| print(f"Retrieval Recall@{k} on answerable only: {recall:.3f} ({hits}/{total})") | |
| return recall | |
| def qa_eval_all(dataset, num_samples=100, k=20): | |
| preds, refs = [], [] | |
| for ex in dataset.select(range(num_samples)): | |
| qid = ex["id"] | |
| gold = ex["answers"] | |
| # ensure metric has something to iterate over | |
| if not gold["text"]: | |
| gold = {"text":[""], "answer_start":[0]} | |
| ans, _ = retrieve_and_answer(ex["question"], k=k) | |
| # for metric purposes, treat our refusal as empty string | |
| pred_text = "" if ans.strip().lower().startswith("sorry") else ans | |
| preds.append({"id": qid, "prediction_text": pred_text}) | |
| refs.append({"id": qid, "answers": gold}) | |
| results = squad_metric.compute(predictions=preds, references=refs) | |
| print(f"Full QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}") | |
| return results | |
| def qa_eval_answerable(dataset, num_samples=100, k=20): | |
| preds, refs = [], [] | |
| for ex in dataset.select(range(num_samples)): | |
| if not ex["answers"]["text"]: | |
| continue # skip unanswerable | |
| qid = ex["id"] | |
| ans, _ = retrieve_and_answer(ex["question"], k=k) | |
| preds.append({"id": qid, "prediction_text": ans}) | |
| refs.append({"id": qid, "answers": ex["answers"]}) | |
| results = squad_metric.compute(predictions=preds, references=refs) | |
| print(f"Answerable-only QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}") | |
| return results | |
| retrieval_recall(squad, k=2, num_samples=100) | |
| retrieval_recall_answerable(squad, k=2, num_samples=100) | |
| qa_eval_all(squad, num_samples=100, k=2) | |
| qa_eval_answerable(squad, num_samples=100, k=2) | |