Spaces:
Runtime error
Runtime error
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_community.llms.huggingface_hub import HuggingFaceHub | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain.chains import create_retrieval_chain | |
| from langchain_community.docstore.in_memory import InMemoryDocstore | |
| from faiss import IndexFlatL2 | |
| #import functools | |
| import pandas as pd | |
| # Load environmental variables from .env-file | |
| from dotenv import load_dotenv, find_dotenv | |
| load_dotenv(find_dotenv()) | |
| # Define important variables | |
| embeddings = HuggingFaceEmbeddings(model_name="paraphrase-multilingual-MiniLM-L12-v2") # Remove embedding input parameter from functions? | |
| llm = HuggingFaceHub( | |
| # ToDo: Try different models here | |
| # repo_id="mistralai/Mistral-7B-Instruct-v0.3", | |
| repo_id = "mistralai/Ministral-8B-Instruct-2410", | |
| #repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| # repo_id="CohereForAI/c4ai-command-r-v01", # too large 69gb | |
| # repo_id="CohereForAI/c4ai-command-r-v01-4bit", # too large 22gb | |
| # repo_id="meta-llama/Meta-Llama-3-8B", # too large 16 gb | |
| task="text-generation", | |
| model_kwargs={ | |
| "max_new_tokens": 512, | |
| "top_k": 30, | |
| "temperature": 0.1, | |
| "repetition_penalty": 1.03, | |
| } | |
| ) | |
| # ToDo: Experiment with different templates | |
| prompt_test = ChatPromptTemplate.from_template("""<s>[INST] | |
| Instruction: Beantworte die folgende Frage auf deutsch und nur auf der Grundlage des angegebenen Kontexts: | |
| Context: {context} | |
| Question: {input} | |
| [/INST]""" | |
| ) | |
| prompt_de = ChatPromptTemplate.from_template("""Beantworte die folgende Frage auf deutsch und nur auf der Grundlage des angegebenen Kontexts: | |
| <context> | |
| {context} | |
| </context> | |
| Frage: {input} | |
| """ | |
| # Returns the answer in German | |
| ) | |
| prompt_en = ChatPromptTemplate.from_template("""Answer the following question in English and solely based on the provided context: | |
| <context> | |
| {context} | |
| </context> | |
| Question: {input} | |
| """ | |
| # Returns the answer in English | |
| ) | |
| db_all = FAISS.load_local(folder_path="./src/FAISS", index_name="speeches_1949_09_12", | |
| embeddings=embeddings, allow_dangerous_deserialization=True) | |
| def get_vectorstore(inputs, embeddings): | |
| """ | |
| Combine multiple FAISS vector stores into a single vector store based on the specified inputs. | |
| Parameters | |
| ---------- | |
| inputs : list of str | |
| A list of strings specifying which vector stores to combine. Each string represents a specific | |
| index or a special keyword "All". If "All" is the first entry in the list, | |
| it directly return the pre-defined vectorstore for all speeches | |
| embeddings : Embeddings | |
| An instance of embeddings that will be used to load the vector stores. The specific type and | |
| structure of `embeddings` depend on the implementation of the `get_vectorstore` function. | |
| Returns | |
| ------- | |
| FAISS | |
| A FAISS vector store that combines the specified indices into a single vector store. | |
| """ | |
| # Default folder path | |
| folder_path = "./src/FAISS" | |
| if inputs[0] == "All" or inputs[0] is None: | |
| return db_all | |
| # Initialize empty db | |
| embedding_function = embeddings | |
| dimensions = len(embedding_function.embed_query("dummy")) | |
| db = FAISS( | |
| embedding_function=embedding_function, | |
| index=IndexFlatL2(dimensions), | |
| docstore=InMemoryDocstore(), | |
| index_to_docstore_id={}, | |
| normalize_L2=False | |
| ) | |
| # Retrieve inputs: 20. Legislaturperiode, 19. Legislaturperiode, ... | |
| for input in inputs: | |
| # Ignore if user also selected All among other legislatures | |
| if input == "All": | |
| continue | |
| # Retrieve selected index and merge vector stores | |
| index = input.split(".")[0] | |
| index_name = f'{index}_legislature' | |
| local_db = FAISS.load_local(folder_path=folder_path, index_name=index_name, | |
| embeddings=embeddings, allow_dangerous_deserialization=True) | |
| db.merge_from(local_db) | |
| print('Successfully merged inputs') | |
| return db | |
| def RAG(llm, prompt, db, question): | |
| """ | |
| Apply Retrieval-Augmented Generation (RAG) by providing the context and the question to the | |
| language model using a predefined template. | |
| Parameters: | |
| ---------- | |
| llm : LanguageModel | |
| An instance of the language model to be used for generating responses. | |
| prompt : str | |
| A predefined template or prompt that structures how the context and question are presented to the language model. | |
| db : VectorStore | |
| A vector store instance that supports retrieval of relevant documents based on the input question. | |
| question : str | |
| The question or query to be answered by the language model. | |
| Returns: | |
| ------- | |
| str | |
| The response generated by the language model, based on the retrieved context and provided question. | |
| """ | |
| # Create a document chain using the provided language model and prompt template | |
| document_chain = create_stuff_documents_chain(llm=llm, prompt=prompt) | |
| # Convert the vector store into a retriever | |
| retriever = db.as_retriever() | |
| # Create a retrieval chain that integrates the retriever with the document chain | |
| retrieval_chain = create_retrieval_chain(retriever, document_chain) | |
| # Invoke the retrieval chain with the input question to get the final response | |
| response = retrieval_chain.invoke({"input": question}) | |
| return response | |
| def chatbot(message, history, db_inputs, prompt_language, llm=llm): | |
| """ | |
| Generate a response from the chatbot based on the provided message, history, database inputs, prompt language, and LLM model. | |
| Parameters: | |
| ----------- | |
| message : str | |
| The message or question to be answered by the chatbot. | |
| history : list | |
| The history of previous interactions or messages. | |
| db_inputs : list | |
| A list of strings specifying which vector stores to combine. Each string represents a specific index or a special keyword "All". | |
| prompt_language : str | |
| The language of the prompt to be used for generating the response. Should be either "DE" for German or "EN" for English. | |
| llm : LLM, optional | |
| An instance of the Language Model to be used for generating the response. Defaults to the global variable `llm`. | |
| Returns: | |
| -------- | |
| str | |
| The response generated by the chatbot. | |
| """ | |
| db = get_vectorstore(inputs = db_inputs, embeddings=embeddings) | |
| # Select prompt based on user input | |
| if prompt_language == "DE": | |
| prompt = prompt_de | |
| raw_response = RAG(llm=llm, prompt=prompt, db=db, question=message) | |
| # Only necessary because mistral does include it´s json structure in the output including its input content | |
| try: | |
| response = raw_response['answer'].split("Antwort: ")[1] | |
| except: | |
| response = raw_response['answer'] | |
| return response | |
| else: | |
| prompt = prompt_en | |
| raw_response = RAG(llm=llm, prompt=prompt, db=db, question=message) | |
| # Only necessary because mistral does include it´s json structure in the output including its input content | |
| try: | |
| response = raw_response['answer'].split("Answer: ")[1] | |
| except: | |
| response = raw_response['answer'] | |
| return response | |
| def keyword_search(query, n=10, embeddings=embeddings, method="ss", party_filter="All"): | |
| """ | |
| Retrieve speech contents based on keywords using a specified method. | |
| Parameters: | |
| ---------- | |
| db : FAISS | |
| The FAISS vector store containing speech embeddings. | |
| query : str | |
| The keyword(s) to search for in the speech contents. | |
| n : int, optional | |
| The number of speech contents to retrieve (default is 10). | |
| embeddings : Embeddings, optional | |
| An instance of embeddings used for embedding queries (default is embeddings). | |
| method : str, optional | |
| The method used for retrieving speech contents. Options are 'ss' (semantic search) and 'mmr' | |
| (maximal marginal relevance) (default is 'ss'). | |
| party_filter : str, optional | |
| A filter for retrieving speech contents by party affiliation. Specify 'All' to retrieve | |
| speeches from all parties (default is 'All'). | |
| Returns: | |
| ------- | |
| pandas.DataFrame | |
| A DataFrame containing the speech contents, dates, and party affiliations. | |
| Notes: | |
| ----- | |
| - The `db` parameter should be a FAISS vector store containing speech embeddings. | |
| - The `query` parameter specifies the keyword(s) to search for in the speech contents. | |
| - The `n` parameter determines the number of speech contents to retrieve (default is 10). | |
| - The `embeddings` parameter is an instance of embeddings used for embedding queries (default is embeddings). | |
| - The `method` parameter specifies the method used for retrieving speech contents. Options are 'ss' (semantic search) | |
| and 'mmr' (maximal marginal relevance) (default is 'ss'). | |
| - The `party_filter` parameter is a filter for retrieving speech contents by party affiliation. Specify 'All' to retrieve | |
| speeches from all parties (default is 'All'). | |
| """ | |
| db = get_vectorstore(inputs=["All"], embeddings=embeddings) | |
| query_embedding = embeddings.embed_query(query) | |
| # Maximal Marginal Relevance | |
| if method == "mmr": | |
| df_res = pd.DataFrame(columns=['Speech Content', 'Date', 'Party', 'Relevance']) | |
| results = db.max_marginal_relevance_search_with_score_by_vector(query_embedding, k=n) | |
| for doc in results: | |
| party = doc[0].metadata["party"] | |
| if party != party_filter and party_filter != 'All': | |
| continue | |
| speech_content = doc[0].page_content | |
| speech_date = doc[0].metadata["date"] | |
| score = round(doc[1], ndigits=2) | |
| df_res = pd.concat([df_res, pd.DataFrame({'Speech Content': [speech_content], | |
| 'Date': [speech_date], | |
| 'Party': [party], | |
| 'Relevance': [score]})], ignore_index=True) | |
| df_res.sort_values('Relevance', inplace=True, ascending=True) | |
| # Similarity Search | |
| elif method == "ss": | |
| kws_data = [] | |
| results = db.similarity_search_by_vector(query_embedding, k=n) | |
| for doc in results: | |
| party = doc.metadata["party"] | |
| if party != party_filter and party_filter != 'All': | |
| continue | |
| speech_content = doc.page_content | |
| speech_date = doc.metadata["date"] | |
| speech_date = speech_date.strftime("%Y-%m-%d") | |
| print(speech_date) | |
| # Error here? | |
| kws_entry = {'Speech Content': speech_content, | |
| 'Date': speech_date, | |
| 'Party': party} | |
| kws_data.append(kws_entry) | |
| df_res = pd.DataFrame(kws_data) | |
| return df_res | |