File size: 8,293 Bytes
e112463
fa9ac7e
e112463
 
 
 
fa9ac7e
e112463
 
eec81fa
05dabf4
e112463
 
 
05dabf4
fa9ac7e
 
 
 
 
 
e112463
f97aa81
fa9ac7e
 
 
 
 
 
e112463
fa9ac7e
 
e112463
fa9ac7e
 
 
 
 
 
 
 
e112463
fa9ac7e
 
 
 
 
 
 
 
 
 
 
f97aa81
fa9ac7e
 
 
 
 
f97aa81
fa9ac7e
 
f97aa81
fa9ac7e
 
 
f97aa81
fa9ac7e
 
e112463
fa9ac7e
 
 
 
 
 
e112463
fa9ac7e
 
 
e112463
f97aa81
fa9ac7e
 
e112463
fa9ac7e
 
 
 
 
e112463
fa9ac7e
e112463
fa9ac7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e112463
fa9ac7e
 
 
 
 
 
 
 
e112463
fa9ac7e
 
 
 
 
 
e112463
fa9ac7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import logging
from dataclasses import dataclass, field

import numpy as np
import openai
import pandas as pd
from omegaconf import OmegaConf
from openai.embeddings_utils import cosine_similarity, get_embedding

from buster.docparser import EMBEDDING_MODEL

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


def load_documents(path: str) -> pd.DataFrame:
    logger.info(f"loading embeddings from {path}...")
    df = pd.read_csv(path)
    df["embedding"] = df.embedding.apply(eval).apply(np.array)
    logger.info(f"embeddings loaded.")
    return df


class Chatbot:
    def __init__(self, cfg: OmegaConf):
        # TODO: right now, the cfg is being passed as an omegaconf, is this what we want?
        self.cfg = cfg
        self._init_documents()
        self._init_unk_embedding()

    def _init_documents(self):
        self.documents = load_documents(self.cfg.documents_csv)

    def _init_unk_embedding(self):
        logger.info("Generating UNK token...")
        unknown_prompt = self.cfg.unknown_prompt
        engine = self.cfg.embedding_model
        self.unk_embedding = get_embedding(
            unknown_prompt,
            engine=engine,
        )

    def rank_documents(
        self,
        documents: pd.DataFrame,
        query: str,
    ) -> pd.DataFrame:
        """
        Compare the question to the series of documents and return the best matching documents.
        """
        top_k = self.cfg.top_k
        thresh = self.cfg.thresh
        engine = self.cfg.embedding_model  # EMBEDDING_MODEL

        query_embedding = get_embedding(
            query,
            engine=engine,
        )
        documents["similarity"] = documents.embedding.apply(lambda x: cosine_similarity(x, query_embedding))

        # sort the matched_documents by score
        matched_documents = documents.sort_values("similarity", ascending=False)

        # limit search to top_k matched_documents.
        top_k = len(matched_documents) if top_k == -1 else top_k
        matched_documents = matched_documents.head(top_k)

        # log matched_documents to the console
        logger.info(f"matched documents before thresh: {matched_documents}")

        # filter out matched_documents using a threshold
        if thresh:
            matched_documents = matched_documents[matched_documents.similarity > thresh]
            logger.info(f"matched documents after thresh: {matched_documents}")

        return matched_documents

    def prepare_prompt(self, question: str, candidates: pd.DataFrame) -> str:
        """
        Prepare the prompt with prompt engineering.
        """

        max_chars = self.cfg.max_chars
        text_before_prompt = self.cfg.text_before_prompt

        documents_list = candidates.text.to_list()
        documents_str = " ".join(documents_list)
        if len(documents_str) > max_chars:
            logger.info("truncating documents to fit...")
            documents_str = documents_str[0:max_chars]

        return documents_str + text_before_prompt + question

    def generate_response(self, prompt: str, matched_documents: pd.DataFrame) -> str:
        """
        Generate a response based on the retrieved documents.
        """
        if len(matched_documents) == 0:
            # No matching documents were retrieved, return
            response_text = "I did not find any relevant documentation related to your question."
            return response_text

        logger.info(f"querying GPT...")
        # Call the API to generate a response
        try:
            completion_kwargs = self.cfg.completion_kwargs
            completion_kwargs["prompt"] = prompt
            response = openai.Completion.create(**completion_kwargs)

            # Get the response text
            response_text = response["choices"][0]["text"]
            logger.info(f"GPT Response:\n{response_text}")
            return response_text

        except Exception as e:
            # log the error and return a generic response instead.
            import traceback

            logging.error(traceback.format_exc())
            response_text = "Oops, something went wrong. Try again later!"
            return response_text

    def add_sources(self, response: str, matched_documents: pd.DataFrame):
        """
        Add sources fromt the matched documents to the response.
        """
        sep = self.cfg.separator  # \n
        format = self.cfg.link_format

        urls = matched_documents.url.to_list()
        names = matched_documents.name.to_list()
        similarities = matched_documents.similarity.to_list()

        response += f"{sep}{sep}Here are the sources I used to answer your question:\n"
        for url, name, similarity in zip(urls, names, similarities):
            if format == "markdown":
                response += f"{sep}[{name}]({url}){sep}"
            elif format == "slack":
                response += f"• <{url}|{name}>, score: {similarity:2.3f}{sep}"
            else:
                raise ValueError(f"{format} is not a valid URL format.")

        return response

    def format_response(self, response: str, matched_documents: pd.DataFrame) -> str:
        """
        Format the response by adding the sources if necessary, and a disclaimer prompt.
        """

        sep = self.cfg.separator
        text_after_response = self.cfg.text_after_response

        if len(matched_documents) > 0:
            # we have matched documents, now we check to see if the answer is meaningful
            response_embedding = get_embedding(
                response,
                engine=EMBEDDING_MODEL,
            )
            score = cosine_similarity(response_embedding, self.unk_embedding)
            logger.info(f"UNK score: {score}")
            if score < 0.9:
                # Liekly that the answer is meaningful, add the top sources
                response = self.add_sources(response, matched_documents=matched_documents)

        response += f"{sep}{sep}{sep}{text_after_response}{sep}"

        return response

    def process_input(self, question: str) -> str:
        """
        Main function to process the input question and generate a formatted output.
        """

        logger.info(f"User Question:\n{question}")

        matched_documents = self.rank_documents(documents=self.documents, query=question)
        prompt = self.prepare_prompt(question, matched_documents)
        response = self.generate_response(prompt, matched_documents)
        formatted_output = self.format_response(response, matched_documents)

        return formatted_output


@dataclass
class ChatbotConfig:
    """Configuration object for a chatbot.

    documents_csv: Path to the csv file containing the documents and their embeddings.
    embedding_model: OpenAI model to use to get embeddings.
    top_k: Max number of documents to retrieve, ordered by cosine similarity
    thresh: threshold for cosine similarity to be considered
    max_chars: maximum number of characters the retrieved documents can be. Will truncate otherwise.
    completion_kwargs: kwargs for the OpenAI.Completion() method
    separator: the separator to use, can be either "\n" or <p> depending on rendering.
    link_format: the type of format to render links with, e.g. slack or markdown
    unknown_prompt: Prompt to use to generate the "I don't know" embedding to compare to.
    text_before_prompt: Text to prompt GPT with before the user prompt, but after the documentation.
    text_after_response: Generic response to add the the chatbot's reply.
    """

    documents_csv: str = "buster/data/document_embeddings.csv"
    embedding_model: str = "text-embedding-ada-002"
    top_k: int = 3
    thresh: float = 0.7
    max_chars: int = 3000

    completion_kwargs: dict = field(
        default_factory=lambda: {
            "engine": "text-davinci-003",
            "max_tokens": 200,
            "temperature": None,
            "top_p": None,
            "frequency_penalty": 1,
            "presence_penalty": 1,
        }
    )
    separator: str = "\n"
    link_format: str = "slack"
    unknown_prompt: str = "I Don't know how to answer your question."
    text_before_prompt: str = "I'm a chatbot, bleep bloop."
    text_after_response: str = "Answer the following question:\n"