Ramon Meffert
commited on
Commit
·
ab5dfc2
1
Parent(s):
fa8dc75
Add reader
Browse files- main.py +47 -18
- src/readers/dpr_reader.py +27 -0
- src/retrievers/{fais_retriever.py → faiss_retriever.py} +10 -9
- src/utils/preprocessing.py +35 -0
main.py
CHANGED
|
@@ -1,12 +1,21 @@
|
|
| 1 |
from datasets import DatasetDict, load_dataset
|
| 2 |
|
| 3 |
-
from src.
|
|
|
|
| 4 |
from src.utils.log import get_logger
|
| 5 |
-
from src.evaluation import evaluate
|
| 6 |
from typing import cast
|
| 7 |
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
if __name__ == '__main__':
|
| 12 |
dataset_name = "GroNLP/ik-nlp-22_slp"
|
|
@@ -15,24 +24,44 @@ if __name__ == '__main__':
|
|
| 15 |
|
| 16 |
questions_test = questions["test"]
|
| 17 |
|
| 18 |
-
logger.info(questions)
|
| 19 |
|
| 20 |
# Initialize retriever
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
| 28 |
-
f"Example q: {example_q} answer: {result['text'][0]}")
|
| 29 |
|
| 30 |
-
for i, score in enumerate(scores):
|
| 31 |
-
|
| 32 |
-
|
| 33 |
|
| 34 |
-
# Compute overall performance
|
| 35 |
-
exact_match, f1_score = evaluate(
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
f"F1-score: {f1_score:.02f}")
|
|
|
|
| 1 |
from datasets import DatasetDict, load_dataset
|
| 2 |
|
| 3 |
+
from src.readers.dpr_reader import DprReader
|
| 4 |
+
from src.retrievers.faiss_retriever import FaissRetriever
|
| 5 |
from src.utils.log import get_logger
|
| 6 |
+
# from src.evaluation import evaluate
|
| 7 |
from typing import cast
|
| 8 |
|
| 9 |
+
from src.utils.preprocessing import result_to_reader_input
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import transformers
|
| 13 |
+
import os
|
| 14 |
|
| 15 |
+
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
|
| 16 |
+
|
| 17 |
+
logger = get_logger()
|
| 18 |
+
transformers.logging.set_verbosity_error()
|
| 19 |
|
| 20 |
if __name__ == '__main__':
|
| 21 |
dataset_name = "GroNLP/ik-nlp-22_slp"
|
|
|
|
| 24 |
|
| 25 |
questions_test = questions["test"]
|
| 26 |
|
| 27 |
+
# logger.info(questions)
|
| 28 |
|
| 29 |
# Initialize retriever
|
| 30 |
+
retriever = FaissRetriever()
|
| 31 |
+
|
| 32 |
+
# Retrieve example
|
| 33 |
+
example_q = questions_test.shuffle()["question"][0]
|
| 34 |
+
scores, result = retriever.retrieve(example_q)
|
| 35 |
+
|
| 36 |
+
reader_input = result_to_reader_input(result)
|
| 37 |
+
|
| 38 |
+
# Initialize reader
|
| 39 |
+
reader = DprReader()
|
| 40 |
+
answers = reader.read(example_q, reader_input)
|
| 41 |
+
|
| 42 |
+
# Calculate softmaxed scores for readable output
|
| 43 |
+
sm = torch.nn.Softmax(dim=0)
|
| 44 |
+
document_scores = sm(torch.Tensor(
|
| 45 |
+
[pred.relevance_score for pred in answers]))
|
| 46 |
+
span_scores = sm(torch.Tensor(
|
| 47 |
+
[pred.span_score for pred in answers]))
|
| 48 |
|
| 49 |
+
print(example_q)
|
| 50 |
+
for answer_i, answer in enumerate(answers):
|
| 51 |
+
print(f"[{answer_i + 1}]: {answer.text}")
|
| 52 |
+
print(f"\tDocument {answer.doc_id}", end='')
|
| 53 |
+
print(f"\t(score {document_scores[answer_i] * 100:.02f})")
|
| 54 |
+
print(f"\tSpan {answer.start_index}-{answer.end_index}", end='')
|
| 55 |
+
print(f"\t(score {span_scores[answer_i] * 100:.02f})")
|
| 56 |
+
print() # Newline
|
| 57 |
|
| 58 |
+
# print(f"Example q: {example_q} answer: {result['text'][0]}")
|
|
|
|
| 59 |
|
| 60 |
+
# for i, score in enumerate(scores):
|
| 61 |
+
# print(f"Result {i+1} (score: {score:.02f}):")
|
| 62 |
+
# print(result['text'][i])
|
| 63 |
|
| 64 |
+
# # Compute overall performance
|
| 65 |
+
# exact_match, f1_score = evaluate(
|
| 66 |
+
# r, questions_test["question"], questions_test["answer"])
|
| 67 |
+
# print(f"Exact match: {exact_match:.02f}\n", f"F1-score: {f1_score:.02f}")
|
|
|
src/readers/dpr_reader.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import DPRReader, DPRReaderTokenizer
|
| 2 |
+
from typing import List, Dict, Tuple
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class DprReader():
|
| 6 |
+
def __init__(self) -> None:
|
| 7 |
+
self._tokenizer = DPRReaderTokenizer.from_pretrained(
|
| 8 |
+
"facebook/dpr-reader-single-nq-base")
|
| 9 |
+
self._model = DPRReader.from_pretrained(
|
| 10 |
+
"facebook/dpr-reader-single-nq-base"
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
def read(self, query: str, context: Dict[str, List[str]]) -> List[Tuple]:
|
| 14 |
+
encoded_inputs = self._tokenizer(
|
| 15 |
+
questions=query,
|
| 16 |
+
titles=context['titles'],
|
| 17 |
+
texts=context['texts'],
|
| 18 |
+
return_tensors='pt',
|
| 19 |
+
truncation=True,
|
| 20 |
+
padding=True
|
| 21 |
+
)
|
| 22 |
+
outputs = self._model(**encoded_inputs)
|
| 23 |
+
|
| 24 |
+
predicted_spans = self._tokenizer.decode_best_spans(
|
| 25 |
+
encoded_inputs, outputs)
|
| 26 |
+
|
| 27 |
+
return predicted_spans
|
src/retrievers/{fais_retriever.py → faiss_retriever.py}
RENAMED
|
@@ -13,15 +13,15 @@ from transformers import (
|
|
| 13 |
from src.retrievers.base_retriever import Retriever
|
| 14 |
from src.utils.log import get_logger
|
| 15 |
|
| 16 |
-
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
| 17 |
# Hacky fix for FAISS error on macOS
|
| 18 |
# See https://stackoverflow.com/a/63374568/4545692
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
logger = get_logger()
|
| 22 |
|
| 23 |
|
| 24 |
-
class
|
| 25 |
"""A class used to retrieve relevant documents based on some query.
|
| 26 |
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
|
| 27 |
"""
|
|
@@ -56,14 +56,16 @@ class FAISRetriever(Retriever):
|
|
| 56 |
self.dataset_name = dataset_name
|
| 57 |
self.dataset = self._init_dataset(dataset_name)
|
| 58 |
|
| 59 |
-
def _init_dataset(
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
| 62 |
"""Loads the dataset and adds FAISS embeddings.
|
| 63 |
|
| 64 |
Args:
|
| 65 |
dataset (str): A HuggingFace dataset name.
|
| 66 |
-
fname (str): The name to use to save the embeddings to disk for
|
| 67 |
faster loading after the first run.
|
| 68 |
|
| 69 |
Returns:
|
|
@@ -73,9 +75,8 @@ class FAISRetriever(Retriever):
|
|
| 73 |
# Load dataset
|
| 74 |
ds = load_dataset(dataset_name, name="paragraphs")[
|
| 75 |
"train"] # type: ignore
|
| 76 |
-
logger.info(ds)
|
| 77 |
|
| 78 |
-
if os.path.exists(embedding_path):
|
| 79 |
# If we already have FAISS embeddings, load them from disk
|
| 80 |
ds.load_faiss_index('embeddings', embedding_path) # type: ignore
|
| 81 |
return ds
|
|
@@ -95,7 +96,7 @@ class FAISRetriever(Retriever):
|
|
| 95 |
ds_with_embeddings.add_faiss_index(column="embeddings")
|
| 96 |
|
| 97 |
# save dataset w/ embeddings
|
| 98 |
-
os.makedirs("./models/", exist_ok=True)
|
| 99 |
ds_with_embeddings.save_faiss_index("embeddings", embedding_path)
|
| 100 |
|
| 101 |
return ds_with_embeddings
|
|
|
|
| 13 |
from src.retrievers.base_retriever import Retriever
|
| 14 |
from src.utils.log import get_logger
|
| 15 |
|
|
|
|
| 16 |
# Hacky fix for FAISS error on macOS
|
| 17 |
# See https://stackoverflow.com/a/63374568/4545692
|
| 18 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
| 19 |
|
| 20 |
|
| 21 |
logger = get_logger()
|
| 22 |
|
| 23 |
|
| 24 |
+
class FaissRetriever(Retriever):
|
| 25 |
"""A class used to retrieve relevant documents based on some query.
|
| 26 |
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
|
| 27 |
"""
|
|
|
|
| 56 |
self.dataset_name = dataset_name
|
| 57 |
self.dataset = self._init_dataset(dataset_name)
|
| 58 |
|
| 59 |
+
def _init_dataset(
|
| 60 |
+
self,
|
| 61 |
+
dataset_name: str,
|
| 62 |
+
embedding_path: str = "./src/models/paragraphs_embedding.faiss",
|
| 63 |
+
force_new_embedding: bool = False):
|
| 64 |
"""Loads the dataset and adds FAISS embeddings.
|
| 65 |
|
| 66 |
Args:
|
| 67 |
dataset (str): A HuggingFace dataset name.
|
| 68 |
+
fname (str): The name to use to save the embeddings to disk for
|
| 69 |
faster loading after the first run.
|
| 70 |
|
| 71 |
Returns:
|
|
|
|
| 75 |
# Load dataset
|
| 76 |
ds = load_dataset(dataset_name, name="paragraphs")[
|
| 77 |
"train"] # type: ignore
|
|
|
|
| 78 |
|
| 79 |
+
if not force_new_embedding and os.path.exists(embedding_path):
|
| 80 |
# If we already have FAISS embeddings, load them from disk
|
| 81 |
ds.load_faiss_index('embeddings', embedding_path) # type: ignore
|
| 82 |
return ds
|
|
|
|
| 96 |
ds_with_embeddings.add_faiss_index(column="embeddings")
|
| 97 |
|
| 98 |
# save dataset w/ embeddings
|
| 99 |
+
os.makedirs("./src/models/", exist_ok=True)
|
| 100 |
ds_with_embeddings.save_faiss_index("embeddings", embedding_path)
|
| 101 |
|
| 102 |
return ds_with_embeddings
|
src/utils/preprocessing.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def result_to_reader_input(result: Dict[str, List[str]]) \
|
| 5 |
+
-> Dict[str, List[str]]:
|
| 6 |
+
"""Takes the output of the retriever and turns it into a format the reader
|
| 7 |
+
understands.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
result (Dict[str, List[str]]): The result from the retriever
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
# Take the number of valeus of an arbitrary item as the number of entries
|
| 14 |
+
# (This should always be valid)
|
| 15 |
+
num_entries = len(result['n_chapter'])
|
| 16 |
+
|
| 17 |
+
# Prepare result
|
| 18 |
+
reader_result = {
|
| 19 |
+
'titles': [],
|
| 20 |
+
'texts': []
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
for n in range(num_entries):
|
| 24 |
+
# Get the most specific title
|
| 25 |
+
if result['subsection'][n] != 'nan':
|
| 26 |
+
title = result['subsection'][n]
|
| 27 |
+
elif result['section'][n] != 'nan':
|
| 28 |
+
title = result['section'][n]
|
| 29 |
+
else:
|
| 30 |
+
title = result['chapter'][n]
|
| 31 |
+
|
| 32 |
+
reader_result['titles'].append(title)
|
| 33 |
+
reader_result['texts'].append(result['text'][n])
|
| 34 |
+
|
| 35 |
+
return reader_result
|