|
from typing import Union |
|
import json |
|
|
|
|
|
from dsp.utils import dotdict |
|
|
|
|
|
class PyseriniRetriever: |
|
"""Wrapper for retrieval with Pyserini. Supports using either pyserini prebuilt faiss indexes or your own faiss index.""" |
|
|
|
def __init__(self, |
|
query_encoder: str = 'castorini/dkrr-dpr-nq-retriever', |
|
index: str = 'wikipedia-dpr-dkrr-nq', |
|
dataset= None, |
|
id_field: str = '_id', |
|
text_fields: list[str] = ['text']) -> None: |
|
""" |
|
Args: |
|
|
|
query_encoder (`str`): |
|
Huggingface model to encode queries |
|
index (`str`): |
|
Either a prebuilt index from pyserini or a local path to a faiss index |
|
dataset (`Dataset`): |
|
Only required when using a local faiss index. The dataset should be the one that has been put into the faiss index. |
|
id_field (`str`): |
|
The name of the id field of the dataset used for retrieval. |
|
text_fields (`list[str]`): |
|
A list of the names of the text fields for the dataset used for retrieval. |
|
""" |
|
|
|
|
|
from pyserini.search import FaissSearcher |
|
from pyserini.prebuilt_index_info import TF_INDEX_INFO, FAISS_INDEX_INFO, IMPACT_INDEX_INFO |
|
|
|
self.encoder = FaissSearcher._init_encoder_from_str(query_encoder) |
|
self.dataset = dataset |
|
self.id_field = id_field |
|
self.text_fields = text_fields |
|
|
|
if index in TF_INDEX_INFO or index in FAISS_INDEX_INFO or index in IMPACT_INDEX_INFO: |
|
self.searcher = FaissSearcher.from_prebuilt_index(index, self.encoder) |
|
else: |
|
self.searcher = FaissSearcher(index_dir=index, query_encoder=self.encoder) |
|
assert self.dataset is not None |
|
self.dataset_id_to_index = {} |
|
for i, docid in enumerate(self.dataset[self.id_field]): |
|
self.dataset_id_to_index[docid] = i |
|
|
|
|
|
def __call__( |
|
self, query: str, k: int = 10, threads: int = 16, |
|
) -> Union[list[str], list[dotdict]]: |
|
hits = self.searcher.search(query, k=k, threads=threads) |
|
|
|
topk = [] |
|
for rank, hit in enumerate(hits, start=1): |
|
if self.dataset is not None: |
|
row = self.dataset_id_to_index[hit.docid] |
|
text = ' '.join(self.dataset[field][row] for field in self.text_fields) |
|
pid = self.dataset[self.id_field][row] |
|
else: |
|
|
|
psg = json.loads(self.searcher.doc(hit.docid).raw()) |
|
text = ' '.join(psg[field] for field in self.text_fields) |
|
pid = psg[self.id_field] |
|
|
|
topk.append({ |
|
'text': text, |
|
'long_text': text, |
|
'pid': pid, |
|
'score': hit.score, |
|
'rank': rank, |
|
}) |
|
|
|
return [dotdict(psg) for psg in topk] |