Spaces:
Runtime error
Runtime error
import logging | |
import os | |
from dataclasses import dataclass, field | |
from typing import Iterable | |
import numpy as np | |
import openai | |
import pandas as pd | |
import promptlayer | |
from openai.embeddings_utils import cosine_similarity, get_embedding | |
from buster.documents import get_documents_manager_from_extension | |
from buster.formatter import ( | |
Response, | |
ResponseFormatter, | |
Source, | |
response_formatter_factory, | |
) | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
# Check if an API key exists for promptlayer, if it does, use it | |
promptlayer_api_key = os.environ.get("PROMPTLAYER_API_KEY") | |
if promptlayer_api_key: | |
logger.info("Enabling prompt layer...") | |
promptlayer.api_key = promptlayer_api_key | |
# replace openai with the promptlayer wrapper | |
openai = promptlayer.openai | |
openai.api_key = os.environ.get("OPENAI_API_KEY") | |
class ChatbotConfig: | |
"""Configuration object for a chatbot. | |
documents_csv: Path to the csv file containing the documents and their embeddings. | |
embedding_model: OpenAI model to use to get embeddings. | |
top_k: Max number of documents to retrieve, ordered by cosine similarity | |
thresh: threshold for cosine similarity to be considered | |
max_words: maximum number of words the retrieved documents can be. Will truncate otherwise. | |
completion_kwargs: kwargs for the OpenAI.Completion() method | |
separator: the separator to use, can be either "\n" or <p> depending on rendering. | |
response_format: the type of format to render links with, e.g. slack or markdown | |
unknown_prompt: Prompt to use to generate the "I don't know" embedding to compare to. | |
text_before_prompt: Text to prompt GPT with before the user prompt, but after the documentation. | |
reponse_footnote: Generic response to add the the chatbot's reply. | |
""" | |
documents_file: str = "buster/data/document_embeddings.tar.gz" | |
embedding_model: str = "text-embedding-ada-002" | |
top_k: int = 3 | |
thresh: float = 0.7 | |
max_words: int = 3000 | |
unknown_threshold: float = 0.9 # set to 0 to deactivate | |
completion_kwargs: dict = field( | |
default_factory=lambda: { | |
"engine": "text-davinci-003", | |
"max_tokens": 200, | |
"temperature": None, | |
"top_p": None, | |
"frequency_penalty": 1, | |
"presence_penalty": 1, | |
} | |
) | |
separator: str = "\n" | |
response_format: str = "slack" | |
unknown_prompt: str = "I Don't know how to answer your question." | |
text_before_documents: str = "You are a chatbot answering questions.\n" | |
text_before_prompt: str = "Answer the following question:\n" | |
response_footnote: str = "I'm a bot 🤖 and not always perfect." | |
class Chatbot: | |
def __init__(self, cfg: ChatbotConfig): | |
# TODO: right now, the cfg is being passed as an omegaconf, is this what we want? | |
self.cfg = cfg | |
self._init_documents() | |
self._init_unk_embedding() | |
self._init_response_formatter() | |
def _init_response_formatter(self): | |
self.response_formatter = response_formatter_factory( | |
format=self.cfg.response_format, response_footnote=self.cfg.response_footnote | |
) | |
def _init_documents(self): | |
filepath = self.cfg.documents_file | |
logger.info(f"loading embeddings from {filepath}...") | |
self.documents = get_documents_manager_from_extension(filepath)(filepath) | |
logger.info(f"embeddings loaded.") | |
def _init_unk_embedding(self): | |
logger.info("Generating UNK embedding...") | |
self.unk_embedding = get_embedding( | |
self.cfg.unknown_prompt, | |
engine=self.cfg.embedding_model, | |
) | |
def rank_documents( | |
self, | |
query: str, | |
top_k: float, | |
thresh: float, | |
engine: str, | |
) -> pd.DataFrame: | |
""" | |
Compare the question to the series of documents and return the best matching documents. | |
""" | |
query_embedding = get_embedding( | |
query, | |
engine=engine, | |
) | |
matched_documents = self.documents.retrieve(query_embedding, top_k) | |
# log matched_documents to the console | |
logger.info(f"matched documents before thresh: {matched_documents}") | |
# filter out matched_documents using a threshold | |
if thresh: | |
matched_documents = matched_documents[matched_documents.similarity > thresh] | |
logger.info(f"matched documents after thresh: {matched_documents}") | |
return matched_documents | |
def prepare_documents(self, matched_documents: pd.DataFrame, max_words: int) -> str: | |
# gather the documents in one large plaintext variable | |
documents_list = matched_documents.content.to_list() | |
documents_str = " ".join(documents_list) | |
# truncate the documents to fit | |
# TODO: increase to actual token count | |
word_count = len(documents_str.split(" ")) | |
if word_count > max_words: | |
logger.info("truncating documents to fit...") | |
documents_str = " ".join(documents_str.split(" ")[0:max_words]) | |
logger.info(f"Documents after truncation: {documents_str}") | |
return documents_str | |
def prepare_prompt( | |
self, | |
question: str, | |
matched_documents: pd.DataFrame, | |
text_before_prompt: str, | |
text_before_documents: str, | |
) -> str: | |
""" | |
Prepare the prompt with prompt engineering. | |
""" | |
documents_str: str = self.prepare_documents(matched_documents, max_words=self.cfg.max_words) | |
return text_before_documents + documents_str + text_before_prompt + question | |
def get_gpt_response(self, **completion_kwargs) -> Response: | |
# Call the API to generate a response | |
logger.info(f"querying GPT...") | |
try: | |
response = openai.Completion.create(**completion_kwargs) | |
except Exception as e: | |
# log the error and return a generic response instead. | |
logger.exception("Error connecting to OpenAI API. See traceback:") | |
return Response("", True, "We're having trouble connecting to OpenAI right now... Try again soon!") | |
text = response["choices"][0]["text"] | |
return Response(text) | |
def generate_response( | |
self, prompt: str, matched_documents: pd.DataFrame, unknown_prompt: str | |
) -> tuple[Response, Iterable[Source]]: | |
""" | |
Generate a response based on the retrieved documents. | |
""" | |
if len(matched_documents) == 0: | |
# No matching documents were retrieved, return | |
sources = tuple() | |
return Response(unknown_prompt), sources | |
logger.info(f"Prompt: {prompt}") | |
response = self.get_gpt_response(prompt=prompt, **self.cfg.completion_kwargs) | |
if response: | |
logger.info(f"GPT Response:\n{response.text}") | |
relevant = self.check_response_relevance( | |
response=response.text, | |
engine=self.cfg.embedding_model, | |
unk_embedding=self.unk_embedding, | |
unk_threshold=self.cfg.unknown_threshold, | |
) | |
if relevant: | |
sources = ( | |
Source(dct["source"], dct["url"], dct["similarity"]) | |
for dct in matched_documents.to_dict(orient="records") | |
) | |
else: | |
# Override the answer with a generic unknown prompt, without sources. | |
response = Response(text=self.cfg.unknown_prompt) | |
sources = tuple() | |
return response, sources | |
def check_response_relevance( | |
self, response: str, engine: str, unk_embedding: np.array, unk_threshold: float | |
) -> bool: | |
"""Check to see if a response is relevant to the chatbot's knowledge or not. | |
We assume we've prompt-engineered our bot to say a response is unrelated to the context if it isn't relevant. | |
Here, we compare the embedding of the response to the embedding of the prompt-engineered "I don't know" embedding. | |
set the unk_threshold to 0 to essentially turn off this feature. | |
""" | |
response_embedding = get_embedding( | |
response, | |
engine=engine, | |
) | |
score = cosine_similarity(response_embedding, unk_embedding) | |
logger.info(f"UNK score: {score}") | |
# Likely that the answer is meaningful, add the top sources | |
return score < unk_threshold | |
def process_input(self, question: str, formatter: ResponseFormatter = None) -> str: | |
""" | |
Main function to process the input question and generate a formatted output. | |
""" | |
logger.info(f"User Question:\n{question}") | |
# We make sure there is always a newline at the end of the question to avoid completing the question. | |
if not question.endswith("\n"): | |
question += "\n" | |
matched_documents = self.rank_documents( | |
query=question, | |
top_k=self.cfg.top_k, | |
thresh=self.cfg.thresh, | |
engine=self.cfg.embedding_model, | |
) | |
prompt = self.prepare_prompt( | |
question=question, | |
matched_documents=matched_documents, | |
text_before_prompt=self.cfg.text_before_prompt, | |
text_before_documents=self.cfg.text_before_documents, | |
) | |
response, sources = self.generate_response(prompt, matched_documents, self.cfg.unknown_prompt) | |
return self.response_formatter(response, sources) | |