|
from typing import List, Optional, Union, Any |
|
import dspy |
|
import os |
|
from openai import ( |
|
OpenAI, |
|
APITimeoutError, |
|
InternalServerError, |
|
RateLimitError, |
|
UnprocessableEntityError, |
|
) |
|
import backoff |
|
|
|
try: |
|
from pymongo import MongoClient |
|
from pymongo.errors import ( |
|
ConnectionFailure, |
|
ConfigurationError, |
|
ServerSelectionTimeoutError, |
|
InvalidURI, |
|
OperationFailure, |
|
) |
|
except ImportError: |
|
raise ImportError( |
|
"Please install the pymongo package by running `pip install dspy-ai[mongodb]`" |
|
) |
|
|
|
|
|
def build_vector_search_pipeline( |
|
index_name: str, query_vector: List[float], num_candidates: int, limit: int |
|
) -> List[dict[str, Any]]: |
|
return [ |
|
{ |
|
"$vectorSearch": { |
|
"index": index_name, |
|
"path": "embedding", |
|
"queryVector": query_vector, |
|
"numCandidates": num_candidates, |
|
"limit": limit, |
|
} |
|
}, |
|
{"$project": {"_id": 0, "text": 1, "score": {"$meta": "vectorSearchScore"}}}, |
|
] |
|
|
|
|
|
class Embedder: |
|
def __init__(self, provider: str, model: str): |
|
if provider == "openai": |
|
api_key = os.getenv("OPENAI_API_KEY") |
|
if not api_key: |
|
raise ValueError("Environment variable OPENAI_API_KEY must be set") |
|
self.client = OpenAI() |
|
self.model = model |
|
|
|
@backoff.on_exception( |
|
backoff.expo, |
|
( |
|
APITimeoutError, |
|
InternalServerError, |
|
RateLimitError, |
|
UnprocessableEntityError, |
|
), |
|
max_time=15, |
|
) |
|
def __call__(self, queries) -> Any: |
|
embedding = self.client.embeddings.create(input=queries, model=self.model) |
|
return [result.embedding for result in embedding.data] |
|
|
|
|
|
class MongoDBAtlasRM(dspy.Retrieve): |
|
def __init__( |
|
self, |
|
db_name: str, |
|
collection_name: str, |
|
index_name: str, |
|
k: int = 5, |
|
embedding_provider: str = "openai", |
|
embedding_model: str = "text-embedding-ada-002", |
|
): |
|
super().__init__(k=k) |
|
self.db_name = db_name |
|
self.collection_name = collection_name |
|
self.index_name = index_name |
|
self.username = os.getenv("ATLAS_USERNAME") |
|
self.password = os.getenv("ATLAS_PASSWORD") |
|
self.cluster_url = os.getenv("ATLAS_CLUSTER_URL") |
|
if not self.username: |
|
raise ValueError("Environment variable ATLAS_USERNAME must be set") |
|
if not self.password: |
|
raise ValueError("Environment variable ATLAS_PASSWORD must be set") |
|
if not self.cluster_url: |
|
raise ValueError("Environment variable ATLAS_CLUSTER_URL must be set") |
|
try: |
|
self.client = MongoClient( |
|
f"mongodb+srv://{self.username}:{self.password}@{self.cluster_url}/{self.db_name}" |
|
"?retryWrites=true&w=majority" |
|
) |
|
except ( |
|
InvalidURI, |
|
ConfigurationError, |
|
ConnectionFailure, |
|
ServerSelectionTimeoutError, |
|
OperationFailure, |
|
) as e: |
|
raise ConnectionError("Failed to connect to MongoDB Atlas") from e |
|
|
|
self.embedder = Embedder(provider=embedding_provider, model=embedding_model) |
|
|
|
def forward(self, query_or_queries: str) -> dspy.Prediction: |
|
query_vector = self.embedder([query_or_queries]) |
|
pipeline = build_vector_search_pipeline( |
|
index_name=self.index_name, |
|
query_vector=query_vector[0], |
|
num_candidates=self.k * 10, |
|
limit=self.k, |
|
) |
|
contents = self.client[self.db_name][self.collection_name].aggregate(pipeline) |
|
return dspy.Prediction(passages=list(contents)) |
|
|