EureCA / dspy /retrieve /chromadb_rm.py
tonneli's picture
Delete history
f5776d3
"""
Retriever model for chromadb
"""
from typing import Optional, List, Union
import openai
import dspy
import backoff
from dsp.utils import dotdict
try:
import openai.error
ERRORS = (openai.error.RateLimitError, openai.error.ServiceUnavailableError, openai.error.APIError)
except Exception:
ERRORS = (openai.RateLimitError, openai.APIError)
try:
import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions
from chromadb.api.types import (
Embeddable,
EmbeddingFunction
)
import chromadb.utils.embedding_functions as ef
except ImportError:
chromadb = None
if chromadb is None:
raise ImportError(
"The chromadb library is required to use ChromadbRM. Install it with `pip install dspy-ai[chromadb]`"
)
class ChromadbRM(dspy.Retrieve):
"""
A retrieval module that uses chromadb to return the top passages for a given query.
Assumes that the chromadb index has been created and populated with the following metadata:
- documents: The text of the passage
Args:
collection_name (str): chromadb collection name
persist_directory (str): chromadb persist directory
embedding_function (Optional[EmbeddingFunction[Embeddable]]): Optional function to use to embed documents. Defaults to DefaultEmbeddingFunction.
k (int, optional): The number of top passages to retrieve. Defaults to 7.
Returns:
dspy.Prediction: An object containing the retrieved passages.
Examples:
Below is a code snippet that shows how to use this as the default retriever:
```python
llm = dspy.OpenAI(model="gpt-3.5-turbo")
retriever_model = ChromadbRM('collection_name', 'db_path')
dspy.settings.configure(lm=llm, rm=retriever_model)
# to test the retriever with "my query"
retriever_model("my query")
```
Below is a code snippet that shows how to use this in the forward() function of a module
```python
self.retrieve = ChromadbRM('collection_name', 'db_path', k=num_passages)
```
"""
def __init__(
self,
collection_name: str,
persist_directory: str,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(),
k: int = 7,
):
self._init_chromadb(collection_name, persist_directory)
self.ef = embedding_function
super().__init__(k=k)
def _init_chromadb(
self,
collection_name: str,
persist_directory: str,
) -> chromadb.Collection:
"""Initialize chromadb and return the loaded index.
Args:
collection_name (str): chromadb collection name
persist_directory (str): chromadb persist directory
Returns:
"""
self._chromadb_client = chromadb.Client(
Settings(
persist_directory=persist_directory,
is_persistent=True,
)
)
self._chromadb_collection = self._chromadb_client.get_or_create_collection(
name=collection_name,
)
@backoff.on_exception(
backoff.expo,
ERRORS,
max_time=15,
)
def _get_embeddings(self, queries: List[str]) -> List[List[float]]:
"""Return query vector after creating embedding using OpenAI
Args:
queries (list): List of query strings to embed.
Returns:
List[List[float]]: List of embeddings corresponding to each query.
"""
return self.ef(queries)
def forward(
self, query_or_queries: Union[str, List[str]], k: Optional[int] = None
) -> dspy.Prediction:
"""Search with db for self.k top passages for query
Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
Returns:
dspy.Prediction: An object containing the retrieved passages.
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
queries = [q for q in queries if q] # Filter empty queries
embeddings = self._get_embeddings(queries)
k = self.k if k is None else k
results = self._chromadb_collection.query(
query_embeddings=embeddings, n_results=k
)
passages = [dotdict({"long_text": x}) for x in results["documents"][0]]
return passages