Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# coding: utf-8 | |
import os | |
import pickle | |
import faiss | |
import numpy as np | |
import torch | |
import gradio as gr | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer, CrossEncoder | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForSeq2SeqLM, | |
pipeline as hf_pipeline, | |
) | |
# ── 1. Configuration ── | |
DATA_DIR = os.path.join(os.getcwd(), "data") | |
INDEX_PATH = os.path.join(DATA_DIR, "faiss_index.faiss") | |
EMB_PATH = os.path.join(DATA_DIR, "embeddings.npy") | |
PCTX_PATH = os.path.join(DATA_DIR, "passages.pkl") | |
MODEL_NAME = os.getenv("MODEL_NAME", "google/flan-t5-small") | |
EMBEDDER_MODEL = os.getenv("EMBEDDER_MODEL", "sentence-transformers/all-MiniLM-L6-v2") | |
DIST_THRESHOLD = float(os.getenv("DIST_THRESHOLD", 1.0)) | |
MAX_CTX_WORDS = int(os.getenv("MAX_CTX_WORDS", 200)) | |
DEVICE = 0 if torch.cuda.is_available() else -1 | |
os.makedirs(DATA_DIR, exist_ok=True) | |
print(f"Using MODEL_NAME={MODEL_NAME}, EMBEDDER_MODEL={EMBEDDER_MODEL}, device={'GPU' if DEVICE==0 else 'CPU'}") | |
# ── 2. Helpers ── | |
def make_context_snippets(contexts, max_words=MAX_CTX_WORDS): | |
out = [] | |
for c in contexts: | |
words = c.split() | |
if len(words) > max_words: | |
c = " ".join(words[:max_words]) + " ... [truncated]" | |
out.append(c) | |
return out | |
def chunk_text(text, max_tokens, stride=None): | |
words = text.split() | |
if stride is None: | |
stride = max_tokens // 4 | |
chunks, start = [], 0 | |
while start < len(words): | |
end = start + max_tokens | |
chunks.append(" ".join(words[start:end])) | |
start += stride | |
return chunks | |
# ── 3. Load & preprocess passages ── | |
def load_passages(): | |
# 3.1 load raw corpora | |
wiki = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages")["passage"] | |
squad = load_dataset("rajpurkar/squad_v2", split="train[:100]")["context"] | |
trivia_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="validation[:100]") | |
trivia = [] | |
for ex in trivia_ds: | |
for fld in ("wiki_context", "search_context"): | |
txt = ex.get(fld) or "" | |
if txt: trivia.append(txt) | |
all_passages = list(dict.fromkeys(wiki + squad + trivia)) | |
# 3.2 chunk long passages | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
max_tokens = tokenizer.model_max_length | |
chunks = [] | |
for p in all_passages: | |
toks = tokenizer.tokenize(p) | |
if len(toks) > max_tokens: | |
chunks.extend(chunk_text(p, max_tokens)) | |
else: | |
chunks.append(p) | |
print(f"[load_passages] total chunks: {len(chunks)}") | |
with open(PCTX_PATH, "wb") as f: | |
pickle.dump(chunks, f) | |
return chunks | |
# ── 4. Build or load FAISS ── | |
def load_faiss_index(passages): | |
# sentence‐transformers embedder + cross‐encoder | |
embedder = SentenceTransformer(EMBEDDER_MODEL) | |
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
if os.path.exists(INDEX_PATH) and os.path.exists(EMB_PATH): | |
print("Loading FAISS index & embeddings from disk …") | |
index = faiss.read_index(INDEX_PATH) | |
embeddings = np.load(EMB_PATH) | |
else: | |
print("Encoding passages & building FAISS index …") | |
embeddings = embedder.encode(passages, show_progress_bar=True, convert_to_numpy=True, batch_size=32) | |
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) | |
dim = embeddings.shape[1] | |
index = faiss.IndexFlatIP(dim) | |
index.add(embeddings) | |
faiss.write_index(index, INDEX_PATH) | |
np.save(EMB_PATH, embeddings) | |
return embedder, reranker, index | |
# ── 5. Set up RAG pipeline ── | |
def setup_rag(): | |
# 5.1 load or build index + embedder/reranker | |
if os.path.exists(PCTX_PATH): | |
with open(PCTX_PATH, "rb") as f: | |
passages = pickle.load(f) | |
else: | |
passages = load_passages() | |
embedder, reranker, index = load_faiss_index(passages) | |
# 5.2 load generator model & HF pipeline | |
tok = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
qa_pipe = hf_pipeline( | |
"text2text-generation", | |
model=model, | |
tokenizer=tok, | |
device=DEVICE, | |
truncation=True, | |
max_length=512, | |
num_beams=4, # optional: enable beam search | |
early_stopping=True | |
) | |
return passages, embedder, reranker, index, qa_pipe | |
# ── 6. Retrieval + Generation ── | |
def retrieve(question, passages, embedder, index, k=20, rerank_k=5): | |
q_emb = embedder.encode([question], convert_to_numpy=True) | |
distances, idxs = index.search(q_emb, k) | |
cands = [passages[i] for i in idxs[0]] | |
scores = reranker.predict([[question, c] for c in cands]) | |
top = np.argsort(scores)[-rerank_k:][::-1] | |
final_ctxs = [cands[i] for i in top] | |
final_dists = [distances[0][i] for i in top] | |
return final_ctxs, final_dists | |
def generate(question, contexts, qa_pipe): | |
lines = [ f"Context {i+1}: {s}" | |
for i,s in enumerate(make_context_snippets(contexts)) ] | |
prompt = ( | |
"You are a helpful assistant. Use ONLY the following contexts to answer. " | |
"If the answer is not contained, say 'Sorry, I don't know.'\n\n" | |
+ "\n".join(lines) | |
+ f"\n\nQuestion: {question}\nAnswer:" | |
) | |
return qa_pipe(prompt)[0]["generated_text"].strip() | |
def retrieve_and_answer(question, passages, embedder, reranker, index, qa_pipe): | |
ctxs, dists = retrieve(question, passages, embedder, index) | |
if not ctxs or dists[0] > DIST_THRESHOLD: | |
return "Sorry, I don't know.", [] | |
ans = generate(question, ctxs, qa_pipe) | |
return ans, ctxs | |
def answer_and_contexts(question, | |
passages, embedder, reranker, index, qa_pipe): | |
ans, ctxs = retrieve_and_answer(question, passages, embedder, reranker, index, qa_pipe) | |
if not ctxs: | |
return ans, "" | |
snippets = [ | |
f"Context {i+1}: {s}" | |
for i,s in enumerate(make_context_snippets(ctxs)) | |
] | |
return ans, "\n\n---\n\n".join(snippets) | |
# ── 7. Gradio app ── | |
def main(): | |
passages, embedder, reranker, index, qa_pipe = setup_rag() | |
demo = gr.Interface( | |
fn=lambda q: answer_and_contexts(q, passages, embedder, reranker, index, qa_pipe), | |
inputs=gr.Textbox(lines=1, placeholder="Ask me anything…", label="Question"), | |
outputs=[gr.Textbox(label="Answer"), gr.Textbox(label="Contexts")], | |
title="🔍 RAG QA Demo", | |
description="Retrieval-Augmented QA with threshold and context preview", | |
examples=[ | |
"When was Abraham Lincoln inaugurated?", | |
"What is the capital of France?", | |
"Who wrote '1984'?" | |
] | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |