File size: 3,785 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from typing import List, Optional, Union, Any
import dspy
import os
from openai import (
    OpenAI,
    APITimeoutError,
    InternalServerError,
    RateLimitError,
    UnprocessableEntityError,
)
import backoff

try:
    from pymongo import MongoClient
    from pymongo.errors import (
        ConnectionFailure,
        ConfigurationError,
        ServerSelectionTimeoutError,
        InvalidURI,
        OperationFailure,
    )
except ImportError:
    raise ImportError(
        "Please install the pymongo package by running `pip install dspy-ai[mongodb]`"
    )


def build_vector_search_pipeline(
    index_name: str, query_vector: List[float], num_candidates: int, limit: int
) -> List[dict[str, Any]]:
    return [
        {
            "$vectorSearch": {
                "index": index_name,
                "path": "embedding",
                "queryVector": query_vector,
                "numCandidates": num_candidates,
                "limit": limit,
            }
        },
        {"$project": {"_id": 0, "text": 1, "score": {"$meta": "vectorSearchScore"}}},
    ]


class Embedder:
    def __init__(self, provider: str, model: str):
        if provider == "openai":
            api_key = os.getenv("OPENAI_API_KEY")
            if not api_key:
                raise ValueError("Environment variable OPENAI_API_KEY must be set")
            self.client = OpenAI()
            self.model = model

    @backoff.on_exception(
        backoff.expo,
        (
            APITimeoutError,
            InternalServerError,
            RateLimitError,
            UnprocessableEntityError,
        ),
        max_time=15,
    )
    def __call__(self, queries) -> Any:
        embedding = self.client.embeddings.create(input=queries, model=self.model)
        return [result.embedding for result in embedding.data]


class MongoDBAtlasRM(dspy.Retrieve):
    def __init__(
        self,
        db_name: str,
        collection_name: str,
        index_name: str,
        k: int = 5,
        embedding_provider: str = "openai",
        embedding_model: str = "text-embedding-ada-002",
    ):
        super().__init__(k=k)
        self.db_name = db_name
        self.collection_name = collection_name
        self.index_name = index_name
        self.username = os.getenv("ATLAS_USERNAME")
        self.password = os.getenv("ATLAS_PASSWORD")
        self.cluster_url = os.getenv("ATLAS_CLUSTER_URL")
        if not self.username:
            raise ValueError("Environment variable ATLAS_USERNAME must be set")
        if not self.password:
            raise ValueError("Environment variable ATLAS_PASSWORD must be set")
        if not self.cluster_url:
            raise ValueError("Environment variable ATLAS_CLUSTER_URL must be set")
        try:
            self.client = MongoClient(
                f"mongodb+srv://{self.username}:{self.password}@{self.cluster_url}/{self.db_name}"
                "?retryWrites=true&w=majority"
            )
        except (
            InvalidURI,
            ConfigurationError,
            ConnectionFailure,
            ServerSelectionTimeoutError,
            OperationFailure,
        ) as e:
            raise ConnectionError("Failed to connect to MongoDB Atlas") from e

        self.embedder = Embedder(provider=embedding_provider, model=embedding_model)

    def forward(self, query_or_queries: str) -> dspy.Prediction:
        query_vector = self.embedder([query_or_queries])
        pipeline = build_vector_search_pipeline(
            index_name=self.index_name,
            query_vector=query_vector[0],
            num_candidates=self.k * 10,
            limit=self.k,
        )
        contents = self.client[self.db_name][self.collection_name].aggregate(pipeline)
        return dspy.Prediction(passages=list(contents))