EureCA / dspy /retrieve /marqo_rm.py
tonneli's picture
Delete history
f5776d3
from collections import defaultdict
from typing import List, Union
import dspy
try:
import marqo
except ImportError:
raise ImportError(
"The 'marqo' extra is required to use MarqoRM. Install it with `pip install dspy-ai[marqo]`"
)
class MarqoRM(dspy.Retrieve):
"""
A retrieval module that uses Marqo to return the top passages for a given query.
Assumes that a Marqo index has been created and populated with the following payload:
- document: The text of the passage
Args:
marqo_index_name (str): The name of the marqo index.
marqo_client (marqo.client.Client): A marqo client instance.
k (int, optional): The number of top passages to retrieve. Defaults to 3.
Examples:
Below is a code snippet that shows how to use Marqo as the default retriver:
```python
import marqo
marqo_client = marqo.Client(url="http://0.0.0.0:8882")
llm = dspy.OpenAI(model="gpt-3.5-turbo")
retriever_model = MarqoRM("my_index_name", marqo_client=marqo_client)
dspy.settings.configure(lm=llm, rm=retriever_model)
```
Below is a code snippet that shows how to use Marqo in the forward() function of a module
```python
self.retrieve = MarqoRM("my_index_name", marqo_client=marqo_client, k=num_passages)
```
"""
def __init__(
self,
marqo_index_name: str,
marqo_client: marqo.client.Client,
k: int = 3,
):
self._marqo_index_name = marqo_index_name
self._marqo_client = marqo_client
super().__init__(k=k)
def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction:
"""Search with Marqo for self.k top passages for query
Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
Returns:
dspy.Prediction: An object containing the retrieved passages.
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
queries = [q for q in queries if q]
all_query_results = []
for query in queries:
_result = self._marqo_client.index(self._marqo_index_name).search(
q=query,
limit=self.k
)
all_query_results.append(_result)
passages = defaultdict(float)
for result_dict in all_query_results:
for result in result_dict['hits']:
passages[result['document']] += result['_score']
sorted_passages = sorted(
passages.items(), key=lambda x: x[1], reverse=True)[:self.k]
return dspy.Prediction(passages=[passage for passage, _ in sorted_passages])