Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3 |
|
4 |
import os
|
5 |
import pickle
|
|
|
6 |
import faiss
|
7 |
import numpy as np
|
8 |
import torch
|
@@ -15,6 +16,7 @@ from transformers import (
|
|
15 |
AutoModelForSeq2SeqLM,
|
16 |
pipeline as hf_pipeline,
|
17 |
)
|
|
|
18 |
|
19 |
# ββ 1. Configuration ββ
|
20 |
DATA_DIR = os.path.join(os.getcwd(), "data")
|
@@ -30,7 +32,6 @@ MAX_CTX_WORDS = int(os.getenv("MAX_CTX_WORDS", 200))
|
|
30 |
DEVICE = 0 if torch.cuda.is_available() else -1
|
31 |
os.makedirs(DATA_DIR, exist_ok=True)
|
32 |
|
33 |
-
print(f"MODEL={MODEL_NAME}, EMBEDDER={EMBEDDER_MODEL}, DEVICE={'GPU' if DEVICE==0 else 'CPU'}")
|
34 |
|
35 |
# ββ 2. Helpers ββ
|
36 |
def make_context_snippets(contexts, max_words=MAX_CTX_WORDS):
|
@@ -53,15 +54,15 @@ def chunk_text(text, max_tokens, stride=None):
|
|
53 |
start += stride
|
54 |
return chunks
|
55 |
|
|
|
56 |
# ββ 3. Load & preprocess passages ββ
|
57 |
def load_passages():
|
58 |
-
# 3.1 load raw corpora
|
59 |
wiki_ds = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages")
|
60 |
squad_ds = load_dataset("rajpurkar/squad_v2", split="train[:100]")
|
61 |
trivia_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="validation[:100]")
|
62 |
|
63 |
-
wiki_passages
|
64 |
-
squad_passages
|
65 |
trivia_passages = []
|
66 |
for ex in trivia_ds:
|
67 |
for fld in ("wiki_context", "search_context"):
|
@@ -69,12 +70,10 @@ def load_passages():
|
|
69 |
if txt:
|
70 |
trivia_passages.append(txt)
|
71 |
|
72 |
-
# dedupe
|
73 |
all_passages = list(dict.fromkeys(wiki_passages + squad_passages + trivia_passages))
|
|
|
|
|
74 |
|
75 |
-
# chunk long passages
|
76 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
77 |
-
max_tokens = tokenizer.model_max_length
|
78 |
chunks = []
|
79 |
for p in all_passages:
|
80 |
toks = tokenizer.tokenize(p)
|
@@ -88,6 +87,7 @@ def load_passages():
|
|
88 |
pickle.dump(chunks, f)
|
89 |
return chunks
|
90 |
|
|
|
91 |
# ββ 4. Build or load FAISS ββ
|
92 |
def load_faiss_index(passages):
|
93 |
embedder = SentenceTransformer(EMBEDDER_MODEL)
|
@@ -116,6 +116,7 @@ def load_faiss_index(passages):
|
|
116 |
|
117 |
return embedder, reranker, index
|
118 |
|
|
|
119 |
# ββ 5. Initialize RAG components ββ
|
120 |
def setup_rag():
|
121 |
if os.path.exists(PCTX_PATH):
|
@@ -141,8 +142,9 @@ def setup_rag():
|
|
141 |
|
142 |
return passages, embedder, reranker, index, qa_pipe
|
143 |
|
|
|
144 |
# ββ 6. Retrieval & generation ββ
|
145 |
-
def retrieve(question, passages, embedder, index, k=20, rerank_k=5):
|
146 |
q_emb = embedder.encode([question], convert_to_numpy=True)
|
147 |
distances, idxs = index.search(q_emb, k)
|
148 |
|
@@ -166,7 +168,7 @@ def generate(question, contexts, qa_pipe):
|
|
166 |
return qa_pipe(prompt)[0]["generated_text"].strip()
|
167 |
|
168 |
def retrieve_and_answer(question, passages, embedder, reranker, index, qa_pipe):
|
169 |
-
contexts, dists = retrieve(question, passages, embedder, index)
|
170 |
if not contexts or dists[0] > DIST_THRESHOLD:
|
171 |
return "Sorry, I don't know.", []
|
172 |
return generate(question, contexts, qa_pipe), contexts
|
@@ -181,24 +183,99 @@ def answer_and_contexts(question, passages, embedder, reranker, index, qa_pipe):
|
|
181 |
]
|
182 |
return ans, "\n\n---\n\n".join(snippets)
|
183 |
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
def main():
|
186 |
passages, embedder, reranker, index, qa_pipe = setup_rag()
|
187 |
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
title="π RAG QA Demo",
|
193 |
-
description="Retrieval-Augmented QA with threshold and context preview",
|
194 |
-
examples=[
|
195 |
-
"When was Abraham Lincoln inaugurated?",
|
196 |
-
"What is the capital of France?",
|
197 |
-
"Who wrote '1984'?"
|
198 |
-
],
|
199 |
-
allow_flagging="never",
|
200 |
)
|
201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
if __name__ == "__main__":
|
204 |
main()
|
|
|
3 |
|
4 |
import os
|
5 |
import pickle
|
6 |
+
import argparse
|
7 |
import faiss
|
8 |
import numpy as np
|
9 |
import torch
|
|
|
16 |
AutoModelForSeq2SeqLM,
|
17 |
pipeline as hf_pipeline,
|
18 |
)
|
19 |
+
import evaluate
|
20 |
|
21 |
# ββ 1. Configuration ββ
|
22 |
DATA_DIR = os.path.join(os.getcwd(), "data")
|
|
|
32 |
DEVICE = 0 if torch.cuda.is_available() else -1
|
33 |
os.makedirs(DATA_DIR, exist_ok=True)
|
34 |
|
|
|
35 |
|
36 |
# ββ 2. Helpers ββ
|
37 |
def make_context_snippets(contexts, max_words=MAX_CTX_WORDS):
|
|
|
54 |
start += stride
|
55 |
return chunks
|
56 |
|
57 |
+
|
58 |
# ββ 3. Load & preprocess passages ββ
|
59 |
def load_passages():
|
|
|
60 |
wiki_ds = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages")
|
61 |
squad_ds = load_dataset("rajpurkar/squad_v2", split="train[:100]")
|
62 |
trivia_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="validation[:100]")
|
63 |
|
64 |
+
wiki_passages = wiki_ds["passage"]
|
65 |
+
squad_passages = [ex["context"] for ex in squad_ds]
|
66 |
trivia_passages = []
|
67 |
for ex in trivia_ds:
|
68 |
for fld in ("wiki_context", "search_context"):
|
|
|
70 |
if txt:
|
71 |
trivia_passages.append(txt)
|
72 |
|
|
|
73 |
all_passages = list(dict.fromkeys(wiki_passages + squad_passages + trivia_passages))
|
74 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
75 |
+
max_tokens = tokenizer.model_max_length
|
76 |
|
|
|
|
|
|
|
77 |
chunks = []
|
78 |
for p in all_passages:
|
79 |
toks = tokenizer.tokenize(p)
|
|
|
87 |
pickle.dump(chunks, f)
|
88 |
return chunks
|
89 |
|
90 |
+
|
91 |
# ββ 4. Build or load FAISS ββ
|
92 |
def load_faiss_index(passages):
|
93 |
embedder = SentenceTransformer(EMBEDDER_MODEL)
|
|
|
116 |
|
117 |
return embedder, reranker, index
|
118 |
|
119 |
+
|
120 |
# ββ 5. Initialize RAG components ββ
|
121 |
def setup_rag():
|
122 |
if os.path.exists(PCTX_PATH):
|
|
|
142 |
|
143 |
return passages, embedder, reranker, index, qa_pipe
|
144 |
|
145 |
+
|
146 |
# ββ 6. Retrieval & generation ββ
|
147 |
+
def retrieve(question, passages, embedder, reranker, index, k=20, rerank_k=5):
|
148 |
q_emb = embedder.encode([question], convert_to_numpy=True)
|
149 |
distances, idxs = index.search(q_emb, k)
|
150 |
|
|
|
168 |
return qa_pipe(prompt)[0]["generated_text"].strip()
|
169 |
|
170 |
def retrieve_and_answer(question, passages, embedder, reranker, index, qa_pipe):
|
171 |
+
contexts, dists = retrieve(question, passages, embedder, reranker, index)
|
172 |
if not contexts or dists[0] > DIST_THRESHOLD:
|
173 |
return "Sorry, I don't know.", []
|
174 |
return generate(question, contexts, qa_pipe), contexts
|
|
|
183 |
]
|
184 |
return ans, "\n\n---\n\n".join(snippets)
|
185 |
|
186 |
+
|
187 |
+
# ββ 7. Evaluation routines ββ
|
188 |
+
def retrieval_recall(dataset, passages, embedder, reranker, index, k=20, rerank_k=None, num_samples=100):
|
189 |
+
hits = 0
|
190 |
+
for ex in dataset.select(range(num_samples)):
|
191 |
+
question = ex["question"]
|
192 |
+
gold_answers = ex["answers"]["text"]
|
193 |
+
|
194 |
+
if rerank_k:
|
195 |
+
ctxs, _ = retrieve(question, passages, embedder, reranker, index, k=k, rerank_k=rerank_k)
|
196 |
+
else:
|
197 |
+
q_emb = embedder.encode([question], convert_to_numpy=True)
|
198 |
+
distances, idxs = index.search(q_emb, k)
|
199 |
+
ctxs = [passages[i] for i in idxs[0]]
|
200 |
+
|
201 |
+
if any(any(ans in ctx for ctx in ctxs) for ans in gold_answers):
|
202 |
+
hits += 1
|
203 |
+
|
204 |
+
recall = hits / num_samples
|
205 |
+
print(f"Retrieval Recall@{k} (rerank_k={rerank_k}): {recall:.3f} ({hits}/{num_samples})")
|
206 |
+
return recall
|
207 |
+
|
208 |
+
def retrieval_recall_answerable(dataset, passages, embedder, reranker, index, k=20, rerank_k=None, num_samples=100):
|
209 |
+
hits, total = 0, 0
|
210 |
+
for ex in dataset.select(range(num_samples)):
|
211 |
+
gold = ex["answers"]["text"]
|
212 |
+
if not gold:
|
213 |
+
continue
|
214 |
+
total += 1
|
215 |
+
question = ex["question"]
|
216 |
+
|
217 |
+
if rerank_k:
|
218 |
+
ctxs, _ = retrieve(question, passages, embedder, reranker, index, k=k, rerank_k=rerank_k)
|
219 |
+
else:
|
220 |
+
q_emb = embedder.encode([question], convert_to_numpy=True)
|
221 |
+
distances, idxs = index.search(q_emb, k)
|
222 |
+
ctxs = [passages[i] for i in idxs[0]]
|
223 |
+
|
224 |
+
if any(any(ans in ctx for ctx in ctxs) for ans in gold):
|
225 |
+
hits += 1
|
226 |
+
|
227 |
+
recall = hits / total if total > 0 else 0.0
|
228 |
+
print(f"Retrieval Recall@{k} on answerable only (rerank_k={rerank_k}): {recall:.3f} ({hits}/{total})")
|
229 |
+
return recall
|
230 |
+
|
231 |
+
def qa_eval_answerable(dataset, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100):
|
232 |
+
squad_metric = evaluate.load("squad")
|
233 |
+
preds, refs = [], []
|
234 |
+
for ex in dataset.select(range(num_samples)):
|
235 |
+
gold = ex["answers"]["text"]
|
236 |
+
if not gold:
|
237 |
+
continue
|
238 |
+
qid = ex["id"]
|
239 |
+
answer, _ = retrieve_and_answer(ex["question"], passages, embedder, reranker, index, qa_pipe)
|
240 |
+
preds.append({"id": qid, "prediction_text": answer})
|
241 |
+
refs.append({"id": qid, "answers": ex["answers"]})
|
242 |
+
|
243 |
+
results = squad_metric.compute(predictions=preds, references=refs)
|
244 |
+
print(f"Answerable-only QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}")
|
245 |
+
return results
|
246 |
+
|
247 |
+
|
248 |
+
# ββ 8. Main entry ββ
|
249 |
def main():
|
250 |
passages, embedder, reranker, index, qa_pipe = setup_rag()
|
251 |
|
252 |
+
parser = argparse.ArgumentParser()
|
253 |
+
parser.add_argument(
|
254 |
+
"--eval", action="store_true",
|
255 |
+
help="Run retrieval/QA evaluations on SQuAD instead of launching the demo"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
)
|
257 |
+
args = parser.parse_args()
|
258 |
+
|
259 |
+
if args.eval:
|
260 |
+
squad = load_dataset("rajpurkar/squad_v2", split="validation")
|
261 |
+
retrieval_recall(squad, passages, embedder, reranker, index, k=20, rerank_k=5, num_samples=100)
|
262 |
+
retrieval_recall_answerable(squad, passages, embedder, reranker, index, k=20, rerank_k=5, num_samples=100)
|
263 |
+
qa_eval_answerable(squad, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100)
|
264 |
+
else:
|
265 |
+
demo = gr.Interface(
|
266 |
+
fn=lambda q: answer_and_contexts(q, passages, embedder, reranker, index, qa_pipe),
|
267 |
+
inputs=gr.Textbox(lines=1, placeholder="Ask me anythingβ¦", label="Question"),
|
268 |
+
outputs=[gr.Textbox(label="Answer"), gr.Textbox(label="Contexts")],
|
269 |
+
title="π RAG QA Demo",
|
270 |
+
description="Retrieval-Augmented QA with threshold and context preview",
|
271 |
+
examples=[
|
272 |
+
"When was Abraham Lincoln inaugurated?",
|
273 |
+
"What is the capital of France?",
|
274 |
+
"Who wrote '1984'?"
|
275 |
+
],
|
276 |
+
allow_flagging="never",
|
277 |
+
)
|
278 |
+
demo.launch()
|
279 |
|
280 |
if __name__ == "__main__":
|
281 |
main()
|