File size: 2,814 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from collections import defaultdict
from typing import List, Union
import dspy

try:
    import marqo
except ImportError:
    raise ImportError(
        "The 'marqo' extra is required to use MarqoRM. Install it with `pip install dspy-ai[marqo]`"
    )

class MarqoRM(dspy.Retrieve):
    """
    A retrieval module that uses Marqo to return the top passages for a given query.

    Assumes that a Marqo index has been created and populated with the following payload:
        - document: The text of the passage

    Args:
        marqo_index_name (str): The name of the marqo index.
        marqo_client (marqo.client.Client): A marqo client instance.
        k (int, optional): The number of top passages to retrieve. Defaults to 3.

    Examples:
        Below is a code snippet that shows how to use Marqo as the default retriver:
        ```python
        import marqo
        marqo_client = marqo.Client(url="http://0.0.0.0:8882")

        llm = dspy.OpenAI(model="gpt-3.5-turbo")
        retriever_model = MarqoRM("my_index_name", marqo_client=marqo_client)
        dspy.settings.configure(lm=llm, rm=retriever_model)
        ```

        Below is a code snippet that shows how to use Marqo in the forward() function of a module
        ```python
        self.retrieve = MarqoRM("my_index_name", marqo_client=marqo_client, k=num_passages)
        ```
    """

    def __init__(
        self,
        marqo_index_name: str,
        marqo_client: marqo.client.Client,
        k: int = 3,
    ):
        self._marqo_index_name = marqo_index_name
        self._marqo_client = marqo_client

        super().__init__(k=k)

    def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction:
        """Search with Marqo for self.k top passages for query

        Args:
            query_or_queries (Union[str, List[str]]): The query or queries to search for.

        Returns:
            dspy.Prediction: An object containing the retrieved passages.
        """
        queries = (
            [query_or_queries]
            if isinstance(query_or_queries, str)
            else query_or_queries
        )
        queries = [q for q in queries if q]  

        all_query_results = []
        for query in queries:
            _result = self._marqo_client.index(self._marqo_index_name).search(
                q=query,
                limit=self.k
            )
            all_query_results.append(_result)

        passages = defaultdict(float)

        for result_dict in all_query_results:
            for result in result_dict['hits']:
                passages[result['document']] += result['_score']

        sorted_passages = sorted(
            passages.items(), key=lambda x: x[1], reverse=True)[:self.k]
        return dspy.Prediction(passages=[passage for passage, _ in sorted_passages])