File size: 3,143 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from typing import Union
import json
# from datasets import Dataset
from dsp.utils import dotdict
class PyseriniRetriever:
"""Wrapper for retrieval with Pyserini. Supports using either pyserini prebuilt faiss indexes or your own faiss index."""
def __init__(self,
query_encoder: str = 'castorini/dkrr-dpr-nq-retriever',
index: str = 'wikipedia-dpr-dkrr-nq',
dataset= None,
id_field: str = '_id',
text_fields: list[str] = ['text']) -> None:
"""
Args:
query_encoder (`str`):
Huggingface model to encode queries
index (`str`):
Either a prebuilt index from pyserini or a local path to a faiss index
dataset (`Dataset`):
Only required when using a local faiss index. The dataset should be the one that has been put into the faiss index.
id_field (`str`):
The name of the id field of the dataset used for retrieval.
text_fields (`list[str]`):
A list of the names of the text fields for the dataset used for retrieval.
"""
# Keep pyserini as an optional dependency
from pyserini.search import FaissSearcher
from pyserini.prebuilt_index_info import TF_INDEX_INFO, FAISS_INDEX_INFO, IMPACT_INDEX_INFO
self.encoder = FaissSearcher._init_encoder_from_str(query_encoder)
self.dataset = dataset
self.id_field = id_field
self.text_fields = text_fields
if index in TF_INDEX_INFO or index in FAISS_INDEX_INFO or index in IMPACT_INDEX_INFO:
self.searcher = FaissSearcher.from_prebuilt_index(index, self.encoder)
else:
self.searcher = FaissSearcher(index_dir=index, query_encoder=self.encoder)
assert self.dataset is not None
self.dataset_id_to_index = {}
for i, docid in enumerate(self.dataset[self.id_field]):
self.dataset_id_to_index[docid] = i
def __call__(
self, query: str, k: int = 10, threads: int = 16,
) -> Union[list[str], list[dotdict]]:
hits = self.searcher.search(query, k=k, threads=threads)
topk = []
for rank, hit in enumerate(hits, start=1):
if self.dataset is not None:
row = self.dataset_id_to_index[hit.docid]
text = ' '.join(self.dataset[field][row] for field in self.text_fields)
pid = self.dataset[self.id_field][row]
else:
# Pyserini prebuilt faiss indexes can perform docid lookup
psg = json.loads(self.searcher.doc(hit.docid).raw())
text = ' '.join(psg[field] for field in self.text_fields)
pid = psg[self.id_field]
topk.append({
'text': text,
'long_text': text,
'pid': pid,
'score': hit.score,
'rank': rank,
})
return [dotdict(psg) for psg in topk] |