VictorTomas09 commited on
Commit
728106c
Β·
verified Β·
1 Parent(s): 8b017be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -24
app.py CHANGED
@@ -3,6 +3,7 @@
3
 
4
  import os
5
  import pickle
 
6
  import faiss
7
  import numpy as np
8
  import torch
@@ -15,6 +16,7 @@ from transformers import (
15
  AutoModelForSeq2SeqLM,
16
  pipeline as hf_pipeline,
17
  )
 
18
 
19
  # ── 1. Configuration ──
20
  DATA_DIR = os.path.join(os.getcwd(), "data")
@@ -30,7 +32,6 @@ MAX_CTX_WORDS = int(os.getenv("MAX_CTX_WORDS", 200))
30
  DEVICE = 0 if torch.cuda.is_available() else -1
31
  os.makedirs(DATA_DIR, exist_ok=True)
32
 
33
- print(f"MODEL={MODEL_NAME}, EMBEDDER={EMBEDDER_MODEL}, DEVICE={'GPU' if DEVICE==0 else 'CPU'}")
34
 
35
  # ── 2. Helpers ──
36
  def make_context_snippets(contexts, max_words=MAX_CTX_WORDS):
@@ -53,15 +54,15 @@ def chunk_text(text, max_tokens, stride=None):
53
  start += stride
54
  return chunks
55
 
 
56
  # ── 3. Load & preprocess passages ──
57
  def load_passages():
58
- # 3.1 load raw corpora
59
  wiki_ds = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages")
60
  squad_ds = load_dataset("rajpurkar/squad_v2", split="train[:100]")
61
  trivia_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="validation[:100]")
62
 
63
- wiki_passages = wiki_ds["passage"]
64
- squad_passages = [ex["context"] for ex in squad_ds]
65
  trivia_passages = []
66
  for ex in trivia_ds:
67
  for fld in ("wiki_context", "search_context"):
@@ -69,12 +70,10 @@ def load_passages():
69
  if txt:
70
  trivia_passages.append(txt)
71
 
72
- # dedupe
73
  all_passages = list(dict.fromkeys(wiki_passages + squad_passages + trivia_passages))
 
 
74
 
75
- # chunk long passages
76
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
77
- max_tokens = tokenizer.model_max_length
78
  chunks = []
79
  for p in all_passages:
80
  toks = tokenizer.tokenize(p)
@@ -88,6 +87,7 @@ def load_passages():
88
  pickle.dump(chunks, f)
89
  return chunks
90
 
 
91
  # ── 4. Build or load FAISS ──
92
  def load_faiss_index(passages):
93
  embedder = SentenceTransformer(EMBEDDER_MODEL)
@@ -116,6 +116,7 @@ def load_faiss_index(passages):
116
 
117
  return embedder, reranker, index
118
 
 
119
  # ── 5. Initialize RAG components ──
120
  def setup_rag():
121
  if os.path.exists(PCTX_PATH):
@@ -141,8 +142,9 @@ def setup_rag():
141
 
142
  return passages, embedder, reranker, index, qa_pipe
143
 
 
144
  # ── 6. Retrieval & generation ──
145
- def retrieve(question, passages, embedder, index, k=20, rerank_k=5):
146
  q_emb = embedder.encode([question], convert_to_numpy=True)
147
  distances, idxs = index.search(q_emb, k)
148
 
@@ -166,7 +168,7 @@ def generate(question, contexts, qa_pipe):
166
  return qa_pipe(prompt)[0]["generated_text"].strip()
167
 
168
  def retrieve_and_answer(question, passages, embedder, reranker, index, qa_pipe):
169
- contexts, dists = retrieve(question, passages, embedder, index)
170
  if not contexts or dists[0] > DIST_THRESHOLD:
171
  return "Sorry, I don't know.", []
172
  return generate(question, contexts, qa_pipe), contexts
@@ -181,24 +183,99 @@ def answer_and_contexts(question, passages, embedder, reranker, index, qa_pipe):
181
  ]
182
  return ans, "\n\n---\n\n".join(snippets)
183
 
184
- # ── 7. Gradio app ──
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  def main():
186
  passages, embedder, reranker, index, qa_pipe = setup_rag()
187
 
188
- demo = gr.Interface(
189
- fn=lambda q: answer_and_contexts(q, passages, embedder, reranker, index, qa_pipe),
190
- inputs=gr.Textbox(lines=1, placeholder="Ask me anything…", label="Question"),
191
- outputs=[gr.Textbox(label="Answer"), gr.Textbox(label="Contexts")],
192
- title="πŸ” RAG QA Demo",
193
- description="Retrieval-Augmented QA with threshold and context preview",
194
- examples=[
195
- "When was Abraham Lincoln inaugurated?",
196
- "What is the capital of France?",
197
- "Who wrote '1984'?"
198
- ],
199
- allow_flagging="never",
200
  )
201
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  if __name__ == "__main__":
204
  main()
 
3
 
4
  import os
5
  import pickle
6
+ import argparse
7
  import faiss
8
  import numpy as np
9
  import torch
 
16
  AutoModelForSeq2SeqLM,
17
  pipeline as hf_pipeline,
18
  )
19
+ import evaluate
20
 
21
  # ── 1. Configuration ──
22
  DATA_DIR = os.path.join(os.getcwd(), "data")
 
32
  DEVICE = 0 if torch.cuda.is_available() else -1
33
  os.makedirs(DATA_DIR, exist_ok=True)
34
 
 
35
 
36
  # ── 2. Helpers ──
37
  def make_context_snippets(contexts, max_words=MAX_CTX_WORDS):
 
54
  start += stride
55
  return chunks
56
 
57
+
58
  # ── 3. Load & preprocess passages ──
59
  def load_passages():
 
60
  wiki_ds = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages")
61
  squad_ds = load_dataset("rajpurkar/squad_v2", split="train[:100]")
62
  trivia_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="validation[:100]")
63
 
64
+ wiki_passages = wiki_ds["passage"]
65
+ squad_passages = [ex["context"] for ex in squad_ds]
66
  trivia_passages = []
67
  for ex in trivia_ds:
68
  for fld in ("wiki_context", "search_context"):
 
70
  if txt:
71
  trivia_passages.append(txt)
72
 
 
73
  all_passages = list(dict.fromkeys(wiki_passages + squad_passages + trivia_passages))
74
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
75
+ max_tokens = tokenizer.model_max_length
76
 
 
 
 
77
  chunks = []
78
  for p in all_passages:
79
  toks = tokenizer.tokenize(p)
 
87
  pickle.dump(chunks, f)
88
  return chunks
89
 
90
+
91
  # ── 4. Build or load FAISS ──
92
  def load_faiss_index(passages):
93
  embedder = SentenceTransformer(EMBEDDER_MODEL)
 
116
 
117
  return embedder, reranker, index
118
 
119
+
120
  # ── 5. Initialize RAG components ──
121
  def setup_rag():
122
  if os.path.exists(PCTX_PATH):
 
142
 
143
  return passages, embedder, reranker, index, qa_pipe
144
 
145
+
146
  # ── 6. Retrieval & generation ──
147
+ def retrieve(question, passages, embedder, reranker, index, k=20, rerank_k=5):
148
  q_emb = embedder.encode([question], convert_to_numpy=True)
149
  distances, idxs = index.search(q_emb, k)
150
 
 
168
  return qa_pipe(prompt)[0]["generated_text"].strip()
169
 
170
  def retrieve_and_answer(question, passages, embedder, reranker, index, qa_pipe):
171
+ contexts, dists = retrieve(question, passages, embedder, reranker, index)
172
  if not contexts or dists[0] > DIST_THRESHOLD:
173
  return "Sorry, I don't know.", []
174
  return generate(question, contexts, qa_pipe), contexts
 
183
  ]
184
  return ans, "\n\n---\n\n".join(snippets)
185
 
186
+
187
+ # ── 7. Evaluation routines ──
188
+ def retrieval_recall(dataset, passages, embedder, reranker, index, k=20, rerank_k=None, num_samples=100):
189
+ hits = 0
190
+ for ex in dataset.select(range(num_samples)):
191
+ question = ex["question"]
192
+ gold_answers = ex["answers"]["text"]
193
+
194
+ if rerank_k:
195
+ ctxs, _ = retrieve(question, passages, embedder, reranker, index, k=k, rerank_k=rerank_k)
196
+ else:
197
+ q_emb = embedder.encode([question], convert_to_numpy=True)
198
+ distances, idxs = index.search(q_emb, k)
199
+ ctxs = [passages[i] for i in idxs[0]]
200
+
201
+ if any(any(ans in ctx for ctx in ctxs) for ans in gold_answers):
202
+ hits += 1
203
+
204
+ recall = hits / num_samples
205
+ print(f"Retrieval Recall@{k} (rerank_k={rerank_k}): {recall:.3f} ({hits}/{num_samples})")
206
+ return recall
207
+
208
+ def retrieval_recall_answerable(dataset, passages, embedder, reranker, index, k=20, rerank_k=None, num_samples=100):
209
+ hits, total = 0, 0
210
+ for ex in dataset.select(range(num_samples)):
211
+ gold = ex["answers"]["text"]
212
+ if not gold:
213
+ continue
214
+ total += 1
215
+ question = ex["question"]
216
+
217
+ if rerank_k:
218
+ ctxs, _ = retrieve(question, passages, embedder, reranker, index, k=k, rerank_k=rerank_k)
219
+ else:
220
+ q_emb = embedder.encode([question], convert_to_numpy=True)
221
+ distances, idxs = index.search(q_emb, k)
222
+ ctxs = [passages[i] for i in idxs[0]]
223
+
224
+ if any(any(ans in ctx for ctx in ctxs) for ans in gold):
225
+ hits += 1
226
+
227
+ recall = hits / total if total > 0 else 0.0
228
+ print(f"Retrieval Recall@{k} on answerable only (rerank_k={rerank_k}): {recall:.3f} ({hits}/{total})")
229
+ return recall
230
+
231
+ def qa_eval_answerable(dataset, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100):
232
+ squad_metric = evaluate.load("squad")
233
+ preds, refs = [], []
234
+ for ex in dataset.select(range(num_samples)):
235
+ gold = ex["answers"]["text"]
236
+ if not gold:
237
+ continue
238
+ qid = ex["id"]
239
+ answer, _ = retrieve_and_answer(ex["question"], passages, embedder, reranker, index, qa_pipe)
240
+ preds.append({"id": qid, "prediction_text": answer})
241
+ refs.append({"id": qid, "answers": ex["answers"]})
242
+
243
+ results = squad_metric.compute(predictions=preds, references=refs)
244
+ print(f"Answerable-only QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}")
245
+ return results
246
+
247
+
248
+ # ── 8. Main entry ──
249
  def main():
250
  passages, embedder, reranker, index, qa_pipe = setup_rag()
251
 
252
+ parser = argparse.ArgumentParser()
253
+ parser.add_argument(
254
+ "--eval", action="store_true",
255
+ help="Run retrieval/QA evaluations on SQuAD instead of launching the demo"
 
 
 
 
 
 
 
 
256
  )
257
+ args = parser.parse_args()
258
+
259
+ if args.eval:
260
+ squad = load_dataset("rajpurkar/squad_v2", split="validation")
261
+ retrieval_recall(squad, passages, embedder, reranker, index, k=20, rerank_k=5, num_samples=100)
262
+ retrieval_recall_answerable(squad, passages, embedder, reranker, index, k=20, rerank_k=5, num_samples=100)
263
+ qa_eval_answerable(squad, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100)
264
+ else:
265
+ demo = gr.Interface(
266
+ fn=lambda q: answer_and_contexts(q, passages, embedder, reranker, index, qa_pipe),
267
+ inputs=gr.Textbox(lines=1, placeholder="Ask me anything…", label="Question"),
268
+ outputs=[gr.Textbox(label="Answer"), gr.Textbox(label="Contexts")],
269
+ title="πŸ” RAG QA Demo",
270
+ description="Retrieval-Augmented QA with threshold and context preview",
271
+ examples=[
272
+ "When was Abraham Lincoln inaugurated?",
273
+ "What is the capital of France?",
274
+ "Who wrote '1984'?"
275
+ ],
276
+ allow_flagging="never",
277
+ )
278
+ demo.launch()
279
 
280
  if __name__ == "__main__":
281
  main()