EureCA / dspy /retrieve /mongodb_atlas_rm.py
tonneli's picture
Delete history
f5776d3
from typing import List, Optional, Union, Any
import dspy
import os
from openai import (
OpenAI,
APITimeoutError,
InternalServerError,
RateLimitError,
UnprocessableEntityError,
)
import backoff
try:
from pymongo import MongoClient
from pymongo.errors import (
ConnectionFailure,
ConfigurationError,
ServerSelectionTimeoutError,
InvalidURI,
OperationFailure,
)
except ImportError:
raise ImportError(
"Please install the pymongo package by running `pip install dspy-ai[mongodb]`"
)
def build_vector_search_pipeline(
index_name: str, query_vector: List[float], num_candidates: int, limit: int
) -> List[dict[str, Any]]:
return [
{
"$vectorSearch": {
"index": index_name,
"path": "embedding",
"queryVector": query_vector,
"numCandidates": num_candidates,
"limit": limit,
}
},
{"$project": {"_id": 0, "text": 1, "score": {"$meta": "vectorSearchScore"}}},
]
class Embedder:
def __init__(self, provider: str, model: str):
if provider == "openai":
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("Environment variable OPENAI_API_KEY must be set")
self.client = OpenAI()
self.model = model
@backoff.on_exception(
backoff.expo,
(
APITimeoutError,
InternalServerError,
RateLimitError,
UnprocessableEntityError,
),
max_time=15,
)
def __call__(self, queries) -> Any:
embedding = self.client.embeddings.create(input=queries, model=self.model)
return [result.embedding for result in embedding.data]
class MongoDBAtlasRM(dspy.Retrieve):
def __init__(
self,
db_name: str,
collection_name: str,
index_name: str,
k: int = 5,
embedding_provider: str = "openai",
embedding_model: str = "text-embedding-ada-002",
):
super().__init__(k=k)
self.db_name = db_name
self.collection_name = collection_name
self.index_name = index_name
self.username = os.getenv("ATLAS_USERNAME")
self.password = os.getenv("ATLAS_PASSWORD")
self.cluster_url = os.getenv("ATLAS_CLUSTER_URL")
if not self.username:
raise ValueError("Environment variable ATLAS_USERNAME must be set")
if not self.password:
raise ValueError("Environment variable ATLAS_PASSWORD must be set")
if not self.cluster_url:
raise ValueError("Environment variable ATLAS_CLUSTER_URL must be set")
try:
self.client = MongoClient(
f"mongodb+srv://{self.username}:{self.password}@{self.cluster_url}/{self.db_name}"
"?retryWrites=true&w=majority"
)
except (
InvalidURI,
ConfigurationError,
ConnectionFailure,
ServerSelectionTimeoutError,
OperationFailure,
) as e:
raise ConnectionError("Failed to connect to MongoDB Atlas") from e
self.embedder = Embedder(provider=embedding_provider, model=embedding_model)
def forward(self, query_or_queries: str) -> dspy.Prediction:
query_vector = self.embedder([query_or_queries])
pipeline = build_vector_search_pipeline(
index_name=self.index_name,
query_vector=query_vector[0],
num_candidates=self.k * 10,
limit=self.k,
)
contents = self.client[self.db_name][self.collection_name].aggregate(pipeline)
return dspy.Prediction(passages=list(contents))