File size: 3,254 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""Clarifai as retriver to retrieve hits"""
import os
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional, Union

import requests

import dspy
from dsp.utils import dotdict

try:
    from clarifai.client.search import Search
except ImportError as err:
    raise ImportError(
        "Clarifai is not installed. Install it using `pip install clarifai`"
    ) from err


class ClarifaiRM(dspy.Retrieve):
    """
    Retrieval module uses clarifai to return the Top K relevant pasages for the given query.
    Assuming that you have ingested the source documents into clarifai App, where it is indexed and stored.

    Args:
        clarifai_user_id (str): Clarifai unique user_id.
        clarfiai_app_id (str): Clarifai App ID, where the documents are stored.
        clarifai_pat (str): Clarifai PAT key.
        k (int): Top K documents to retrieve.

    Examples:
        TODO
    """

    def __init__(
        self,
        clarifai_user_id: str,
        clarfiai_app_id: str,
        clarifai_pat: Optional[str] = None,
        k: int = 3,
    ):
        self.app_id = clarfiai_app_id
        self.user_id = clarifai_user_id
        self.pat = (
            clarifai_pat if clarifai_pat is not None else os.environ["CLARIFAI_PAT"]
        )
        self.k = k
        self.clarifai_search = Search(
            user_id=self.user_id, app_id=self.app_id, top_k=k, pat=self.pat
        )
        super().__init__(k=k)

    def retrieve_hits(self, hits):
        header = {"Authorization": f"Key {self.pat}"}
        request = requests.get(hits.input.data.text.url, headers=header)
        request.encoding = request.apparent_encoding
        requested_text = request.text
        return requested_text

    def forward(
        self, query_or_queries: Union[str, List[str]], k: Optional[int] = None
    ) -> dspy.Prediction:
        """Uses clarifai-python SDK search function and retrieves top_k similar passages for given query,
        Args:
             query_or_queries : single query or list of queries
             k : Top K relevant documents to return

        Returns:
             passages in format of dotdict

        Examples:
        Below is a code snippet that shows how to use Marqo as the default retriver:
         ```python
         import clarifai
         llm = dspy.Clarifai(model=MODEL_URL, api_key="YOUR CLARIFAI_PAT")
         retriever_model = ClarifaiRM(clarifai_user_id="USER_ID", clarfiai_app_id="APP_ID", clarifai_pat="YOUR CLARIFAI_PAT")
         dspy.settings.configure(lm=llm, rm=retriever_model)
         ```
        """
        queries = (
            [query_or_queries]
            if isinstance(query_or_queries, str)
            else query_or_queries
        )
        passages = []
        queries = [q for q in queries if q]

        for query in queries:
            search_response = self.clarifai_search.query(ranks=[{"text_raw": query}])

            # Retrieve hits
            hits = [hit for data in search_response for hit in data.hits]
            with ThreadPoolExecutor(max_workers=10) as executor:
                results = list(executor.map(self.retrieve_hits, hits))
            passages.extend(dotdict({"long_text": d}) for d in results)

        return passages