Spaces:
Runtime error
Runtime error
File size: 7,051 Bytes
e112463 fa9ac7e c6dd20e e112463 d16a006 6a4ac5a d16a006 cb9fd37 e112463 05dabf4 6a4ac5a 6008655 6a4ac5a 1ea867d c8a1687 d16a006 1ea867d d16a006 1ea867d d16a006 c8a1687 d16a006 c8a1687 d16a006 c8a1687 1ea867d c8a1687 06bca0c c6dd20e 06bca0c c6dd20e 5b7d0e6 c6dd20e d16a006 6a4ac5a c6dd20e e112463 c6dd20e e112463 fa9ac7e 6a09028 c6dd20e fa9ac7e f97aa81 c6dd20e fa9ac7e 06bca0c f97aa81 fa9ac7e e112463 fa9ac7e d16a006 fa9ac7e e112463 6a09028 6a4ac5a 6a09028 c6dd20e 6a4ac5a 6a09028 25a0d11 fa9ac7e c8a1687 fa9ac7e 1a12df2 c8a1687 1a12df2 6a09028 c8a1687 d16a006 6a09028 c8a1687 6a4ac5a 6008655 6a4ac5a 463196d 6a4ac5a c8a1687 6008655 6a4ac5a d16a006 c8a1687 d16a006 6a09028 6008655 6a4ac5a c8a1687 6a4ac5a d16a006 fa9ac7e 6008655 6a4ac5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import logging
from dataclasses import dataclass, field
from functools import lru_cache
import numpy as np
import pandas as pd
from openai.embeddings_utils import cosine_similarity, get_embedding
from buster.completers import completer_factory
from buster.completers.base import Completion
from buster.formatters.prompts import SystemPromptFormatter, prompt_formatter_factory
from buster.retriever import Retriever
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
@dataclass(slots=True)
class Response:
completion: Completion
is_relevant: bool
matched_documents: pd.DataFrame | None = None
@dataclass
class BusterConfig:
"""Configuration object for a chatbot."""
embedding_model: str = "text-embedding-ada-002"
unknown_threshold: float = 0.9
unknown_prompt: str = "I Don't know how to answer your question."
document_source: str = ""
retriever_cfg: dict = field(
default_factory=lambda: {
"top_k": 3,
"thresh": 0.7,
}
)
prompt_cfg: dict = field(
default_factory=lambda: {
"max_words": 3000,
"text_before_documents": "You are a chatbot answering questions.\n",
"text_before_prompt": "Answer the following question:\n",
}
)
completion_cfg: dict = field(
default_factory=lambda: {
"name": "ChatGPT",
"completion_kwargs": {
"engine": "gpt-3.5-turbo",
"max_tokens": 200,
"temperature": None,
"top_p": None,
"frequency_penalty": 1,
"presence_penalty": 1,
},
}
)
class Buster:
def __init__(self, cfg: BusterConfig, retriever: Retriever):
self._unk_embedding = None
self.update_cfg(cfg)
self.retriever = retriever
@property
def unk_embedding(self):
return self._unk_embedding
@unk_embedding.setter
def unk_embedding(self, embedding):
logger.info("Setting new UNK embedding...")
self._unk_embedding = embedding
return self._unk_embedding
def update_cfg(self, cfg: BusterConfig):
"""Every time we set a new config, we update the things that need to be updated."""
logger.info(f"Updating config to {cfg.document_source}:\n{cfg}")
self._cfg = cfg
self.embedding_model = cfg.embedding_model
self.unknown_threshold = cfg.unknown_threshold
self.unknown_prompt = cfg.unknown_prompt
self.document_source = cfg.document_source
self.retriever_cfg = cfg.retriever_cfg
self.completion_cfg = cfg.completion_cfg
self.prompt_cfg = cfg.prompt_cfg
# set the unk. embedding
self.unk_embedding = self.get_embedding(self.unknown_prompt, engine=self.embedding_model)
# update completer and formatter cfg
self.completer = completer_factory(self.completion_cfg)
self.prompt_formatter = prompt_formatter_factory(self.prompt_cfg)
logger.info(f"Config Updated.")
@lru_cache
def get_embedding(self, query: str, engine: str):
logger.info("generating embedding")
return get_embedding(query, engine=engine)
def rank_documents(
self,
query: str,
top_k: float,
thresh: float,
engine: str,
source: str,
) -> pd.DataFrame:
"""
Compare the question to the series of documents and return the best matching documents.
"""
query_embedding = self.get_embedding(
query,
engine=engine,
)
matched_documents = self.retriever.retrieve(query_embedding, top_k=top_k, source=source)
# log matched_documents to the console
logger.info(f"matched documents before thresh: {matched_documents}")
# filter out matched_documents using a threshold
matched_documents = matched_documents[matched_documents.similarity > thresh]
logger.info(f"matched documents after thresh: {matched_documents}")
return matched_documents
def check_response_relevance(
self, completion_text: str, engine: str, unk_embedding: np.array, unk_threshold: float
) -> bool:
"""Check to see if a response is relevant to the chatbot's knowledge or not.
We assume we've prompt-engineered our bot to say a response is unrelated to the context if it isn't relevant.
Here, we compare the embedding of the response to the embedding of the prompt-engineered "I don't know" embedding.
set the unk_threshold to 0 to essentially turn off this feature.
"""
response_embedding = self.get_embedding(
completion_text,
engine=engine,
)
score = cosine_similarity(response_embedding, unk_embedding)
logger.info(f"UNK score: {score}")
# Likely that the answer is meaningful, add the top sources
return score < unk_threshold
def process_input(self, user_input: str) -> Response:
"""
Main function to process the input question and generate a formatted output.
"""
logger.info(f"User Input:\n{user_input}")
# We make sure there is always a newline at the end of the question to avoid completing the question.
if not user_input.endswith("\n"):
user_input += "\n"
matched_documents = self.rank_documents(
query=user_input,
top_k=self.retriever_cfg["top_k"],
thresh=self.retriever_cfg["thresh"],
engine=self.embedding_model,
source=self.document_source,
)
if len(matched_documents) == 0:
logger.warning("No documents found...")
completion = Completion(text="No documents found.")
matched_documents = pd.DataFrame(columns=matched_documents.columns)
response = Response(completion=completion, matched_documents=matched_documents, is_relevant=False)
return response
# prepare the prompt
system_prompt = self.prompt_formatter.format(matched_documents)
completion: Completion = self.completer.generate_response(user_input=user_input, system_prompt=system_prompt)
logger.info(f"GPT Response:\n{completion.text}")
# check for relevance
is_relevant = self.check_response_relevance(
completion_text=completion.text,
engine=self.embedding_model,
unk_embedding=self.unk_embedding,
unk_threshold=self.unknown_threshold,
)
if not is_relevant:
matched_documents = pd.DataFrame(columns=matched_documents.columns)
# answer generated was the chatbot saying it doesn't know how to answer
# uncomment override completion with unknown prompt
# completion = Completion(text=self.unknown_prompt)
response = Response(completion=completion, matched_documents=matched_documents, is_relevant=is_relevant)
return response
|