File size: 5,672 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from collections import defaultdict
from typing import List, Union
import dspy
from typing import Optional
import json
import os
import requests

from dsp.utils import dotdict

START_SNIPPET = "<%START%>"
END_SNIPPET = "<%END%>"

def remove_snippet(s: str) -> str:
    return s.replace(START_SNIPPET, "").replace(END_SNIPPET, "")

class VectaraRM(dspy.Retrieve):
    """
    A retrieval module that uses Vectara to return the top passages for a given query.

    Assumes that a Vectara corpus has been created and populated with the following payload:
        - document: The text of the passage

    Args:
        vectara_customer_id (str): Vectara Customer ID. defaults to VECTARA_CUSTOMER_ID environment variable
        vectara_corpus_id (str): Vectara Corpus ID. defaults to VECTARA_CORPUS_ID environment variable
        vectara_api_key (str): Vectara API Key. defaults to VECTARA_API_KEY environment variable
        k (int, optional): The default number of top passages to retrieve. Defaults to 3.

    Examples:
        Below is a code snippet that shows how to use Vectara as the default retriver:
        ```python
        from vectara_client import vectaraClient

        llm = dspy.OpenAI(model="gpt-3.5-turbo")
        retriever_model = vectaraRM("<VECTARA_CUSTOMER_ID>", "<VECTARA_CORPUS_ID>", "<VECTARA_API_KEY>")
        dspy.settings.configure(lm=llm, rm=retriever_model)
        ```

        Below is a code snippet that shows how to use Vectara in the forward() function of a module
        ```python
        self.retrieve = vectaraRM("<VECTARA_CUSTOMER_ID>", "<VECTARA_CORPUS_ID>", "<VECTARA_API_KEY>", k=num_passages)
        ```
    """

    def __init__(
        self,
        vectara_customer_id: Optional[str] = None,
        vectara_corpus_id: Optional[str] = None,
        vectara_api_key: Optional[str] = None,
        k: int = 5,
    ):
        if vectara_customer_id is None:
            vectara_customer_id = os.environ.get("VECTARA_CUSTOMER_ID", "")
        if vectara_corpus_id is None:
            vectara_corpus_id = os.environ.get("VECTARA_CORPUS_ID", "")
        if vectara_api_key is None:
            vectara_api_key = os.environ.get("VECTARA_API_KEY", "")

        self._vectara_customer_id = vectara_customer_id
        self._vectara_corpus_id = vectara_corpus_id
        self._vectara_api_key = vectara_api_key
        self._n_sentences_before = self._n_sentences_after = 2
        self._vectara_timeout = 120
        super().__init__(k=k)

    def _vectara_query(
        self,
        query: str,
        limit: int = 3,
    ) -> List[str]:
        """Query Vectara index to get for top k matching passages.
        Args:
            query: query string
        """
        corpus_key = {
            "customerId": self._vectara_customer_id,
            "corpusId": self._vectara_corpus_id,
            "lexicalInterpolationConfig": {"lambda": 0.025 }
        }

        data = {
            "query": [
                {
                    "query": query,
                    "start": 0,
                    "numResults": limit,
                    "contextConfig": {
                        "sentencesBefore": self._n_sentences_before,
                        "sentencesAfter": self._n_sentences_after,
                        "startTag": START_SNIPPET,
                        "endTag": END_SNIPPET,
                    },
                    "corpusKey": [corpus_key],
                }
            ]
        }

        headers = {
            "x-api-key": self._vectara_api_key,
            "customer-id": self._vectara_customer_id,
            "Content-Type": "application/json",
            "X-Source": "dspy",
        }

        response = requests.post(
            headers=headers,
            url="https://api.vectara.io/v1/query",
            data=json.dumps(data),
            timeout=self._vectara_timeout,
        )

        if response.status_code != 200:
            print(
                "Query failed %s",
                f"(code {response.status_code}, reason {response.reason}, details "
                f"{response.text})",
            )
            return []

        result = response.json()
        responses = result["responseSet"][0]["response"]

        res = [
            {
                "text": remove_snippet(x["text"]),
                "score": x["score"]
            } for x in responses
        ]
        return res
    
    def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]) -> dspy.Prediction:
        """Search with Vectara for self.k top passages for query

        Args:
            query_or_queries (Union[str, List[str]]): The query or queries to search for.
            k (Optional[int]): The number of top passages to retrieve. Defaults to self.k.
        Returns:
            dspy.Prediction: An object containing the retrieved passages.
        """
        queries = (
            [query_or_queries]
            if isinstance(query_or_queries, str)
            else query_or_queries
        )
        queries = [q for q in queries if q]  # Filter empty queries        
        k = k if k is not None else self.k

        all_res = []
        limit = 3*k if len(queries) > 1 else k
        for query in queries:
            _res = self._vectara_query(query, limit=limit)
            all_res.append(_res)

        passages = defaultdict(float)
        for res_list in all_res:
            for res in res_list:
                passages[res["text"]] += res["score"]
        sorted_passages = sorted(
            passages.items(), key=lambda x: x[1], reverse=True)[:k]

        return [dotdict({"long_text": passage}) for passage, _ in sorted_passages]