EureCA / dsp /modules /pyserini.py
tonneli's picture
Delete history
f5776d3
from typing import Union
import json
# from datasets import Dataset
from dsp.utils import dotdict
class PyseriniRetriever:
"""Wrapper for retrieval with Pyserini. Supports using either pyserini prebuilt faiss indexes or your own faiss index."""
def __init__(self,
query_encoder: str = 'castorini/dkrr-dpr-nq-retriever',
index: str = 'wikipedia-dpr-dkrr-nq',
dataset= None,
id_field: str = '_id',
text_fields: list[str] = ['text']) -> None:
"""
Args:
query_encoder (`str`):
Huggingface model to encode queries
index (`str`):
Either a prebuilt index from pyserini or a local path to a faiss index
dataset (`Dataset`):
Only required when using a local faiss index. The dataset should be the one that has been put into the faiss index.
id_field (`str`):
The name of the id field of the dataset used for retrieval.
text_fields (`list[str]`):
A list of the names of the text fields for the dataset used for retrieval.
"""
# Keep pyserini as an optional dependency
from pyserini.search import FaissSearcher
from pyserini.prebuilt_index_info import TF_INDEX_INFO, FAISS_INDEX_INFO, IMPACT_INDEX_INFO
self.encoder = FaissSearcher._init_encoder_from_str(query_encoder)
self.dataset = dataset
self.id_field = id_field
self.text_fields = text_fields
if index in TF_INDEX_INFO or index in FAISS_INDEX_INFO or index in IMPACT_INDEX_INFO:
self.searcher = FaissSearcher.from_prebuilt_index(index, self.encoder)
else:
self.searcher = FaissSearcher(index_dir=index, query_encoder=self.encoder)
assert self.dataset is not None
self.dataset_id_to_index = {}
for i, docid in enumerate(self.dataset[self.id_field]):
self.dataset_id_to_index[docid] = i
def __call__(
self, query: str, k: int = 10, threads: int = 16,
) -> Union[list[str], list[dotdict]]:
hits = self.searcher.search(query, k=k, threads=threads)
topk = []
for rank, hit in enumerate(hits, start=1):
if self.dataset is not None:
row = self.dataset_id_to_index[hit.docid]
text = ' '.join(self.dataset[field][row] for field in self.text_fields)
pid = self.dataset[self.id_field][row]
else:
# Pyserini prebuilt faiss indexes can perform docid lookup
psg = json.loads(self.searcher.doc(hit.docid).raw())
text = ' '.join(psg[field] for field in self.text_fields)
pid = psg[self.id_field]
topk.append({
'text': text,
'long_text': text,
'pid': pid,
'score': hit.score,
'rank': rank,
})
return [dotdict(psg) for psg in topk]