Spaces:
Runtime error
Runtime error
File size: 9,530 Bytes
e112463 2118bb2 fa9ac7e 0e4a27a e112463 2118bb2 e112463 71e7dd8 5b7d0e6 cb9fd37 e112463 2118bb2 05dabf4 1ea867d 6a09028 1ea867d 5b7d0e6 1ea867d 5b7d0e6 1ea867d 71e7dd8 1ea867d 6a09028 1ea867d 5b7d0e6 1ea867d 8b5fed9 5b7d0e6 1ea867d fa9ac7e 1ea867d fa9ac7e 5b7d0e6 e112463 fa9ac7e c5f5dc3 71e7dd8 c5f5dc3 e112463 fa9ac7e 6a09028 fa9ac7e 6a09028 fa9ac7e e112463 fa9ac7e 6a09028 fa9ac7e f97aa81 fa9ac7e 71e7dd8 f97aa81 fa9ac7e e112463 fa9ac7e e112463 6a09028 1f22b14 fa9ac7e 6a09028 fa9ac7e 6a09028 e112463 6a09028 e112463 6a09028 fa9ac7e 6a09028 fa9ac7e 6a09028 fa9ac7e 0e4a27a fa9ac7e 6a09028 fa9ac7e 0e4a27a fa9ac7e 6a09028 0e4a27a fa9ac7e 0e4a27a 6a09028 f9dc69d 6a09028 0e4a27a 5b7d0e6 0e4a27a fa9ac7e 5b7d0e6 0e4a27a e112463 0e4a27a fa9ac7e 6a09028 5b7d0e6 fa9ac7e 1a12df2 6a09028 0e4a27a fa9ac7e 5b7d0e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
import logging
import os
from dataclasses import dataclass, field
from typing import Iterable
import numpy as np
import openai
import pandas as pd
import promptlayer
from openai.embeddings_utils import cosine_similarity, get_embedding
from buster.documents import get_documents_manager_from_extension
from buster.formatter import (
Response,
ResponseFormatter,
Source,
response_formatter_factory,
)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
# Check if an API key exists for promptlayer, if it does, use it
promptlayer_api_key = os.environ.get("PROMPTLAYER_API_KEY")
if promptlayer_api_key:
logger.info("Enabling prompt layer...")
promptlayer.api_key = promptlayer_api_key
# replace openai with the promptlayer wrapper
openai = promptlayer.openai
openai.api_key = os.environ.get("OPENAI_API_KEY")
@dataclass
class ChatbotConfig:
"""Configuration object for a chatbot.
documents_csv: Path to the csv file containing the documents and their embeddings.
embedding_model: OpenAI model to use to get embeddings.
top_k: Max number of documents to retrieve, ordered by cosine similarity
thresh: threshold for cosine similarity to be considered
max_words: maximum number of words the retrieved documents can be. Will truncate otherwise.
completion_kwargs: kwargs for the OpenAI.Completion() method
separator: the separator to use, can be either "\n" or <p> depending on rendering.
response_format: the type of format to render links with, e.g. slack or markdown
unknown_prompt: Prompt to use to generate the "I don't know" embedding to compare to.
text_before_prompt: Text to prompt GPT with before the user prompt, but after the documentation.
reponse_footnote: Generic response to add the the chatbot's reply.
"""
documents_file: str = "buster/data/document_embeddings.tar.gz"
embedding_model: str = "text-embedding-ada-002"
top_k: int = 3
thresh: float = 0.7
max_words: int = 3000
unknown_threshold: float = 0.9 # set to 0 to deactivate
completion_kwargs: dict = field(
default_factory=lambda: {
"engine": "text-davinci-003",
"max_tokens": 200,
"temperature": None,
"top_p": None,
"frequency_penalty": 1,
"presence_penalty": 1,
}
)
separator: str = "\n"
response_format: str = "slack"
unknown_prompt: str = "I Don't know how to answer your question."
text_before_documents: str = "You are a chatbot answering questions.\n"
text_before_prompt: str = "Answer the following question:\n"
response_footnote: str = "I'm a bot 🤖 and not always perfect."
class Chatbot:
def __init__(self, cfg: ChatbotConfig):
# TODO: right now, the cfg is being passed as an omegaconf, is this what we want?
self.cfg = cfg
self._init_documents()
self._init_unk_embedding()
self._init_response_formatter()
def _init_response_formatter(self):
self.response_formatter = response_formatter_factory(
format=self.cfg.response_format, response_footnote=self.cfg.response_footnote
)
def _init_documents(self):
filepath = self.cfg.documents_file
logger.info(f"loading embeddings from {filepath}...")
self.documents = get_documents_manager_from_extension(filepath)(filepath)
logger.info(f"embeddings loaded.")
def _init_unk_embedding(self):
logger.info("Generating UNK embedding...")
self.unk_embedding = get_embedding(
self.cfg.unknown_prompt,
engine=self.cfg.embedding_model,
)
def rank_documents(
self,
query: str,
top_k: float,
thresh: float,
engine: str,
) -> pd.DataFrame:
"""
Compare the question to the series of documents and return the best matching documents.
"""
query_embedding = get_embedding(
query,
engine=engine,
)
matched_documents = self.documents.retrieve(query_embedding, top_k)
# log matched_documents to the console
logger.info(f"matched documents before thresh: {matched_documents}")
# filter out matched_documents using a threshold
if thresh:
matched_documents = matched_documents[matched_documents.similarity > thresh]
logger.info(f"matched documents after thresh: {matched_documents}")
return matched_documents
def prepare_documents(self, matched_documents: pd.DataFrame, max_words: int) -> str:
# gather the documents in one large plaintext variable
documents_list = matched_documents.content.to_list()
documents_str = " ".join(documents_list)
# truncate the documents to fit
# TODO: increase to actual token count
word_count = len(documents_str.split(" "))
if word_count > max_words:
logger.info("truncating documents to fit...")
documents_str = " ".join(documents_str.split(" ")[0:max_words])
logger.info(f"Documents after truncation: {documents_str}")
return documents_str
def prepare_prompt(
self,
question: str,
matched_documents: pd.DataFrame,
text_before_prompt: str,
text_before_documents: str,
) -> str:
"""
Prepare the prompt with prompt engineering.
"""
documents_str: str = self.prepare_documents(matched_documents, max_words=self.cfg.max_words)
return text_before_documents + documents_str + text_before_prompt + question
def get_gpt_response(self, **completion_kwargs) -> Response:
# Call the API to generate a response
logger.info(f"querying GPT...")
try:
response = openai.Completion.create(**completion_kwargs)
except Exception as e:
# log the error and return a generic response instead.
logger.exception("Error connecting to OpenAI API. See traceback:")
return Response("", True, "We're having trouble connecting to OpenAI right now... Try again soon!")
text = response["choices"][0]["text"]
return Response(text)
def generate_response(
self, prompt: str, matched_documents: pd.DataFrame, unknown_prompt: str
) -> tuple[Response, Iterable[Source]]:
"""
Generate a response based on the retrieved documents.
"""
if len(matched_documents) == 0:
# No matching documents were retrieved, return
sources = tuple()
return Response(unknown_prompt), sources
logger.info(f"Prompt: {prompt}")
response = self.get_gpt_response(prompt=prompt, **self.cfg.completion_kwargs)
if response:
logger.info(f"GPT Response:\n{response.text}")
relevant = self.check_response_relevance(
response=response.text,
engine=self.cfg.embedding_model,
unk_embedding=self.unk_embedding,
unk_threshold=self.cfg.unknown_threshold,
)
if relevant:
sources = (
Source(dct["source"], dct["url"], dct["similarity"])
for dct in matched_documents.to_dict(orient="records")
)
else:
# Override the answer with a generic unknown prompt, without sources.
response = Response(text=self.cfg.unknown_prompt)
sources = tuple()
return response, sources
def check_response_relevance(
self, response: str, engine: str, unk_embedding: np.array, unk_threshold: float
) -> bool:
"""Check to see if a response is relevant to the chatbot's knowledge or not.
We assume we've prompt-engineered our bot to say a response is unrelated to the context if it isn't relevant.
Here, we compare the embedding of the response to the embedding of the prompt-engineered "I don't know" embedding.
set the unk_threshold to 0 to essentially turn off this feature.
"""
response_embedding = get_embedding(
response,
engine=engine,
)
score = cosine_similarity(response_embedding, unk_embedding)
logger.info(f"UNK score: {score}")
# Likely that the answer is meaningful, add the top sources
return score < unk_threshold
def process_input(self, question: str, formatter: ResponseFormatter = None) -> str:
"""
Main function to process the input question and generate a formatted output.
"""
logger.info(f"User Question:\n{question}")
# We make sure there is always a newline at the end of the question to avoid completing the question.
if not question.endswith("\n"):
question += "\n"
matched_documents = self.rank_documents(
query=question,
top_k=self.cfg.top_k,
thresh=self.cfg.thresh,
engine=self.cfg.embedding_model,
)
prompt = self.prepare_prompt(
question=question,
matched_documents=matched_documents,
text_before_prompt=self.cfg.text_before_prompt,
text_before_documents=self.cfg.text_before_documents,
)
response, sources = self.generate_response(prompt, matched_documents, self.cfg.unknown_prompt)
return self.response_formatter(response, sources)
|