File size: 5,672 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
from collections import defaultdict
from typing import List, Union
import dspy
from typing import Optional
import json
import os
import requests
from dsp.utils import dotdict
START_SNIPPET = "<%START%>"
END_SNIPPET = "<%END%>"
def remove_snippet(s: str) -> str:
return s.replace(START_SNIPPET, "").replace(END_SNIPPET, "")
class VectaraRM(dspy.Retrieve):
"""
A retrieval module that uses Vectara to return the top passages for a given query.
Assumes that a Vectara corpus has been created and populated with the following payload:
- document: The text of the passage
Args:
vectara_customer_id (str): Vectara Customer ID. defaults to VECTARA_CUSTOMER_ID environment variable
vectara_corpus_id (str): Vectara Corpus ID. defaults to VECTARA_CORPUS_ID environment variable
vectara_api_key (str): Vectara API Key. defaults to VECTARA_API_KEY environment variable
k (int, optional): The default number of top passages to retrieve. Defaults to 3.
Examples:
Below is a code snippet that shows how to use Vectara as the default retriver:
```python
from vectara_client import vectaraClient
llm = dspy.OpenAI(model="gpt-3.5-turbo")
retriever_model = vectaraRM("<VECTARA_CUSTOMER_ID>", "<VECTARA_CORPUS_ID>", "<VECTARA_API_KEY>")
dspy.settings.configure(lm=llm, rm=retriever_model)
```
Below is a code snippet that shows how to use Vectara in the forward() function of a module
```python
self.retrieve = vectaraRM("<VECTARA_CUSTOMER_ID>", "<VECTARA_CORPUS_ID>", "<VECTARA_API_KEY>", k=num_passages)
```
"""
def __init__(
self,
vectara_customer_id: Optional[str] = None,
vectara_corpus_id: Optional[str] = None,
vectara_api_key: Optional[str] = None,
k: int = 5,
):
if vectara_customer_id is None:
vectara_customer_id = os.environ.get("VECTARA_CUSTOMER_ID", "")
if vectara_corpus_id is None:
vectara_corpus_id = os.environ.get("VECTARA_CORPUS_ID", "")
if vectara_api_key is None:
vectara_api_key = os.environ.get("VECTARA_API_KEY", "")
self._vectara_customer_id = vectara_customer_id
self._vectara_corpus_id = vectara_corpus_id
self._vectara_api_key = vectara_api_key
self._n_sentences_before = self._n_sentences_after = 2
self._vectara_timeout = 120
super().__init__(k=k)
def _vectara_query(
self,
query: str,
limit: int = 3,
) -> List[str]:
"""Query Vectara index to get for top k matching passages.
Args:
query: query string
"""
corpus_key = {
"customerId": self._vectara_customer_id,
"corpusId": self._vectara_corpus_id,
"lexicalInterpolationConfig": {"lambda": 0.025 }
}
data = {
"query": [
{
"query": query,
"start": 0,
"numResults": limit,
"contextConfig": {
"sentencesBefore": self._n_sentences_before,
"sentencesAfter": self._n_sentences_after,
"startTag": START_SNIPPET,
"endTag": END_SNIPPET,
},
"corpusKey": [corpus_key],
}
]
}
headers = {
"x-api-key": self._vectara_api_key,
"customer-id": self._vectara_customer_id,
"Content-Type": "application/json",
"X-Source": "dspy",
}
response = requests.post(
headers=headers,
url="https://api.vectara.io/v1/query",
data=json.dumps(data),
timeout=self._vectara_timeout,
)
if response.status_code != 200:
print(
"Query failed %s",
f"(code {response.status_code}, reason {response.reason}, details "
f"{response.text})",
)
return []
result = response.json()
responses = result["responseSet"][0]["response"]
res = [
{
"text": remove_snippet(x["text"]),
"score": x["score"]
} for x in responses
]
return res
def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]) -> dspy.Prediction:
"""Search with Vectara for self.k top passages for query
Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
k (Optional[int]): The number of top passages to retrieve. Defaults to self.k.
Returns:
dspy.Prediction: An object containing the retrieved passages.
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
queries = [q for q in queries if q] # Filter empty queries
k = k if k is not None else self.k
all_res = []
limit = 3*k if len(queries) > 1 else k
for query in queries:
_res = self._vectara_query(query, limit=limit)
all_res.append(_res)
passages = defaultdict(float)
for res_list in all_res:
for res in res_list:
passages[res["text"]] += res["score"]
sorted_passages = sorted(
passages.items(), key=lambda x: x[1], reverse=True)[:k]
return [dotdict({"long_text": passage}) for passage, _ in sorted_passages]
|