File size: 3,234 Bytes
83870cc 51dabd6 1fb8ae3 51a31d4 51dabd6 b06298d 51a31d4 b7158e7 b06298d 8bbe3aa 51a31d4 ab5dfc2 51a31d4 83870cc 51a31d4 83870cc 51a31d4 ab5dfc2 8bbe3aa e9df5ab 8bbe3aa 83870cc 8bbe3aa 83870cc 8bbe3aa 83870cc 8bbe3aa 83870cc 8bbe3aa e9df5ab 1fb8ae3 8bbe3aa 1fb8ae3 ab5dfc2 1fb8ae3 e9df5ab b7158e7 1fb8ae3 83870cc 1fb8ae3 83870cc 1fb8ae3 8bbe3aa 83870cc ab5dfc2 1fb8ae3 8bbe3aa 1fb8ae3 8bbe3aa b06298d 8bbe3aa 83870cc 8bbe3aa 1fb8ae3 83870cc 2827202 8bbe3aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import os
import os.path
import torch
from datasets import DatasetDict, load_dataset
from transformers import (
DPRContextEncoder,
DPRContextEncoderTokenizer,
DPRQuestionEncoder,
DPRQuestionEncoderTokenizer,
)
from src.retrievers.base_retriever import RetrieveType, Retriever
from src.utils.log import get_logger
from src.utils.preprocessing import remove_formulas
from src.utils.timing import timeit
# Hacky fix for FAISS error on macOS
# See https://stackoverflow.com/a/63374568/4545692
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
logger = get_logger()
class FaissRetriever(Retriever):
"""A class used to retrieve relevant documents based on some query.
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
"""
def __init__(self, paragraphs: DatasetDict, embedding_path: str = "./src/models/paragraphs_embedding.faiss") -> None:
torch.set_grad_enabled(False)
# Context encoding and tokenization
self.ctx_encoder = DPRContextEncoder.from_pretrained(
"facebook/dpr-ctx_encoder-single-nq-base"
)
self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
"facebook/dpr-ctx_encoder-single-nq-base"
)
# Question encoding and tokenization
self.q_encoder = DPRQuestionEncoder.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
self.paragraphs = paragraphs
self.embedding_path = embedding_path
self.index = self._init_index()
def _init_index(
self,
force_new_embedding: bool = False):
ds = self.paragraphs["train"]
ds = ds.map(remove_formulas)
if not force_new_embedding and os.path.exists(self.embedding_path):
ds.load_faiss_index(
'embeddings', self.embedding_path) # type: ignore
return ds
else:
def embed(row):
# Inline helper function to perform embedding
p = row["text"]
tok = self.ctx_tokenizer(
p, return_tensors="pt", truncation=True)
enc = self.ctx_encoder(**tok)[0][0].numpy()
return {"embeddings": enc}
# Add FAISS embeddings
index = ds.map(embed) # type: ignore
index.add_faiss_index(column="embeddings")
# save dataset w/ embeddings
os.makedirs("./src/models/", exist_ok=True)
index.save_faiss_index(
"embeddings", self.embedding_path)
return index
@timeit("faissretriever.retrieve")
def retrieve(self, query: str, k: int = 5) -> RetrieveType:
def embed(q):
# Inline helper function to perform embedding
tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
return self.q_encoder(**tok)[0][0].numpy()
question_embedding = embed(query)
scores, results = self.index.get_nearest_examples(
"embeddings", question_embedding, k=k
)
return scores, results
|