File size: 3,143 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from typing import Union
import json
# from datasets import Dataset

from dsp.utils import dotdict


class PyseriniRetriever:
    """Wrapper for retrieval with Pyserini. Supports using either pyserini prebuilt faiss indexes or your own faiss index."""

    def __init__(self, 
                 query_encoder: str = 'castorini/dkrr-dpr-nq-retriever', 
                 index: str = 'wikipedia-dpr-dkrr-nq', 
                 dataset= None,
                 id_field: str = '_id',
                 text_fields: list[str] = ['text']) -> None:
        """
        Args:
        
            query_encoder (`str`):
                Huggingface model to encode queries
            index (`str`):
                Either a prebuilt index from pyserini or a local path to a faiss index
            dataset (`Dataset`):
                Only required when using a local faiss index. The dataset should be the one that has been put into the faiss index.
            id_field (`str`):
                The name of the id field of the dataset used for retrieval.
            text_fields (`list[str]`):
                A list of the names of the text fields for the dataset used for retrieval.
        """
        
        # Keep pyserini as an optional dependency
        from pyserini.search import FaissSearcher
        from pyserini.prebuilt_index_info import TF_INDEX_INFO, FAISS_INDEX_INFO, IMPACT_INDEX_INFO
        
        self.encoder = FaissSearcher._init_encoder_from_str(query_encoder)
        self.dataset = dataset
        self.id_field = id_field
        self.text_fields = text_fields
        
        if index in TF_INDEX_INFO or index in FAISS_INDEX_INFO or index in IMPACT_INDEX_INFO:
            self.searcher = FaissSearcher.from_prebuilt_index(index, self.encoder)
        else:
            self.searcher = FaissSearcher(index_dir=index, query_encoder=self.encoder)
            assert self.dataset is not None
            self.dataset_id_to_index = {}
            for i, docid in enumerate(self.dataset[self.id_field]):
                self.dataset_id_to_index[docid] = i
                

    def __call__(
        self, query: str, k: int = 10, threads: int = 16,
    ) -> Union[list[str], list[dotdict]]:
        hits = self.searcher.search(query, k=k, threads=threads)
        
        topk = []
        for rank, hit in enumerate(hits, start=1):
            if self.dataset is not None:
                row = self.dataset_id_to_index[hit.docid]
                text = ' '.join(self.dataset[field][row] for field in self.text_fields)
                pid = self.dataset[self.id_field][row]
            else:
                # Pyserini prebuilt faiss indexes can perform docid lookup
                psg = json.loads(self.searcher.doc(hit.docid).raw())
                text = ' '.join(psg[field] for field in self.text_fields)
                pid = psg[self.id_field]
            
            topk.append({
                'text': text,
                'long_text': text,
                'pid': pid,
                'score': hit.score,
                'rank': rank,
            })
        
        return [dotdict(psg) for psg in topk]