VictorTomas09 commited on
Commit
74d4f11
·
verified ·
1 Parent(s): 2e8a823

Add app.py for RAG QA demo

Browse files

Converted the Jupyter notebook into a standalone Python script.
- Defines the embedder, FAISS index loading/creation, retrieval & generation functions.
- Builds the Gradio interface in a `main()` function and launches it.
- Ready for deployment on Hugging Face Spaces.

Files changed (1) hide show
  1. app.py +370 -0
app.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # # Retrieval-Augmented QA Demo
5
+ #
6
+ # This notebook builds a minimal RAG (Retrieval-Augmented Generation) pipeline with enhancements:
7
+ #
8
+ # - Slimmed & deduplicated corpora
9
+ # - Chunking long passages
10
+ # - Persistent FAISS index & embeddings
11
+ # - Distance threshold to avoid hallucinations
12
+ # - Context-length control
13
+ # - Polished Gradio interface with separate contexts panel
14
+
15
+ # ## 1. Configuration & Imports
16
+ #
17
+ # We detect device, print settings, and support loading saved index.
18
+
19
+ # In[2]:
20
+
21
+
22
+ import os
23
+ import pickle
24
+ from datasets import load_dataset
25
+ from sentence_transformers import SentenceTransformer, CrossEncoder
26
+ import faiss
27
+ import numpy as np
28
+ import torch
29
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
30
+ from transformers import AutoTokenizer as _AutoTokenizer
31
+ import gradio as gr
32
+ import evaluate
33
+
34
+
35
+ # Settings
36
+ data_dir = os.path.join(os.getcwd(), "data")
37
+ os.makedirs(data_dir, exist_ok=True)
38
+ INDEX_PATH = os.path.join(data_dir, "faiss_index.faiss")
39
+ EMB_PATH = os.path.join(data_dir, "embeddings.npy")
40
+ PCTX_PATH = os.path.join(data_dir, "passages.pkl")
41
+
42
+ MODEL_NAME = os.getenv("MODEL_NAME", "google/flan-t5-small")
43
+ EMBEDDER_MODEL = os.getenv("EMBEDDER_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
44
+ device = 0 if torch.cuda.is_available() else -1
45
+ print(f"Using model: {MODEL_NAME}, embedder: {EMBEDDER_MODEL}, device: {'GPU' if device==0 else 'CPU'}")
46
+
47
+ # Threshold for maximum acceptable L2 distance
48
+ dist_threshold = 1.0 # tune as needed
49
+ # Max words per context snippet
50
+ max_context_words = 200
51
+
52
+
53
+ # ## Useful functions
54
+
55
+ def make_context_snippets(contexts, max_words=200):
56
+ snippets = []
57
+ for c in contexts:
58
+ words = c.split()
59
+ if len(words) > max_words:
60
+ c = " ".join(words[:max_words]) + " ... [truncated]"
61
+ snippets.append(c)
62
+ return snippets
63
+
64
+
65
+ # ## 2. Load, Deduplicate & Chunk Corpora
66
+ #
67
+ # For this demo we sample small slices and remove duplicates. We also chunk any passage >512 tokens.
68
+ #
69
+
70
+
71
+ # tokenizer for chunking
72
+ chunk_tokenizer = _AutoTokenizer.from_pretrained(MODEL_NAME)
73
+ max_tokens = chunk_tokenizer.model_max_length
74
+
75
+ def chunk_text(text: str, max_tokens: int, stride: int = None) -> list[str]:
76
+ """
77
+ Split `text` into overlapping chunks of up to max_tokens words.
78
+ By default uses 25% overlap (stride = max_tokens // 4).
79
+ """
80
+ words = text.split()
81
+ if stride is None:
82
+ stride = max_tokens // 4 # 25% overlap
83
+ chunks = []
84
+ start = 0
85
+ while start < len(words):
86
+ end = start + max_tokens
87
+ chunk = " ".join(words[start:end])
88
+ chunks.append(chunk)
89
+ # advance by stride, not full window
90
+ start += stride
91
+ return chunks
92
+
93
+
94
+ # Load corpora
95
+ wiki_ds = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages")
96
+ wiki_passages = wiki_ds["passage"]
97
+
98
+ squad_ds = load_dataset("rajpurkar/squad_v2", split="train[:100]")
99
+ squad_passages = [ex["context"] for ex in squad_ds]
100
+
101
+ trivia_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="validation[:100]")
102
+ trivia_passages = []
103
+ for ex in trivia_ds:
104
+ for field in ("wiki_context", "search_context"):
105
+ txt = ex.get(field) or ""
106
+ if txt:
107
+ trivia_passages.append(txt)
108
+
109
+ # Combine, dedupe, chunk
110
+ all_passages = wiki_passages + squad_passages + trivia_passages
111
+ unique_passages = list(dict.fromkeys(all_passages))
112
+ passages = []
113
+ for p in unique_passages:
114
+ # count tokens without encoding to avoid warnings
115
+ tokens = chunk_tokenizer.tokenize(p)
116
+ if len(tokens) > max_tokens:
117
+ passages.extend(chunk_text(p, max_tokens))
118
+ else:
119
+ passages.append(p)
120
+ print(f"Total passages after dedupe & chunk: {len(passages)}")
121
+
122
+ # Persist raw passages list
123
+ with open(PCTX_PATH, "wb") as f:
124
+ pickle.dump(passages, f)
125
+
126
+
127
+ # ## 3. Build or Load FAISS Index & Embeddings
128
+ #
129
+ # We save embeddings & index to disk to skip slow re-encoding.
130
+
131
+
132
+ # ── Initialize embedder and reranker ──
133
+ from sentence_transformers import SentenceTransformer
134
+ from torch import no_grad
135
+
136
+
137
+ embedder = SentenceTransformer(EMBEDDER_MODEL)
138
+ reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
139
+
140
+ # ── Load or (re)build FAISS index with cosine similarity ──
141
+ if os.path.exists(INDEX_PATH) and os.path.exists(EMB_PATH):
142
+ print("Loading saved index and embeddings…")
143
+ index = faiss.read_index(INDEX_PATH)
144
+ embeddings = np.load(EMB_PATH)
145
+ else:
146
+ print("Encoding passages (with overlap)…")
147
+ embeddings = embedder.encode(
148
+ passages,
149
+ show_progress_bar=True,
150
+ convert_to_numpy=True,
151
+ batch_size=32
152
+ )
153
+ # Normalize to unit length so that inner‐product = cosine sim
154
+ embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
155
+
156
+ # Build a FAISS index over inner‐product (cosine) space
157
+ dim = embeddings.shape[1]
158
+ index = faiss.IndexFlatIP(dim)
159
+ index.add(embeddings)
160
+
161
+ # Persist to disk for faster reload
162
+ faiss.write_index(index, INDEX_PATH)
163
+ np.save(EMB_PATH, embeddings)
164
+ print(f"Indexed {index.ntotal} vectors.")
165
+
166
+
167
+ # ## 4. Load QA Model & Pipeline
168
+
169
+
170
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
171
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
172
+ qa_pipeline = pipeline(
173
+ "text2text-generation",
174
+ model=model,
175
+ tokenizer=tokenizer,
176
+ device=device,
177
+ early_stopping=True
178
+ )
179
+ print("QA pipeline ready.")
180
+
181
+
182
+ # ## 5. Retrieval + Generation Functions
183
+ #
184
+ # We bail out early if top distance > threshold to avoid hallucination.
185
+
186
+
187
+ def retrieve(question: str, k: int = 20, rerank_k: int = 5):
188
+ # 1) dense‐search top k
189
+ q_emb = embedder.encode([question], convert_to_numpy=True)
190
+ distances, indices = index.search(q_emb, k)
191
+
192
+ # 2) pull out those k contexts
193
+ candidates = [passages[i] for i in indices[0]]
194
+
195
+ # 3) score with cross‐encoder
196
+ pairs = [[question, ctx] for ctx in candidates]
197
+ scores = reranker.predict(pairs)
198
+
199
+ # 4) pick top rerank_k
200
+ top_idxs = np.argsort(scores)[-rerank_k:][::-1]
201
+ final_ctxs = [candidates[i] for i in top_idxs]
202
+ final_dist = [distances[0][i] for i in top_idxs]
203
+
204
+ return final_ctxs, final_dist
205
+
206
+
207
+
208
+ def generate(question: str, contexts: list) -> str:
209
+ """
210
+ Build a RAG prompt from the retrieved contexts and generate
211
+ an answer using the HF text2text pipeline.
212
+ """
213
+ # 1) Turn each context into a truncated snippet
214
+ snippet_lines = [
215
+ f"Context {i+1}: {s}"
216
+ for i, s in enumerate(make_context_snippets(contexts, max_context_words))
217
+ ]
218
+
219
+ # 2) Build the full prompt
220
+ prompt = (
221
+ "You are a helpful assistant. Use ONLY the following contexts to answer. "
222
+ "If the answer is not contained, say 'Sorry, I don't know.'\n\n"
223
+ + "\n".join(snippet_lines)
224
+ + f"\n\nQuestion: {question}\nAnswer:"
225
+ )
226
+
227
+ # 3) Call the pipeline (it handles tokenization + generation + decoding)
228
+ result = qa_pipeline(prompt, truncation=True, max_new_tokens=200)[0]["generated_text"]
229
+ return result.strip()
230
+
231
+
232
+ def retrieve_and_answer(question, k=5):
233
+ contexts, distances = retrieve(question, k=20)
234
+ if not contexts or distances[0] > dist_threshold:
235
+ return "Sorry, I don't know.", []
236
+
237
+ ans = generate(question, contexts)
238
+ return ans, contexts
239
+
240
+
241
+ import random
242
+
243
+ print("Some sample passages:\n")
244
+ for p in random.sample(passages, 5):
245
+ print(p, "\n" + "-"*80 + "\n")
246
+
247
+
248
+ # ## 6. Gradio Demo Interface
249
+ #
250
+ # Separate panels for answer and contexts.
251
+
252
+ def answer_and_contexts(question: str):
253
+ """
254
+ Full end-to-end: retrieve, threshold-check, generate answer,
255
+ and return both the answer and a formatted string of contexts.
256
+ """
257
+ answer, contexts = retrieve_and_answer(question)
258
+
259
+ # If no valid contexts, just return the apology
260
+ if not contexts:
261
+ return answer, ""
262
+
263
+ # Otherwise format each snippet for display
264
+ ctx_snippets = [
265
+ f"Context {i+1}: {s}"
266
+ for i, s in enumerate(make_context_snippets(contexts, max_context_words))
267
+ ]
268
+
269
+ return answer, "\n\n---\n\n".join(ctx_snippets)
270
+
271
+
272
+
273
+ iface = gr.Interface(
274
+ fn=answer_and_contexts,
275
+ inputs=gr.Textbox(lines=1, placeholder="Enter your question here...", label="Question"),
276
+ outputs=[
277
+ gr.Textbox(label="Answer"),
278
+ gr.Textbox(label="Retrieved Contexts")
279
+ ],
280
+ title="🔍 RAG QA Demo",
281
+ description="Retrieval-Augmented QA with distance threshold and context preview"
282
+ )
283
+
284
+ iface.launch()
285
+
286
+
287
+ # # Test the Model
288
+
289
+ # load SQuAD v2 (we only need validation split)
290
+ squad = load_dataset("rajpurkar/squad_v2", split="validation")
291
+
292
+ # load the SQuAD metric (handles no-answer properly)
293
+ squad_metric = evaluate.load("squad")
294
+
295
+
296
+ def retrieval_recall(dataset, k=20, num_samples=100):
297
+ hits = 0
298
+ for ex in dataset.select(range(num_samples)):
299
+ question = ex["question"]
300
+ gold_answers = ex["answers"]["text"] # list, empty if unanswerable
301
+
302
+ # get your top-k contexts
303
+ ctxs, _ = retrieve(question, k=k, rerank_k=k) # or rerank_k smaller
304
+ # check if any gold answer appears in any context
305
+ if any(any(ans in ctx for ctx in ctxs) for ans in gold_answers):
306
+ hits += 1
307
+
308
+ recall = hits / num_samples
309
+ print(f"Retrieval Recall@{k}: {recall:.3f}")
310
+ return recall
311
+
312
+
313
+ # ## Only answerable Questions
314
+
315
+
316
+ def retrieval_recall_answerable(dataset, k=20, num_samples=100):
317
+ hits = 0
318
+ total = 0
319
+ for ex in dataset.select(range(num_samples)):
320
+ if not ex["answers"]["text"]:
321
+ continue # skip unanswerable
322
+ total += 1
323
+ ctxs, _ = retrieve(ex["question"], k=k, rerank_k=k)
324
+ if any(any(ans in ctx for ctx in ctxs) for ans in ex["answers"]["text"]):
325
+ hits += 1
326
+ recall = hits / total
327
+ print(f"Retrieval Recall@{k} on answerable only: {recall:.3f} ({hits}/{total})")
328
+ return recall
329
+
330
+ def qa_eval_all(dataset, num_samples=100, k=20):
331
+ preds, refs = [], []
332
+ for ex in dataset.select(range(num_samples)):
333
+ qid = ex["id"]
334
+ gold = ex["answers"]
335
+ # ensure metric has something to iterate over
336
+ if not gold["text"]:
337
+ gold = {"text":[""], "answer_start":[0]}
338
+
339
+ ans, _ = retrieve_and_answer(ex["question"], k=k)
340
+ # for metric purposes, treat our refusal as empty string
341
+ pred_text = "" if ans.strip().lower().startswith("sorry") else ans
342
+
343
+ preds.append({"id": qid, "prediction_text": pred_text})
344
+ refs.append({"id": qid, "answers": gold})
345
+
346
+ results = squad_metric.compute(predictions=preds, references=refs)
347
+ print(f"Full QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}")
348
+ return results
349
+
350
+ def qa_eval_answerable(dataset, num_samples=100, k=20):
351
+ preds, refs = [], []
352
+ for ex in dataset.select(range(num_samples)):
353
+ if not ex["answers"]["text"]:
354
+ continue # skip unanswerable
355
+ qid = ex["id"]
356
+ ans, _ = retrieve_and_answer(ex["question"], k=k)
357
+
358
+ preds.append({"id": qid, "prediction_text": ans})
359
+ refs.append({"id": qid, "answers": ex["answers"]})
360
+
361
+ results = squad_metric.compute(predictions=preds, references=refs)
362
+ print(f"Answerable-only QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}")
363
+ return results
364
+
365
+
366
+ retrieval_recall(squad, k=2, num_samples=100)
367
+ retrieval_recall_answerable(squad, k=2, num_samples=100)
368
+ qa_eval_all(squad, num_samples=100, k=2)
369
+ qa_eval_answerable(squad, num_samples=100, k=2)
370
+