File size: 3,244 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from collections import defaultdict
from typing import List, Union, Optional
import dspy
from dsp.utils import dotdict

try:
    from qdrant_client import QdrantClient
    import fastembed
except ImportError:
    raise ImportError(
        "The 'qdrant' extra is required to use QdrantRM. Install it with `pip install dspy-ai[qdrant]`"
    )


class QdrantRM(dspy.Retrieve):
    """
    A retrieval module that uses Qdrant to return the top passages for a given query.

    Assumes that a Qdrant collection has been created and populated with the following payload:
        - document: The text of the passage

    Args:
        qdrant_collection_name (str): The name of the Qdrant collection.
        qdrant_client (QdrantClient): A QdrantClient instance.
        k (int, optional): The default number of top passages to retrieve. Defaults to 3.

    Examples:
        Below is a code snippet that shows how to use Qdrant as the default retriver:
        ```python
        from qdrant_client import QdrantClient

        llm = dspy.OpenAI(model="gpt-3.5-turbo")
        qdrant_client = QdrantClient()
        retriever_model = QdrantRM("my_collection_name", qdrant_client=qdrant_client)
        dspy.settings.configure(lm=llm, rm=retriever_model)
        ```

        Below is a code snippet that shows how to use Qdrant in the forward() function of a module
        ```python
        self.retrieve = QdrantRM("my_collection_name", qdrant_client=qdrant_client, k=num_passages)
        ```
    """

    def __init__(
        self,
        qdrant_collection_name: str,
        qdrant_client: QdrantClient,
        k: int = 3,
    ):
        self._qdrant_collection_name = qdrant_collection_name
        self._qdrant_client = qdrant_client

        super().__init__(k=k)

    def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]) -> dspy.Prediction:
        """Search with Qdrant for self.k top passages for query

        Args:
            query_or_queries (Union[str, List[str]]): The query or queries to search for.
            k (Optional[int]): The number of top passages to retrieve. Defaults to self.k.
        Returns:
            dspy.Prediction: An object containing the retrieved passages.
        """
        queries = (
            [query_or_queries]
            if isinstance(query_or_queries, str)
            else query_or_queries
        )
        queries = [q for q in queries if q]  # Filter empty queries

        k = k if k is not None else self.k
        batch_results = self._qdrant_client.query_batch(
            self._qdrant_collection_name, query_texts=queries, limit=k)

        passages_scores = defaultdict(float)
        for batch in batch_results:
            for result in batch:
                # If a passage is returned multiple times, the score is accumulated.
                passages_scores[result.document] += result.score

        # Sort passages by their accumulated scores in descending order
        sorted_passages = sorted(
            passages_scores.items(), key=lambda x: x[1], reverse=True)[:k]

        # Wrap each sorted passage in a dotdict with 'long_text'
        passages = [dotdict({"long_text": passage}) for passage, _ in sorted_passages]

        return passages