VictorTomas09 commited on
Commit
e40d8f8
·
verified ·
1 Parent(s): d12dbf4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -313
app.py CHANGED
@@ -1,365 +1,199 @@
1
-
2
- # # Retrieval-Augmented QA Demo
3
- #
4
- # This notebook builds a minimal RAG (Retrieval-Augmented Generation) pipeline with enhancements:
5
- #
6
- # - Slimmed & deduplicated corpora
7
- # - Chunking long passages
8
- # - Persistent FAISS index & embeddings
9
- # - Distance threshold to avoid hallucinations
10
- # - Context-length control
11
- # - Polished Gradio interface with separate contexts panel
12
-
13
- # ## 1. Configuration & Imports
14
- #
15
- # We detect device, print settings, and support loading saved index.
16
 
17
  import os
18
  import pickle
19
- from datasets import load_dataset
20
- from sentence_transformers import SentenceTransformer, CrossEncoder
21
  import faiss
22
  import numpy as np
23
  import torch
24
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
25
- from transformers import AutoTokenizer as _AutoTokenizer
26
  import gradio as gr
27
- import evaluate
28
 
 
 
 
 
 
 
 
29
 
30
- # Settings
31
- data_dir = os.path.join(os.getcwd(), "data")
32
- os.makedirs(data_dir, exist_ok=True)
33
- INDEX_PATH = os.path.join(data_dir, "faiss_index.faiss")
34
- EMB_PATH = os.path.join(data_dir, "embeddings.npy")
35
- PCTX_PATH = os.path.join(data_dir, "passages.pkl")
36
 
37
  MODEL_NAME = os.getenv("MODEL_NAME", "google/flan-t5-small")
38
  EMBEDDER_MODEL = os.getenv("EMBEDDER_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
39
- device = 0 if torch.cuda.is_available() else -1
40
- print(f"Using model: {MODEL_NAME}, embedder: {EMBEDDER_MODEL}, device: {'GPU' if device==0 else 'CPU'}")
41
 
42
- # Threshold for maximum acceptable L2 distance
43
- dist_threshold = 1.0 # tune as needed
44
- # Max words per context snippet
45
- max_context_words = 200
46
 
 
47
 
48
- # ## Useful functions
49
 
50
- def make_context_snippets(contexts, max_words=200):
51
- snippets = []
 
52
  for c in contexts:
53
  words = c.split()
54
  if len(words) > max_words:
55
  c = " ".join(words[:max_words]) + " ... [truncated]"
56
- snippets.append(c)
57
- return snippets
58
-
59
-
60
- # ## 2. Load, Deduplicate & Chunk Corpora
61
- #
62
- # For this demo we sample small slices and remove duplicates. We also chunk any passage >512 tokens.
63
- #
64
-
65
 
66
- # tokenizer for chunking
67
- chunk_tokenizer = _AutoTokenizer.from_pretrained(MODEL_NAME)
68
- max_tokens = chunk_tokenizer.model_max_length
69
-
70
- def chunk_text(text: str, max_tokens: int, stride: int = None) -> list[str]:
71
- """
72
- Split `text` into overlapping chunks of up to max_tokens words.
73
- By default uses 25% overlap (stride = max_tokens // 4).
74
- """
75
  words = text.split()
76
  if stride is None:
77
- stride = max_tokens // 4 # 25% overlap
78
- chunks = []
79
- start = 0
80
  while start < len(words):
81
  end = start + max_tokens
82
- chunk = " ".join(words[start:end])
83
- chunks.append(chunk)
84
- # advance by stride, not full window
85
  start += stride
86
  return chunks
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Load corpora
90
- wiki_ds = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages")
91
- wiki_passages = wiki_ds["passage"]
92
-
93
- squad_ds = load_dataset("rajpurkar/squad_v2", split="train[:100]")
94
- squad_passages = [ex["context"] for ex in squad_ds]
 
 
 
 
 
 
95
 
96
- trivia_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="validation[:100]")
97
- trivia_passages = []
98
- for ex in trivia_ds:
99
- for field in ("wiki_context", "search_context"):
100
- txt = ex.get(field) or ""
101
- if txt:
102
- trivia_passages.append(txt)
103
 
104
- # Combine, dedupe, chunk
105
- all_passages = wiki_passages + squad_passages + trivia_passages
106
- unique_passages = list(dict.fromkeys(all_passages))
107
- passages = []
108
- for p in unique_passages:
109
- # count tokens without encoding to avoid warnings
110
- tokens = chunk_tokenizer.tokenize(p)
111
- if len(tokens) > max_tokens:
112
- passages.extend(chunk_text(p, max_tokens))
113
  else:
114
- passages.append(p)
115
- print(f"Total passages after dedupe & chunk: {len(passages)}")
116
-
117
- # Persist raw passages list
118
- with open(PCTX_PATH, "wb") as f:
119
- pickle.dump(passages, f)
120
-
121
-
122
- # ## 3. Build or Load FAISS Index & Embeddings
123
- #
124
- # We save embeddings & index to disk to skip slow re-encoding.
125
-
126
 
127
- # ── Initialize embedder and reranker ──
128
- from sentence_transformers import SentenceTransformer
129
- from torch import no_grad
130
 
 
 
131
 
132
- embedder = SentenceTransformer(EMBEDDER_MODEL)
133
- reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
134
 
135
- # ── Load or (re)build FAISS index with cosine similarity ──
136
- if os.path.exists(INDEX_PATH) and os.path.exists(EMB_PATH):
137
- print("Loading saved index and embeddings…")
138
- index = faiss.read_index(INDEX_PATH)
139
- embeddings = np.load(EMB_PATH)
140
- else:
141
- print("Encoding passages (with overlap)…")
142
- embeddings = embedder.encode(
143
- passages,
144
- show_progress_bar=True,
145
- convert_to_numpy=True,
146
- batch_size=32
 
 
 
 
 
 
 
 
 
 
 
147
  )
148
- # Normalize to unit length so that inner‐product = cosine sim
149
- embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
150
-
151
- # Build a FAISS index over inner‐product (cosine) space
152
- dim = embeddings.shape[1]
153
- index = faiss.IndexFlatIP(dim)
154
- index.add(embeddings)
155
-
156
- # Persist to disk for faster reload
157
- faiss.write_index(index, INDEX_PATH)
158
- np.save(EMB_PATH, embeddings)
159
- print(f"Indexed {index.ntotal} vectors.")
160
-
161
-
162
- # ## 4. Load QA Model & Pipeline
163
-
164
-
165
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
166
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
167
- qa_pipeline = pipeline(
168
- "text2text-generation",
169
- model=model,
170
- tokenizer=tokenizer,
171
- device=device,
172
- early_stopping=True
173
- )
174
- print("QA pipeline ready.")
175
-
176
 
177
- # ## 5. Retrieval + Generation Functions
178
- #
179
- # We bail out early if top distance > threshold to avoid hallucination.
180
 
181
-
182
- def retrieve(question: str, k: int = 20, rerank_k: int = 5):
183
- # 1) dense‐search top k
184
  q_emb = embedder.encode([question], convert_to_numpy=True)
185
- distances, indices = index.search(q_emb, k)
186
-
187
- # 2) pull out those k contexts
188
- candidates = [passages[i] for i in indices[0]]
189
-
190
- # 3) score with cross‐encoder
191
- pairs = [[question, ctx] for ctx in candidates]
192
- scores = reranker.predict(pairs)
193
-
194
- # 4) pick top rerank_k
195
- top_idxs = np.argsort(scores)[-rerank_k:][::-1]
196
- final_ctxs = [candidates[i] for i in top_idxs]
197
- final_dist = [distances[0][i] for i in top_idxs]
198
-
199
- return final_ctxs, final_dist
200
 
 
 
 
201
 
 
 
 
202
 
203
- def generate(question: str, contexts: list) -> str:
204
- """
205
- Build a RAG prompt from the retrieved contexts and generate
206
- an answer using the HF text2text pipeline.
207
- """
208
- # 1) Turn each context into a truncated snippet
209
- snippet_lines = [
210
- f"Context {i+1}: {s}"
211
- for i, s in enumerate(make_context_snippets(contexts, max_context_words))
212
- ]
213
-
214
- # 2) Build the full prompt
215
  prompt = (
216
  "You are a helpful assistant. Use ONLY the following contexts to answer. "
217
  "If the answer is not contained, say 'Sorry, I don't know.'\n\n"
218
- + "\n".join(snippet_lines)
219
  + f"\n\nQuestion: {question}\nAnswer:"
220
  )
 
221
 
222
- # 3) Call the pipeline (it handles tokenization + generation + decoding)
223
- result = qa_pipeline(prompt, truncation=True, max_new_tokens=200)[0]["generated_text"]
224
- return result.strip()
225
-
226
-
227
- def retrieve_and_answer(question, k=5):
228
- contexts, distances = retrieve(question, k=20)
229
- if not contexts or distances[0] > dist_threshold:
230
  return "Sorry, I don't know.", []
231
-
232
- ans = generate(question, contexts)
233
- return ans, contexts
234
-
235
-
236
- import random
237
-
238
- print("Some sample passages:\n")
239
- for p in random.sample(passages, 5):
240
- print(p, "\n" + "-"*80 + "\n")
241
-
242
-
243
- # ## 6. Gradio Demo Interface
244
- #
245
- # Separate panels for answer and contexts.
246
-
247
- def answer_and_contexts(question: str):
248
- """
249
- Full end-to-end: retrieve, threshold-check, generate answer,
250
- and return both the answer and a formatted string of contexts.
251
- """
252
- answer, contexts = retrieve_and_answer(question)
253
-
254
- # If no valid contexts, just return the apology
255
- if not contexts:
256
- return answer, ""
257
-
258
- # Otherwise format each snippet for display
259
- ctx_snippets = [
260
- f"Context {i+1}: {s}"
261
- for i, s in enumerate(make_context_snippets(contexts, max_context_words))
262
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
- return answer, "\n\n---\n\n".join(ctx_snippets)
265
-
266
-
267
-
268
- iface = gr.Interface(
269
- fn=answer_and_contexts,
270
- inputs=gr.Textbox(lines=1, placeholder="Enter your question here...", label="Question"),
271
- outputs=[
272
- gr.Textbox(label="Answer"),
273
- gr.Textbox(label="Retrieved Contexts")
274
- ],
275
- title="🔍 RAG QA Demo",
276
- description="Retrieval-Augmented QA with distance threshold and context preview"
277
- )
278
-
279
- iface.launch()
280
-
281
-
282
- # # Test the Model
283
-
284
- # load SQuAD v2 (we only need validation split)
285
- squad = load_dataset("rajpurkar/squad_v2", split="validation")
286
-
287
- # load the SQuAD metric (handles no-answer properly)
288
- squad_metric = evaluate.load("squad")
289
-
290
-
291
- def retrieval_recall(dataset, k=20, num_samples=100):
292
- hits = 0
293
- for ex in dataset.select(range(num_samples)):
294
- question = ex["question"]
295
- gold_answers = ex["answers"]["text"] # list, empty if unanswerable
296
-
297
- # get your top-k contexts
298
- ctxs, _ = retrieve(question, k=k, rerank_k=k) # or rerank_k smaller
299
- # check if any gold answer appears in any context
300
- if any(any(ans in ctx for ctx in ctxs) for ans in gold_answers):
301
- hits += 1
302
-
303
- recall = hits / num_samples
304
- print(f"Retrieval Recall@{k}: {recall:.3f}")
305
- return recall
306
-
307
-
308
- # ## Only answerable Questions
309
-
310
-
311
- def retrieval_recall_answerable(dataset, k=20, num_samples=100):
312
- hits = 0
313
- total = 0
314
- for ex in dataset.select(range(num_samples)):
315
- if not ex["answers"]["text"]:
316
- continue # skip unanswerable
317
- total += 1
318
- ctxs, _ = retrieve(ex["question"], k=k, rerank_k=k)
319
- if any(any(ans in ctx for ctx in ctxs) for ans in ex["answers"]["text"]):
320
- hits += 1
321
- recall = hits / total
322
- print(f"Retrieval Recall@{k} on answerable only: {recall:.3f} ({hits}/{total})")
323
- return recall
324
-
325
- def qa_eval_all(dataset, num_samples=100, k=20):
326
- preds, refs = [], []
327
- for ex in dataset.select(range(num_samples)):
328
- qid = ex["id"]
329
- gold = ex["answers"]
330
- # ensure metric has something to iterate over
331
- if not gold["text"]:
332
- gold = {"text":[""], "answer_start":[0]}
333
-
334
- ans, _ = retrieve_and_answer(ex["question"], k=k)
335
- # for metric purposes, treat our refusal as empty string
336
- pred_text = "" if ans.strip().lower().startswith("sorry") else ans
337
-
338
- preds.append({"id": qid, "prediction_text": pred_text})
339
- refs.append({"id": qid, "answers": gold})
340
-
341
- results = squad_metric.compute(predictions=preds, references=refs)
342
- print(f"Full QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}")
343
- return results
344
-
345
- def qa_eval_answerable(dataset, num_samples=100, k=20):
346
- preds, refs = [], []
347
- for ex in dataset.select(range(num_samples)):
348
- if not ex["answers"]["text"]:
349
- continue # skip unanswerable
350
- qid = ex["id"]
351
- ans, _ = retrieve_and_answer(ex["question"], k=k)
352
-
353
- preds.append({"id": qid, "prediction_text": ans})
354
- refs.append({"id": qid, "answers": ex["answers"]})
355
-
356
- results = squad_metric.compute(predictions=preds, references=refs)
357
- print(f"Answerable-only QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}")
358
- return results
359
-
360
-
361
- retrieval_recall(squad, k=2, num_samples=100)
362
- retrieval_recall_answerable(squad, k=2, num_samples=100)
363
- qa_eval_all(squad, num_samples=100, k=2)
364
- qa_eval_answerable(squad, num_samples=100, k=2)
365
-
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  import os
5
  import pickle
 
 
6
  import faiss
7
  import numpy as np
8
  import torch
 
 
9
  import gradio as gr
 
10
 
11
+ from datasets import load_dataset
12
+ from sentence_transformers import SentenceTransformer, CrossEncoder
13
+ from transformers import (
14
+ AutoTokenizer,
15
+ AutoModelForSeq2SeqLM,
16
+ pipeline as hf_pipeline,
17
+ )
18
 
19
+ # ── 1. Configuration ──
20
+ DATA_DIR = os.path.join(os.getcwd(), "data")
21
+ INDEX_PATH = os.path.join(DATA_DIR, "faiss_index.faiss")
22
+ EMB_PATH = os.path.join(DATA_DIR, "embeddings.npy")
23
+ PCTX_PATH = os.path.join(DATA_DIR, "passages.pkl")
 
24
 
25
  MODEL_NAME = os.getenv("MODEL_NAME", "google/flan-t5-small")
26
  EMBEDDER_MODEL = os.getenv("EMBEDDER_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
27
+ DIST_THRESHOLD = float(os.getenv("DIST_THRESHOLD", 1.0))
28
+ MAX_CTX_WORDS = int(os.getenv("MAX_CTX_WORDS", 200))
29
 
30
+ DEVICE = 0 if torch.cuda.is_available() else -1
 
 
 
31
 
32
+ os.makedirs(DATA_DIR, exist_ok=True)
33
 
34
+ print(f"Using MODEL_NAME={MODEL_NAME}, EMBEDDER_MODEL={EMBEDDER_MODEL}, device={'GPU' if DEVICE==0 else 'CPU'}")
35
 
36
+ # ── 2. Helpers ──
37
+ def make_context_snippets(contexts, max_words=MAX_CTX_WORDS):
38
+ out = []
39
  for c in contexts:
40
  words = c.split()
41
  if len(words) > max_words:
42
  c = " ".join(words[:max_words]) + " ... [truncated]"
43
+ out.append(c)
44
+ return out
 
 
 
 
 
 
 
45
 
46
+ def chunk_text(text, max_tokens, stride=None):
 
 
 
 
 
 
 
 
47
  words = text.split()
48
  if stride is None:
49
+ stride = max_tokens // 4
50
+ chunks, start = [], 0
 
51
  while start < len(words):
52
  end = start + max_tokens
53
+ chunks.append(" ".join(words[start:end]))
 
 
54
  start += stride
55
  return chunks
56
 
57
+ # ── 3. Load & preprocess passages ──
58
+ def load_passages():
59
+ # 3.1 load raw corpora
60
+ wiki = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages")["passage"]
61
+ squad = load_dataset("rajpurkar/squad_v2", split="train[:100]")["context"]
62
+ trivia_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="validation[:100]")
63
+ trivia = []
64
+ for ex in trivia_ds:
65
+ for fld in ("wiki_context", "search_context"):
66
+ txt = ex.get(fld) or ""
67
+ if txt: trivia.append(txt)
68
+
69
+ all_passages = list(dict.fromkeys(wiki + squad + trivia))
70
+ # 3.2 chunk long passages
71
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
72
+ max_tokens = tokenizer.model_max_length
73
 
74
+ chunks = []
75
+ for p in all_passages:
76
+ toks = tokenizer.tokenize(p)
77
+ if len(toks) > max_tokens:
78
+ chunks.extend(chunk_text(p, max_tokens))
79
+ else:
80
+ chunks.append(p)
81
+
82
+ print(f"[load_passages] total chunks: {len(chunks)}")
83
+ with open(PCTX_PATH, "wb") as f:
84
+ pickle.dump(chunks, f)
85
+ return chunks
86
 
87
+ # ── 4. Build or load FAISS ──
88
+ def load_faiss_index(passages):
89
+ # sentence‐transformers embedder + cross‐encoder
90
+ embedder = SentenceTransformer(EMBEDDER_MODEL)
91
+ reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
 
 
92
 
93
+ if os.path.exists(INDEX_PATH) and os.path.exists(EMB_PATH):
94
+ print("Loading FAISS index & embeddings from disk …")
95
+ index = faiss.read_index(INDEX_PATH)
96
+ embeddings = np.load(EMB_PATH)
 
 
 
 
 
97
  else:
98
+ print("Encoding passages & building FAISS index …")
99
+ embeddings = embedder.encode(passages, show_progress_bar=True, convert_to_numpy=True, batch_size=32)
100
+ embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
 
 
 
 
 
 
 
 
 
101
 
102
+ dim = embeddings.shape[1]
103
+ index = faiss.IndexFlatIP(dim)
104
+ index.add(embeddings)
105
 
106
+ faiss.write_index(index, INDEX_PATH)
107
+ np.save(EMB_PATH, embeddings)
108
 
109
+ return embedder, reranker, index
 
110
 
111
+ # ── 5. Set up RAG pipeline ──
112
+ def setup_rag():
113
+ # 5.1 load or build index + embedder/reranker
114
+ if os.path.exists(PCTX_PATH):
115
+ with open(PCTX_PATH, "rb") as f:
116
+ passages = pickle.load(f)
117
+ else:
118
+ passages = load_passages()
119
+
120
+ embedder, reranker, index = load_faiss_index(passages)
121
+
122
+ # 5.2 load generator model & HF pipeline
123
+ tok = AutoTokenizer.from_pretrained(MODEL_NAME)
124
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
125
+ qa_pipe = hf_pipeline(
126
+ "text2text-generation",
127
+ model=model,
128
+ tokenizer=tok,
129
+ device=DEVICE,
130
+ truncation=True,
131
+ max_length=512,
132
+ num_beams=4, # optional: enable beam search
133
+ early_stopping=True
134
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ return passages, embedder, reranker, index, qa_pipe
 
 
137
 
138
+ # ── 6. Retrieval + Generation ──
139
+ def retrieve(question, passages, embedder, index, k=20, rerank_k=5):
 
140
  q_emb = embedder.encode([question], convert_to_numpy=True)
141
+ distances, idxs = index.search(q_emb, k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ cands = [passages[i] for i in idxs[0]]
144
+ scores = reranker.predict([[question, c] for c in cands])
145
+ top = np.argsort(scores)[-rerank_k:][::-1]
146
 
147
+ final_ctxs = [cands[i] for i in top]
148
+ final_dists = [distances[0][i] for i in top]
149
+ return final_ctxs, final_dists
150
 
151
+ def generate(question, contexts, qa_pipe):
152
+ lines = [ f"Context {i+1}: {s}"
153
+ for i,s in enumerate(make_context_snippets(contexts)) ]
 
 
 
 
 
 
 
 
 
154
  prompt = (
155
  "You are a helpful assistant. Use ONLY the following contexts to answer. "
156
  "If the answer is not contained, say 'Sorry, I don't know.'\n\n"
157
+ + "\n".join(lines)
158
  + f"\n\nQuestion: {question}\nAnswer:"
159
  )
160
+ return qa_pipe(prompt)[0]["generated_text"].strip()
161
 
162
+ def retrieve_and_answer(question, passages, embedder, reranker, index, qa_pipe):
163
+ ctxs, dists = retrieve(question, passages, embedder, index)
164
+ if not ctxs or dists[0] > DIST_THRESHOLD:
 
 
 
 
 
165
  return "Sorry, I don't know.", []
166
+ ans = generate(question, ctxs, qa_pipe)
167
+ return ans, ctxs
168
+
169
+ def answer_and_contexts(question,
170
+ passages, embedder, reranker, index, qa_pipe):
171
+ ans, ctxs = retrieve_and_answer(question, passages, embedder, reranker, index, qa_pipe)
172
+ if not ctxs:
173
+ return ans, ""
174
+ snippets = [
175
+ f"Context {i+1}: {s}"
176
+ for i,s in enumerate(make_context_snippets(ctxs))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  ]
178
+ return ans, "\n\n---\n\n".join(snippets)
179
+
180
+ # ── 7. Gradio app ──
181
+ def main():
182
+ passages, embedder, reranker, index, qa_pipe = setup_rag()
183
+
184
+ demo = gr.Interface(
185
+ fn=lambda q: answer_and_contexts(q, passages, embedder, reranker, index, qa_pipe),
186
+ inputs=gr.Textbox(lines=1, placeholder="Ask me anything…", label="Question"),
187
+ outputs=[gr.Textbox(label="Answer"), gr.Textbox(label="Contexts")],
188
+ title="🔍 RAG QA Demo",
189
+ description="Retrieval-Augmented QA with threshold and context preview",
190
+ examples=[
191
+ "When was Abraham Lincoln inaugurated?",
192
+ "What is the capital of France?",
193
+ "Who wrote '1984'?"
194
+ ]
195
+ )
196
+ demo.launch()
197
 
198
+ if __name__ == "__main__":
199
+ main()