File size: 3,888 Bytes
ff2408d
 
 
 
 
 
 
 
 
86211f8
ff2408d
 
86211f8
ff2408d
 
86211f8
 
ff2408d
 
 
 
 
86211f8
ff2408d
86211f8
 
ff2408d
86211f8
ff2408d
 
 
86211f8
 
ff2408d
 
86211f8
ff2408d
 
 
 
 
86211f8
ff2408d
86211f8
ff2408d
86211f8
 
ff2408d
86211f8
 
ff2408d
 
 
86211f8
ff2408d
86211f8
ff2408d
 
 
 
86211f8
 
ff2408d
86211f8
ff2408d
 
 
 
 
 
 
86211f8
ff2408d
 
86211f8
 
 
ff2408d
86211f8
 
ff2408d
 
 
86211f8
 
 
ff2408d
 
86211f8
ff2408d
 
 
 
 
 
86211f8
ff2408d
86211f8
 
ff2408d
 
86211f8
 
 
ff2408d
 
 
 
 
86211f8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import pickle
import numpy as np
import faiss
import torch
from datasets import load_dataset
import evaluate

# Import RAG setup and retrieval logic from app.py
from app import setup_rag, retrieve, retrieve_and_answer


def retrieval_recall(dataset, passages, embedder, reranker, index, k=20, rerank_k=None, num_samples=100):
    """
    Compute raw Retrieval Recall@k on the first num_samples examples.
    If rerank_k is set, apply cross-encoder reranking via `retrieve`.
    Otherwise, use the FAISS index only (top-k) without reranking.
    """
    hits = 0
    for ex in dataset.select(range(num_samples)):
        question = ex["question"]
        gold_answers = ex["answers"]["text"]

        if rerank_k:
            # use two-stage retrieval (dense + rerank)
            ctxs, _ = retrieve(question, passages, embedder, reranker, index, k=k, rerank_k=rerank_k)
        else:
            # single-stage: FAISS only
            q_emb = embedder.encode([question], convert_to_numpy=True)
            distances, idxs = index.search(q_emb, k)
            ctxs = [passages[i] for i in idxs[0]]

        # check if any gold answer appears in any retrieved context
        if any(any(ans in ctx for ctx in ctxs) for ans in gold_answers):
            hits += 1

    recall = hits / num_samples
    print(f"Retrieval Recall@{k} (rerank_k={rerank_k}): {recall:.3f} ({hits}/{num_samples})")
    return recall


def retrieval_recall_answerable(dataset, passages, embedder, reranker, index, k=20, rerank_k=None, num_samples=100):
    """
    Retrieval Recall@k evaluated only on answerable questions (answers list non-empty).
    """
    hits = 0
    total = 0
    for ex in dataset.select(range(num_samples)):
        gold = ex["answers"]["text"]
        if not gold:
            continue
        total += 1
        question = ex["question"]

        if rerank_k:
            ctxs, _ = retrieve(question, passages, embedder, reranker, index, k=k, rerank_k=rerank_k)
        else:
            q_emb = embedder.encode([question], convert_to_numpy=True)
            distances, idxs = index.search(q_emb, k)
            ctxs = [passages[i] for i in idxs[0]]

        if any(any(ans in ctx for ctx in ctxs) for ans in gold):
            hits += 1

    recall = hits / total if total > 0 else 0.0
    print(f"Retrieval Recall@{k} on answerable only (rerank_k={rerank_k}): {recall:.3f} ({hits}/{total})")
    return recall


def qa_eval_answerable(dataset, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100):
    """
    End-to-end QA EM/F1 on answerable subset using retrieve_and_answer.
    """
    squad_metric = evaluate.load("squad")
    preds = []
    refs = []

    for ex in dataset.select(range(num_samples)):
        gold = ex["answers"]["text"]
        if not gold:
            continue
        qid = ex["id"]
        # retrieve and generate
        answer, _ = retrieve_and_answer(
            ex["question"], passages, embedder, reranker, index, qa_pipe
        )
        preds.append({"id": qid, "prediction_text": answer})
        refs.append({"id": qid, "answers": ex["answers"]})

    results = squad_metric.compute(predictions=preds, references=refs)
    print(f"Answerable-only QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}")
    return results


def main():
    # 1) Setup RAG components
    passages, embedder, reranker, index, qa_pipe = setup_rag()

    # 2) Load SQuAD v2 validation split
    squad = load_dataset("rajpurkar/squad_v2", split="validation")

    # 3) Run evaluations
    retrieval_recall(squad, passages, embedder, reranker, index, k=20, rerank_k=5, num_samples=100)
    retrieval_recall_answerable(squad, passages, embedder, reranker, index, k=20, rerank_k=5, num_samples=100)
    qa_eval_answerable(squad, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100)


if __name__ == "__main__":
    main()