Spaces:
Runtime error
Runtime error
File size: 8,459 Bytes
e112463 fa9ac7e c6dd20e e112463 c8a1687 5b7d0e6 cb9fd37 e112463 05dabf4 1ea867d c8a1687 1ea867d 6a09028 1ea867d 5b7d0e6 1ea867d 5b7d0e6 c6dd20e 1ea867d 71e7dd8 1ea867d 6a09028 c8a1687 1ea867d c8a1687 1ea867d 5b7d0e6 1ea867d 5b7d0e6 c6dd20e 06bca0c 1ea867d c8a1687 06bca0c c6dd20e fa9ac7e c6dd20e 06bca0c c6dd20e 5b7d0e6 c6dd20e 5b7d0e6 c6dd20e e112463 c6dd20e e112463 fa9ac7e 6a09028 c6dd20e fa9ac7e f97aa81 c6dd20e fa9ac7e 06bca0c f97aa81 fa9ac7e e112463 fa9ac7e e112463 6a09028 1f22b14 c6dd20e 6a09028 fa9ac7e 6a09028 e112463 6a09028 e112463 c8a1687 6a09028 c8a1687 c6dd20e c8a1687 fa9ac7e 6a09028 c8a1687 6a09028 c6dd20e c8a1687 6a09028 c8a1687 fa9ac7e c8a1687 fa9ac7e 1a12df2 c8a1687 1a12df2 6a09028 c8a1687 6a09028 c6dd20e 6a09028 c8a1687 c6dd20e c8a1687 c6dd20e 463196d c8a1687 6a09028 c8a1687 fa9ac7e 5b7d0e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import logging
from dataclasses import dataclass, field
from functools import lru_cache
import numpy as np
import pandas as pd
from openai.embeddings_utils import cosine_similarity, get_embedding
from buster.completers import get_completer
from buster.formatter import (
Response,
ResponseFormatter,
Source,
response_formatter_factory,
)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
@dataclass
class BusterConfig:
"""Configuration object for a chatbot.
documents_csv: Path to the csv file containing the documents and their embeddings.
embedding_model: OpenAI model to use to get embeddings.
top_k: Max number of documents to retrieve, ordered by cosine similarity
thresh: threshold for cosine similarity to be considered
max_words: maximum number of words the retrieved documents can be. Will truncate otherwise.
completion_kwargs: kwargs for the OpenAI.Completion() method
separator: the separator to use, can be either "\n" or <p> depending on rendering.
response_format: the type of format to render links with, e.g. slack or markdown
unknown_prompt: Prompt to use to generate the "I don't know" embedding to compare to.
text_before_prompt: Text to prompt GPT with before the user prompt, but after the documentation.
reponse_footnote: Generic response to add the the chatbot's reply.
source: the source of the document to consider
"""
documents_file: str = "buster/data/document_embeddings.tar.gz"
embedding_model: str = "text-embedding-ada-002"
top_k: int = 3
thresh: float = 0.7
max_words: int = 3000
unknown_threshold: float = 0.9 # set to 0 to deactivate
completer_cfg: dict = field(
# TODO: Put all this in its own config with sane defaults?
default_factory=lambda: {
"name": "GPT3",
"text_before_documents": "You are a chatbot answering questions.\n",
"text_before_prompt": "Answer the following question:\n",
"completion_kwargs": {
"engine": "text-davinci-003",
"max_tokens": 200,
"temperature": None,
"top_p": None,
"frequency_penalty": 1,
"presence_penalty": 1,
},
}
)
response_format: str = "slack"
unknown_prompt: str = "I Don't know how to answer your question."
response_footnote: str = "I'm a bot 🤖 and not always perfect."
source: str = ""
from buster.retriever import Retriever
class Buster:
def __init__(self, cfg: BusterConfig, retriever: Retriever):
self._unk_embedding = None
self.cfg = cfg
self.update_cfg(cfg)
self.retriever = retriever
@property
def unk_embedding(self):
return self._unk_embedding
@unk_embedding.setter
def unk_embedding(self, embedding):
logger.info("Setting new UNK embedding...")
self._unk_embedding = embedding
return self._unk_embedding
def update_cfg(self, cfg: BusterConfig):
"""Every time we set a new config, we update the things that need to be updated."""
logger.info(f"Updating config to {cfg.source}:\n{cfg}")
self.cfg = cfg
self.completer = get_completer(cfg.completer_cfg)
self.unk_embedding = self.get_embedding(self.cfg.unknown_prompt, engine=self.cfg.embedding_model)
self.response_formatter = response_formatter_factory(
format=self.cfg.response_format, response_footnote=self.cfg.response_footnote
)
logger.info(f"Config Updated.")
@lru_cache
def get_embedding(self, query: str, engine: str):
logger.info("generating embedding")
return get_embedding(query, engine=engine)
def rank_documents(
self,
query: str,
top_k: float,
thresh: float,
engine: str,
source: str,
) -> pd.DataFrame:
"""
Compare the question to the series of documents and return the best matching documents.
"""
query_embedding = self.get_embedding(
query,
engine=engine,
)
matched_documents = self.retriever.retrieve(query_embedding, top_k=top_k, source=source)
# log matched_documents to the console
logger.info(f"matched documents before thresh: {matched_documents}")
# filter out matched_documents using a threshold
if thresh:
matched_documents = matched_documents[matched_documents.similarity > thresh]
logger.info(f"matched documents after thresh: {matched_documents}")
return matched_documents
def prepare_documents(self, matched_documents: pd.DataFrame, max_words: int) -> str:
# gather the documents in one large plaintext variable
documents_list = matched_documents.content.to_list()
documents_str = ""
for idx, doc in enumerate(documents_list):
documents_str += f"<DOCUMENT> {doc} <\DOCUMENT>"
# truncate the documents to fit
# TODO: increase to actual token count
word_count = len(documents_str.split(" "))
if word_count > max_words:
logger.info("truncating documents to fit...")
documents_str = " ".join(documents_str.split(" ")[0:max_words])
logger.info(f"Documents after truncation: {documents_str}")
return documents_str
def add_sources(
self,
matched_documents: pd.DataFrame,
):
sources = (
Source(
source=dct["source"], title=dct["title"], url=dct["url"], question_similarity=dct["similarity"] * 100
)
for dct in matched_documents.to_dict(orient="records")
)
return sources
def check_response_relevance(
self, completion: str, engine: str, unk_embedding: np.array, unk_threshold: float
) -> bool:
"""Check to see if a response is relevant to the chatbot's knowledge or not.
We assume we've prompt-engineered our bot to say a response is unrelated to the context if it isn't relevant.
Here, we compare the embedding of the response to the embedding of the prompt-engineered "I don't know" embedding.
set the unk_threshold to 0 to essentially turn off this feature.
"""
response_embedding = self.get_embedding(
completion,
engine=engine,
)
score = cosine_similarity(response_embedding, unk_embedding)
logger.info(f"UNK score: {score}")
# Likely that the answer is meaningful, add the top sources
return score < unk_threshold
def process_input(self, user_input: str, formatter: ResponseFormatter = None) -> str:
"""
Main function to process the input question and generate a formatted output.
"""
logger.info(f"User Input:\n{user_input}")
# We make sure there is always a newline at the end of the question to avoid completing the question.
if not user_input.endswith("\n"):
user_input += "\n"
matched_documents = self.rank_documents(
query=user_input,
top_k=self.cfg.top_k,
thresh=self.cfg.thresh,
engine=self.cfg.embedding_model,
source=self.cfg.source,
)
if len(matched_documents) == 0:
response = Response(self.cfg.unknown_prompt)
sources = tuple()
return self.response_formatter(response, sources)
# generate a completion
documents: str = self.prepare_documents(matched_documents, max_words=self.cfg.max_words)
response: Response = self.completer.generate_response(user_input, documents)
logger.info(f"GPT Response:\n{response.text}")
sources = self.add_sources(matched_documents)
# check for relevance
relevant = self.check_response_relevance(
completion=response.text,
engine=self.cfg.embedding_model,
unk_embedding=self.unk_embedding,
unk_threshold=self.cfg.unknown_threshold,
)
if not relevant:
# answer generated was the chatbot saying it doesn't know how to answer
# override completion with generic "I don't know"
response = Response(text=self.cfg.unknown_prompt)
sources = tuple()
return self.response_formatter(response, sources)
|