PoliticsToYou / src /chatbot.py
TomData's picture
upload refactored code to exclude small chunks without data files
38166c5
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.llms.huggingface_hub import HuggingFaceHub
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain_community.docstore.in_memory import InMemoryDocstore
from faiss import IndexFlatL2
#import functools
import pandas as pd
# Load environmental variables from .env-file
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())
# Define important variables
embeddings = HuggingFaceEmbeddings(model_name="paraphrase-multilingual-MiniLM-L12-v2") # Remove embedding input parameter from functions?
llm = HuggingFaceHub(
# ToDo: Try different models here
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
# repo_id="CohereForAI/c4ai-command-r-v01", # too large 69gb
# repo_id="CohereForAI/c4ai-command-r-v01-4bit", # too large 22gb
# repo_id="meta-llama/Meta-Llama-3-8B", # too large 16 gb
task="text-generation",
model_kwargs={
"max_new_tokens": 512,
"top_k": 30,
"temperature": 0.1,
"repetition_penalty": 1.03,
}
)
# ToDo: Experiment with different templates
prompt_test = ChatPromptTemplate.from_template("""<s>[INST]
Instruction: Beantworte die folgende Frage auf deutsch und nur auf der Grundlage des angegebenen Kontexts:
Context: {context}
Question: {input}
[/INST]"""
)
prompt_de = ChatPromptTemplate.from_template("""Beantworte die folgende Frage auf deutsch und nur auf der Grundlage des angegebenen Kontexts:
<context>
{context}
</context>
Frage: {input}
"""
# Returns the answer in German
)
prompt_en = ChatPromptTemplate.from_template("""Answer the following question in English and solely based on the provided context:
<context>
{context}
</context>
Question: {input}
"""
# Returns the answer in English
)
db_all = FAISS.load_local(folder_path="./src/FAISS", index_name="speeches_1949_09_12",
embeddings=embeddings, allow_dangerous_deserialization=True)
def get_vectorstore(inputs, embeddings):
"""
Combine multiple FAISS vector stores into a single vector store based on the specified inputs.
Parameters
----------
inputs : list of str
A list of strings specifying which vector stores to combine. Each string represents a specific
index or a special keyword "All". If "All" is the first entry in the list,
it directly return the pre-defined vectorstore for all speeches
embeddings : Embeddings
An instance of embeddings that will be used to load the vector stores. The specific type and
structure of `embeddings` depend on the implementation of the `get_vectorstore` function.
Returns
-------
FAISS
A FAISS vector store that combines the specified indices into a single vector store.
"""
# Default folder path
folder_path = "./src/FAISS"
if inputs[0] == "All" or inputs[0] is None:
return db_all
# Initialize empty db
embedding_function = embeddings
dimensions = len(embedding_function.embed_query("dummy"))
db = FAISS(
embedding_function=embedding_function,
index=IndexFlatL2(dimensions),
docstore=InMemoryDocstore(),
index_to_docstore_id={},
normalize_L2=False
)
# Retrieve inputs: 20. Legislaturperiode, 19. Legislaturperiode, ...
for input in inputs:
# Ignore if user also selected All among other legislatures
if input == "All":
continue
# Retrieve selected index and merge vector stores
index = input.split(".")[0]
index_name = f'{index}_legislature'
local_db = FAISS.load_local(folder_path=folder_path, index_name=index_name,
embeddings=embeddings, allow_dangerous_deserialization=True)
db.merge_from(local_db)
print('Successfully merged inputs')
return db
def RAG(llm, prompt, db, question):
"""
Apply Retrieval-Augmented Generation (RAG) by providing the context and the question to the
language model using a predefined template.
Parameters:
----------
llm : LanguageModel
An instance of the language model to be used for generating responses.
prompt : str
A predefined template or prompt that structures how the context and question are presented to the language model.
db : VectorStore
A vector store instance that supports retrieval of relevant documents based on the input question.
question : str
The question or query to be answered by the language model.
Returns:
-------
str
The response generated by the language model, based on the retrieved context and provided question.
"""
# Create a document chain using the provided language model and prompt template
document_chain = create_stuff_documents_chain(llm=llm, prompt=prompt)
# Convert the vector store into a retriever
retriever = db.as_retriever()
# Create a retrieval chain that integrates the retriever with the document chain
retrieval_chain = create_retrieval_chain(retriever, document_chain)
# Invoke the retrieval chain with the input question to get the final response
response = retrieval_chain.invoke({"input": question})
return response
def chatbot(message, history, db_inputs, prompt_language, llm=llm):
"""
Generate a response from the chatbot based on the provided message, history, database inputs, prompt language, and LLM model.
Parameters:
-----------
message : str
The message or question to be answered by the chatbot.
history : list
The history of previous interactions or messages.
db_inputs : list
A list of strings specifying which vector stores to combine. Each string represents a specific index or a special keyword "All".
prompt_language : str
The language of the prompt to be used for generating the response. Should be either "DE" for German or "EN" for English.
llm : LLM, optional
An instance of the Language Model to be used for generating the response. Defaults to the global variable `llm`.
Returns:
--------
str
The response generated by the chatbot.
"""
db = get_vectorstore(inputs = db_inputs, embeddings=embeddings)
# Select prompt based on user input
if prompt_language == "DE":
prompt = prompt_de
raw_response = RAG(llm=llm, prompt=prompt, db=db, question=message)
# Only necessary because mistral does include it´s json structure in the output including its input content
try:
response = raw_response['answer'].split("Antwort: ")[1]
except:
response = raw_response['answer']
return response
else:
prompt = prompt_en
raw_response = RAG(llm=llm, prompt=prompt, db=db, question=message)
# Only necessary because mistral does include it´s json structure in the output including its input content
try:
response = raw_response['answer'].split("Answer: ")[1]
except:
response = raw_response['answer']
return response
def keyword_search(query, n=10, embeddings=embeddings, method="ss", party_filter="All"):
"""
Retrieve speech contents based on keywords using a specified method.
Parameters:
----------
db : FAISS
The FAISS vector store containing speech embeddings.
query : str
The keyword(s) to search for in the speech contents.
n : int, optional
The number of speech contents to retrieve (default is 10).
embeddings : Embeddings, optional
An instance of embeddings used for embedding queries (default is embeddings).
method : str, optional
The method used for retrieving speech contents. Options are 'ss' (semantic search) and 'mmr'
(maximal marginal relevance) (default is 'ss').
party_filter : str, optional
A filter for retrieving speech contents by party affiliation. Specify 'All' to retrieve
speeches from all parties (default is 'All').
Returns:
-------
pandas.DataFrame
A DataFrame containing the speech contents, dates, and party affiliations.
Notes:
-----
- The `db` parameter should be a FAISS vector store containing speech embeddings.
- The `query` parameter specifies the keyword(s) to search for in the speech contents.
- The `n` parameter determines the number of speech contents to retrieve (default is 10).
- The `embeddings` parameter is an instance of embeddings used for embedding queries (default is embeddings).
- The `method` parameter specifies the method used for retrieving speech contents. Options are 'ss' (semantic search)
and 'mmr' (maximal marginal relevance) (default is 'ss').
- The `party_filter` parameter is a filter for retrieving speech contents by party affiliation. Specify 'All' to retrieve
speeches from all parties (default is 'All').
"""
db = get_vectorstore(inputs=["All"], embeddings=embeddings)
query_embedding = embeddings.embed_query(query)
# Maximal Marginal Relevance
if method == "mmr":
df_res = pd.DataFrame(columns=['Speech Content', 'Date', 'Party', 'Relevance'])
results = db.max_marginal_relevance_search_with_score_by_vector(query_embedding, k=n)
for doc in results:
party = doc[0].metadata["party"]
if party != party_filter and party_filter != 'All':
continue
speech_content = doc[0].page_content
speech_date = doc[0].metadata["date"]
score = round(doc[1], ndigits=2)
df_res = pd.concat([df_res, pd.DataFrame({'Speech Content': [speech_content],
'Date': [speech_date],
'Party': [party],
'Relevance': [score]})], ignore_index=True)
df_res.sort_values('Relevance', inplace=True, ascending=True)
# Similarity Search
elif method == "ss":
kws_data = []
results = db.similarity_search_by_vector(query_embedding, k=n)
for doc in results:
party = doc.metadata["party"]
if party != party_filter and party_filter != 'All':
continue
speech_content = doc.page_content
speech_date = doc.metadata["date"]
speech_date = speech_date.strftime("%Y-%m-%d")
print(speech_date)
# Error here?
kws_entry = {'Speech Content': speech_content,
'Date': speech_date,
'Party': party}
kws_data.append(kws_entry)
df_res = pd.DataFrame(kws_data)
return df_res