Spaces:
Sleeping
Sleeping
File size: 3,888 Bytes
ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 ff2408d 86211f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import os
import pickle
import numpy as np
import faiss
import torch
from datasets import load_dataset
import evaluate
# Import RAG setup and retrieval logic from app.py
from app import setup_rag, retrieve, retrieve_and_answer
def retrieval_recall(dataset, passages, embedder, reranker, index, k=20, rerank_k=None, num_samples=100):
"""
Compute raw Retrieval Recall@k on the first num_samples examples.
If rerank_k is set, apply cross-encoder reranking via `retrieve`.
Otherwise, use the FAISS index only (top-k) without reranking.
"""
hits = 0
for ex in dataset.select(range(num_samples)):
question = ex["question"]
gold_answers = ex["answers"]["text"]
if rerank_k:
# use two-stage retrieval (dense + rerank)
ctxs, _ = retrieve(question, passages, embedder, reranker, index, k=k, rerank_k=rerank_k)
else:
# single-stage: FAISS only
q_emb = embedder.encode([question], convert_to_numpy=True)
distances, idxs = index.search(q_emb, k)
ctxs = [passages[i] for i in idxs[0]]
# check if any gold answer appears in any retrieved context
if any(any(ans in ctx for ctx in ctxs) for ans in gold_answers):
hits += 1
recall = hits / num_samples
print(f"Retrieval Recall@{k} (rerank_k={rerank_k}): {recall:.3f} ({hits}/{num_samples})")
return recall
def retrieval_recall_answerable(dataset, passages, embedder, reranker, index, k=20, rerank_k=None, num_samples=100):
"""
Retrieval Recall@k evaluated only on answerable questions (answers list non-empty).
"""
hits = 0
total = 0
for ex in dataset.select(range(num_samples)):
gold = ex["answers"]["text"]
if not gold:
continue
total += 1
question = ex["question"]
if rerank_k:
ctxs, _ = retrieve(question, passages, embedder, reranker, index, k=k, rerank_k=rerank_k)
else:
q_emb = embedder.encode([question], convert_to_numpy=True)
distances, idxs = index.search(q_emb, k)
ctxs = [passages[i] for i in idxs[0]]
if any(any(ans in ctx for ctx in ctxs) for ans in gold):
hits += 1
recall = hits / total if total > 0 else 0.0
print(f"Retrieval Recall@{k} on answerable only (rerank_k={rerank_k}): {recall:.3f} ({hits}/{total})")
return recall
def qa_eval_answerable(dataset, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100):
"""
End-to-end QA EM/F1 on answerable subset using retrieve_and_answer.
"""
squad_metric = evaluate.load("squad")
preds = []
refs = []
for ex in dataset.select(range(num_samples)):
gold = ex["answers"]["text"]
if not gold:
continue
qid = ex["id"]
# retrieve and generate
answer, _ = retrieve_and_answer(
ex["question"], passages, embedder, reranker, index, qa_pipe
)
preds.append({"id": qid, "prediction_text": answer})
refs.append({"id": qid, "answers": ex["answers"]})
results = squad_metric.compute(predictions=preds, references=refs)
print(f"Answerable-only QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}")
return results
def main():
# 1) Setup RAG components
passages, embedder, reranker, index, qa_pipe = setup_rag()
# 2) Load SQuAD v2 validation split
squad = load_dataset("rajpurkar/squad_v2", split="validation")
# 3) Run evaluations
retrieval_recall(squad, passages, embedder, reranker, index, k=20, rerank_k=5, num_samples=100)
retrieval_recall_answerable(squad, passages, embedder, reranker, index, k=20, rerank_k=5, num_samples=100)
qa_eval_answerable(squad, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100)
if __name__ == "__main__":
main()
|