File size: 7,051 Bytes
e112463
fa9ac7e
c6dd20e
e112463
 
 
 
 
d16a006
6a4ac5a
d16a006
 
cb9fd37
e112463
 
 
05dabf4
6a4ac5a
 
 
6008655
6a4ac5a
 
 
1ea867d
c8a1687
d16a006
 
1ea867d
d16a006
 
 
 
1ea867d
d16a006
 
 
 
 
 
 
c8a1687
 
d16a006
 
 
 
 
c8a1687
d16a006
c8a1687
 
 
 
 
 
1ea867d
 
 
 
c8a1687
06bca0c
c6dd20e
 
 
06bca0c
c6dd20e
 
 
 
 
 
 
 
 
 
5b7d0e6
c6dd20e
 
d16a006
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a4ac5a
c6dd20e
e112463
c6dd20e
 
 
 
e112463
fa9ac7e
 
 
6a09028
 
 
c6dd20e
fa9ac7e
 
 
 
f97aa81
c6dd20e
fa9ac7e
 
 
06bca0c
f97aa81
fa9ac7e
 
e112463
fa9ac7e
d16a006
 
fa9ac7e
 
e112463
6a09028
6a4ac5a
6a09028
 
 
 
 
 
 
 
c6dd20e
6a4ac5a
6a09028
 
 
 
 
 
 
 
25a0d11
fa9ac7e
 
 
 
c8a1687
fa9ac7e
1a12df2
c8a1687
 
1a12df2
6a09028
c8a1687
d16a006
 
 
 
6a09028
c8a1687
 
6a4ac5a
 
 
6008655
6a4ac5a
463196d
6a4ac5a
 
 
 
c8a1687
 
6008655
6a4ac5a
d16a006
c8a1687
d16a006
6a09028
6008655
6a4ac5a
c8a1687
6a4ac5a
d16a006
fa9ac7e
6008655
6a4ac5a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import logging
from dataclasses import dataclass, field
from functools import lru_cache

import numpy as np
import pandas as pd
from openai.embeddings_utils import cosine_similarity, get_embedding

from buster.completers import completer_factory
from buster.completers.base import Completion
from buster.formatters.prompts import SystemPromptFormatter, prompt_formatter_factory
from buster.retriever import Retriever

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


@dataclass(slots=True)
class Response:
    completion: Completion
    is_relevant: bool
    matched_documents: pd.DataFrame | None = None


@dataclass
class BusterConfig:
    """Configuration object for a chatbot."""

    embedding_model: str = "text-embedding-ada-002"
    unknown_threshold: float = 0.9
    unknown_prompt: str = "I Don't know how to answer your question."
    document_source: str = ""
    retriever_cfg: dict = field(
        default_factory=lambda: {
            "top_k": 3,
            "thresh": 0.7,
        }
    )
    prompt_cfg: dict = field(
        default_factory=lambda: {
            "max_words": 3000,
            "text_before_documents": "You are a chatbot answering questions.\n",
            "text_before_prompt": "Answer the following question:\n",
        }
    )
    completion_cfg: dict = field(
        default_factory=lambda: {
            "name": "ChatGPT",
            "completion_kwargs": {
                "engine": "gpt-3.5-turbo",
                "max_tokens": 200,
                "temperature": None,
                "top_p": None,
                "frequency_penalty": 1,
                "presence_penalty": 1,
            },
        }
    )


class Buster:
    def __init__(self, cfg: BusterConfig, retriever: Retriever):
        self._unk_embedding = None
        self.update_cfg(cfg)

        self.retriever = retriever

    @property
    def unk_embedding(self):
        return self._unk_embedding

    @unk_embedding.setter
    def unk_embedding(self, embedding):
        logger.info("Setting new UNK embedding...")
        self._unk_embedding = embedding
        return self._unk_embedding

    def update_cfg(self, cfg: BusterConfig):
        """Every time we set a new config, we update the things that need to be updated."""
        logger.info(f"Updating config to {cfg.document_source}:\n{cfg}")
        self._cfg = cfg
        self.embedding_model = cfg.embedding_model
        self.unknown_threshold = cfg.unknown_threshold
        self.unknown_prompt = cfg.unknown_prompt
        self.document_source = cfg.document_source

        self.retriever_cfg = cfg.retriever_cfg
        self.completion_cfg = cfg.completion_cfg
        self.prompt_cfg = cfg.prompt_cfg

        # set the unk. embedding
        self.unk_embedding = self.get_embedding(self.unknown_prompt, engine=self.embedding_model)

        # update completer and formatter cfg
        self.completer = completer_factory(self.completion_cfg)
        self.prompt_formatter = prompt_formatter_factory(self.prompt_cfg)

        logger.info(f"Config Updated.")

    @lru_cache
    def get_embedding(self, query: str, engine: str):
        logger.info("generating embedding")
        return get_embedding(query, engine=engine)

    def rank_documents(
        self,
        query: str,
        top_k: float,
        thresh: float,
        engine: str,
        source: str,
    ) -> pd.DataFrame:
        """
        Compare the question to the series of documents and return the best matching documents.
        """

        query_embedding = self.get_embedding(
            query,
            engine=engine,
        )
        matched_documents = self.retriever.retrieve(query_embedding, top_k=top_k, source=source)

        # log matched_documents to the console
        logger.info(f"matched documents before thresh: {matched_documents}")

        # filter out matched_documents using a threshold
        matched_documents = matched_documents[matched_documents.similarity > thresh]
        logger.info(f"matched documents after thresh: {matched_documents}")

        return matched_documents

    def check_response_relevance(
        self, completion_text: str, engine: str, unk_embedding: np.array, unk_threshold: float
    ) -> bool:
        """Check to see if a response is relevant to the chatbot's knowledge or not.

        We assume we've prompt-engineered our bot to say a response is unrelated to the context if it isn't relevant.
        Here, we compare the embedding of the response to the embedding of the prompt-engineered "I don't know" embedding.

        set the unk_threshold to 0 to essentially turn off this feature.
        """
        response_embedding = self.get_embedding(
            completion_text,
            engine=engine,
        )
        score = cosine_similarity(response_embedding, unk_embedding)
        logger.info(f"UNK score: {score}")

        # Likely that the answer is meaningful, add the top sources
        return score < unk_threshold

    def process_input(self, user_input: str) -> Response:
        """
        Main function to process the input question and generate a formatted output.
        """

        logger.info(f"User Input:\n{user_input}")

        # We make sure there is always a newline at the end of the question to avoid completing the question.
        if not user_input.endswith("\n"):
            user_input += "\n"

        matched_documents = self.rank_documents(
            query=user_input,
            top_k=self.retriever_cfg["top_k"],
            thresh=self.retriever_cfg["thresh"],
            engine=self.embedding_model,
            source=self.document_source,
        )

        if len(matched_documents) == 0:
            logger.warning("No documents found...")
            completion = Completion(text="No documents found.")
            matched_documents = pd.DataFrame(columns=matched_documents.columns)
            response = Response(completion=completion, matched_documents=matched_documents, is_relevant=False)
            return response

        # prepare the prompt
        system_prompt = self.prompt_formatter.format(matched_documents)
        completion: Completion = self.completer.generate_response(user_input=user_input, system_prompt=system_prompt)
        logger.info(f"GPT Response:\n{completion.text}")

        # check for relevance
        is_relevant = self.check_response_relevance(
            completion_text=completion.text,
            engine=self.embedding_model,
            unk_embedding=self.unk_embedding,
            unk_threshold=self.unknown_threshold,
        )
        if not is_relevant:
            matched_documents = pd.DataFrame(columns=matched_documents.columns)
            # answer generated was the chatbot saying it doesn't know how to answer
        # uncomment override completion with unknown prompt
        # completion = Completion(text=self.unknown_prompt)

        response = Response(completion=completion, matched_documents=matched_documents, is_relevant=is_relevant)
        return response