VictorTomas09 commited on
Commit
86211f8
·
verified ·
1 Parent(s): 970694a

Update Evaluators

Browse files
Files changed (1) hide show
  1. Evaluators +39 -22
Evaluators CHANGED
@@ -7,52 +7,61 @@ from datasets import load_dataset
7
  import evaluate
8
 
9
  # Import RAG setup and retrieval logic from app.py
10
- from app import setup_rag, retrieve
11
 
12
 
13
- def retrieval_recall(dataset, passages, embedder, index, k=20, rerank_k=None, num_samples=100):
14
  """
15
  Compute raw Retrieval Recall@k on the first num_samples examples.
16
- If rerank_k is set, also apply cross-encoder reranking.
 
17
  """
18
  hits = 0
19
  for ex in dataset.select(range(num_samples)):
20
  question = ex["question"]
21
  gold_answers = ex["answers"]["text"]
22
- # get top-k retrieved contexts
23
  if rerank_k:
24
- ctxs, _ = retrieve(question, passages, embedder, index, k=k, rerank_k=rerank_k)
 
25
  else:
26
- # skip reranking: use top-k directly
27
  q_emb = embedder.encode([question], convert_to_numpy=True)
28
  distances, idxs = index.search(q_emb, k)
29
  ctxs = [passages[i] for i in idxs[0]]
30
- # check if any gold span appears
 
31
  if any(any(ans in ctx for ctx in ctxs) for ans in gold_answers):
32
  hits += 1
 
33
  recall = hits / num_samples
34
  print(f"Retrieval Recall@{k} (rerank_k={rerank_k}): {recall:.3f} ({hits}/{num_samples})")
35
  return recall
36
 
37
 
38
- def retrieval_recall_answerable(dataset, passages, embedder, index, k=20, rerank_k=None, num_samples=100):
39
  """
40
- Retrieval Recall@k evaluated only on answerable questions.
41
  """
42
- hits, total = 0, 0
 
43
  for ex in dataset.select(range(num_samples)):
44
- if not ex["answers"]["text"]:
 
45
  continue
46
  total += 1
47
  question = ex["question"]
 
48
  if rerank_k:
49
- ctxs, _ = retrieve(question, passages, embedder, index, k=k, rerank_k=rerank_k)
50
  else:
51
  q_emb = embedder.encode([question], convert_to_numpy=True)
52
  distances, idxs = index.search(q_emb, k)
53
  ctxs = [passages[i] for i in idxs[0]]
54
- if any(any(ans in ctx for ctx in ctxs) for ans in ex["answers"]["text"]):
 
55
  hits += 1
 
56
  recall = hits / total if total > 0 else 0.0
57
  print(f"Retrieval Recall@{k} on answerable only (rerank_k={rerank_k}): {recall:.3f} ({hits}/{total})")
58
  return recall
@@ -60,34 +69,42 @@ def retrieval_recall_answerable(dataset, passages, embedder, index, k=20, rerank
60
 
61
  def qa_eval_answerable(dataset, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100):
62
  """
63
- End-to-end QA EM/F1 on answerable subset using the retrieve_and_answer logic.
64
  """
65
  squad_metric = evaluate.load("squad")
66
- preds, refs = [], []
 
 
67
  for ex in dataset.select(range(num_samples)):
68
- if not ex["answers"]["text"]:
 
69
  continue
70
  qid = ex["id"]
71
  # retrieve and generate
72
- answer, _ = retrieve_and_answer(ex["question"], passages, embedder, reranker, index, qa_pipe)
 
 
73
  preds.append({"id": qid, "prediction_text": answer})
74
  refs.append({"id": qid, "answers": ex["answers"]})
 
75
  results = squad_metric.compute(predictions=preds, references=refs)
76
  print(f"Answerable-only QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}")
77
  return results
78
 
79
 
80
  def main():
81
- # Setup RAG components
82
  passages, embedder, reranker, index, qa_pipe = setup_rag()
83
- # Load SQuAD v2 validation set
 
84
  squad = load_dataset("rajpurkar/squad_v2", split="validation")
85
 
86
- # Run evaluations
87
- retrieval_recall(squad, passages, embedder, index, k=20, rerank_k=5, num_samples=100)
88
- retrieval_recall_answerable(squad, passages, embedder, index, k=20, rerank_k=5, num_samples=100)
89
  qa_eval_answerable(squad, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100)
90
 
91
 
92
  if __name__ == "__main__":
93
  main()
 
 
7
  import evaluate
8
 
9
  # Import RAG setup and retrieval logic from app.py
10
+ from app import setup_rag, retrieve, retrieve_and_answer
11
 
12
 
13
+ def retrieval_recall(dataset, passages, embedder, reranker, index, k=20, rerank_k=None, num_samples=100):
14
  """
15
  Compute raw Retrieval Recall@k on the first num_samples examples.
16
+ If rerank_k is set, apply cross-encoder reranking via `retrieve`.
17
+ Otherwise, use the FAISS index only (top-k) without reranking.
18
  """
19
  hits = 0
20
  for ex in dataset.select(range(num_samples)):
21
  question = ex["question"]
22
  gold_answers = ex["answers"]["text"]
23
+
24
  if rerank_k:
25
+ # use two-stage retrieval (dense + rerank)
26
+ ctxs, _ = retrieve(question, passages, embedder, reranker, index, k=k, rerank_k=rerank_k)
27
  else:
28
+ # single-stage: FAISS only
29
  q_emb = embedder.encode([question], convert_to_numpy=True)
30
  distances, idxs = index.search(q_emb, k)
31
  ctxs = [passages[i] for i in idxs[0]]
32
+
33
+ # check if any gold answer appears in any retrieved context
34
  if any(any(ans in ctx for ctx in ctxs) for ans in gold_answers):
35
  hits += 1
36
+
37
  recall = hits / num_samples
38
  print(f"Retrieval Recall@{k} (rerank_k={rerank_k}): {recall:.3f} ({hits}/{num_samples})")
39
  return recall
40
 
41
 
42
+ def retrieval_recall_answerable(dataset, passages, embedder, reranker, index, k=20, rerank_k=None, num_samples=100):
43
  """
44
+ Retrieval Recall@k evaluated only on answerable questions (answers list non-empty).
45
  """
46
+ hits = 0
47
+ total = 0
48
  for ex in dataset.select(range(num_samples)):
49
+ gold = ex["answers"]["text"]
50
+ if not gold:
51
  continue
52
  total += 1
53
  question = ex["question"]
54
+
55
  if rerank_k:
56
+ ctxs, _ = retrieve(question, passages, embedder, reranker, index, k=k, rerank_k=rerank_k)
57
  else:
58
  q_emb = embedder.encode([question], convert_to_numpy=True)
59
  distances, idxs = index.search(q_emb, k)
60
  ctxs = [passages[i] for i in idxs[0]]
61
+
62
+ if any(any(ans in ctx for ctx in ctxs) for ans in gold):
63
  hits += 1
64
+
65
  recall = hits / total if total > 0 else 0.0
66
  print(f"Retrieval Recall@{k} on answerable only (rerank_k={rerank_k}): {recall:.3f} ({hits}/{total})")
67
  return recall
 
69
 
70
  def qa_eval_answerable(dataset, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100):
71
  """
72
+ End-to-end QA EM/F1 on answerable subset using retrieve_and_answer.
73
  """
74
  squad_metric = evaluate.load("squad")
75
+ preds = []
76
+ refs = []
77
+
78
  for ex in dataset.select(range(num_samples)):
79
+ gold = ex["answers"]["text"]
80
+ if not gold:
81
  continue
82
  qid = ex["id"]
83
  # retrieve and generate
84
+ answer, _ = retrieve_and_answer(
85
+ ex["question"], passages, embedder, reranker, index, qa_pipe
86
+ )
87
  preds.append({"id": qid, "prediction_text": answer})
88
  refs.append({"id": qid, "answers": ex["answers"]})
89
+
90
  results = squad_metric.compute(predictions=preds, references=refs)
91
  print(f"Answerable-only QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}")
92
  return results
93
 
94
 
95
  def main():
96
+ # 1) Setup RAG components
97
  passages, embedder, reranker, index, qa_pipe = setup_rag()
98
+
99
+ # 2) Load SQuAD v2 validation split
100
  squad = load_dataset("rajpurkar/squad_v2", split="validation")
101
 
102
+ # 3) Run evaluations
103
+ retrieval_recall(squad, passages, embedder, reranker, index, k=20, rerank_k=5, num_samples=100)
104
+ retrieval_recall_answerable(squad, passages, embedder, reranker, index, k=20, rerank_k=5, num_samples=100)
105
  qa_eval_answerable(squad, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100)
106
 
107
 
108
  if __name__ == "__main__":
109
  main()
110
+