File size: 11,462 Bytes
7137e35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de70383
 
7137e35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.llms.huggingface_hub import HuggingFaceHub
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS


from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain

from langchain_community.docstore.in_memory import InMemoryDocstore
from faiss import IndexFlatL2

#import functools
import pandas as pd

# Load environmental variables from .env-file
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())

# Define important variables 
embeddings = HuggingFaceEmbeddings(model_name="paraphrase-multilingual-MiniLM-L12-v2") # Remove embedding input parameter from functions?
llm = HuggingFaceHub(
    # ToDo: Try different models here
    # repo_id="mistralai/Mistral-7B-Instruct-v0.3",
    repo_id = "mistralai/Ministral-8B-Instruct-2410",
    #repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
    # repo_id="CohereForAI/c4ai-command-r-v01", # too large 69gb
    # repo_id="CohereForAI/c4ai-command-r-v01-4bit", # too large 22gb
    # repo_id="meta-llama/Meta-Llama-3-8B", # too large 16 gb
    task="text-generation",
    model_kwargs={
        "max_new_tokens": 512,
        "top_k": 30,
        "temperature": 0.1,
        "repetition_penalty": 1.03,
        }
)
# ToDo: Experiment with different templates
prompt_test = ChatPromptTemplate.from_template("""<s>[INST] 
                    Instruction: Beantworte die folgende Frage auf deutsch und nur auf der Grundlage des angegebenen Kontexts:

                    Context: {context}

                    Question: {input}  
                    [/INST]"""
                    
) 
prompt_de = ChatPromptTemplate.from_template("""Beantworte die folgende Frage auf deutsch und nur auf der Grundlage des angegebenen Kontexts:

        <context>
        {context}
        </context>

        Frage: {input}
        """
        # Returns the answer in German
)
prompt_en = ChatPromptTemplate.from_template("""Answer the following question in English and solely based on the provided context:

        <context>
        {context}
        </context>

        Question: {input}
        """
        # Returns the answer in English
)

db_all = FAISS.load_local(folder_path="./src/FAISS", index_name="speeches_1949_09_12",
                                           embeddings=embeddings, allow_dangerous_deserialization=True)

def get_vectorstore(inputs, embeddings):
    """
    Combine multiple FAISS vector stores into a single vector store based on the specified inputs.

    Parameters
    ----------
    inputs : list of str
        A list of strings specifying which vector stores to combine. Each string represents a specific 
        index or a special keyword "All". If "All" is the first entry in the list, 
        it directly return the pre-defined vectorstore for all speeches
        
    embeddings : Embeddings
        An instance of embeddings that will be used to load the vector stores. The specific type and
        structure of `embeddings` depend on the implementation of the `get_vectorstore` function.

    Returns
    -------
    FAISS
        A FAISS vector store that combines the specified indices into a single vector store.
    
    """

    # Default folder path
    folder_path = "./src/FAISS"
 

    if inputs[0] == "All" or inputs[0] is None:
        return db_all

    # Initialize empty db
    embedding_function = embeddings 
    dimensions = len(embedding_function.embed_query("dummy"))

    db = FAISS(
        embedding_function=embedding_function,
        index=IndexFlatL2(dimensions),
        docstore=InMemoryDocstore(),
        index_to_docstore_id={},
        normalize_L2=False
    )

    # Retrieve inputs: 20. Legislaturperiode, 19. Legislaturperiode, ...
    for input in inputs:
        # Ignore if user also selected All among other legislatures
        if input == "All":
            continue
        # Retrieve selected index and merge vector stores
        index = input.split(".")[0]
        index_name = f'{index}_legislature'
        local_db = FAISS.load_local(folder_path=folder_path, index_name=index_name,
                                    embeddings=embeddings, allow_dangerous_deserialization=True)
        db.merge_from(local_db)
        print('Successfully merged inputs')
    return db


def RAG(llm, prompt, db, question):
    """
    Apply Retrieval-Augmented Generation (RAG) by providing the context and the question to the 
    language model using a predefined template.

    Parameters:
    ----------
    llm : LanguageModel
        An instance of the language model to be used for generating responses.
        
    prompt : str
        A predefined template or prompt that structures how the context and question are presented to the language model.
        
    db : VectorStore
        A vector store instance that supports retrieval of relevant documents based on the input question.
        
    question : str
        The question or query to be answered by the language model.

    Returns:
    -------
    str
        The response generated by the language model, based on the retrieved context and provided question.
    """
    # Create a document chain using the provided language model and prompt template
    document_chain = create_stuff_documents_chain(llm=llm, prompt=prompt)
    # Convert the vector store into a retriever
    retriever = db.as_retriever()
    # Create a retrieval chain that integrates the retriever with the document chain
    retrieval_chain = create_retrieval_chain(retriever, document_chain)
    # Invoke the retrieval chain with the input question to get the final response
    response = retrieval_chain.invoke({"input": question})
    
    return response


def chatbot(message, history, db_inputs, prompt_language, llm=llm):
    """
    Generate a response from the chatbot based on the provided message, history, database inputs, prompt language, and LLM model.

    Parameters:
    -----------
    message : str
        The message or question to be answered by the chatbot.
        
    history : list
        The history of previous interactions or messages.
        
    db_inputs : list
        A list of strings specifying which vector stores to combine. Each string represents a specific index or a special keyword "All".
        
    prompt_language : str
        The language of the prompt to be used for generating the response. Should be either "DE" for German or "EN" for English.
        
    llm : LLM, optional
        An instance of the Language Model to be used for generating the response. Defaults to the global variable `llm`.

    Returns:
    --------
    str
        The response generated by the chatbot.
    """
    
    db = get_vectorstore(inputs = db_inputs, embeddings=embeddings)
    
    # Select prompt based on user input
    if prompt_language == "DE":
        prompt = prompt_de
        raw_response = RAG(llm=llm, prompt=prompt, db=db, question=message)
        # Only necessary because mistral does include it´s json structure in the output including its input content
        try:
            response = raw_response['answer'].split("Antwort: ")[1]
        except:  
            response = raw_response['answer']
        return response
    else:
        prompt = prompt_en
        raw_response = RAG(llm=llm, prompt=prompt, db=db, question=message)
        # Only necessary because mistral does include it´s json structure in the output including its input content
        try:
            response = raw_response['answer'].split("Answer: ")[1]
        except:  
            response = raw_response['answer']
    
    return response  


def keyword_search(query, n=10, embeddings=embeddings, method="ss", party_filter="All"):
    """
    Retrieve speech contents based on keywords using a specified method.

    Parameters:
    ----------
    db : FAISS
        The FAISS vector store containing speech embeddings.

    query : str
        The keyword(s) to search for in the speech contents.

    n : int, optional
        The number of speech contents to retrieve (default is 10).

    embeddings : Embeddings, optional
        An instance of embeddings used for embedding queries (default is embeddings).

    method : str, optional
        The method used for retrieving speech contents. Options are 'ss' (semantic search) and 'mmr' 
        (maximal marginal relevance) (default is 'ss').

    party_filter : str, optional
        A filter for retrieving speech contents by party affiliation. Specify 'All' to retrieve 
        speeches from all parties (default is 'All').

    Returns:
    -------
    pandas.DataFrame
        A DataFrame containing the speech contents, dates, and party affiliations.
    
    Notes:
    -----
    - The `db` parameter should be a FAISS vector store containing speech embeddings.
    - The `query` parameter specifies the keyword(s) to search for in the speech contents.
    - The `n` parameter determines the number of speech contents to retrieve (default is 10).
    - The `embeddings` parameter is an instance of embeddings used for embedding queries (default is embeddings).
    - The `method` parameter specifies the method used for retrieving speech contents. Options are 'ss' (semantic search) 
      and 'mmr' (maximal marginal relevance) (default is 'ss').
    - The `party_filter` parameter is a filter for retrieving speech contents by party affiliation. Specify 'All' to retrieve 
      speeches from all parties (default is 'All').
    """
    
    db = get_vectorstore(inputs=["All"], embeddings=embeddings)
    query_embedding = embeddings.embed_query(query)

    # Maximal Marginal Relevance
    if method == "mmr":
        df_res = pd.DataFrame(columns=['Speech Content', 'Date', 'Party', 'Relevance'])
        results = db.max_marginal_relevance_search_with_score_by_vector(query_embedding, k=n)
        for doc in results:
            party = doc[0].metadata["party"]
            if party != party_filter and party_filter != 'All':
                continue
            speech_content = doc[0].page_content
            speech_date = doc[0].metadata["date"]
            score = round(doc[1], ndigits=2)
            df_res = pd.concat([df_res, pd.DataFrame({'Speech Content': [speech_content],
                                                      'Date': [speech_date],
                                                      'Party': [party],
                                                      'Relevance': [score]})], ignore_index=True)
        df_res.sort_values('Relevance', inplace=True, ascending=True)

    # Similarity Search
    elif method == "ss":
        kws_data = [] 
        results = db.similarity_search_by_vector(query_embedding, k=n)
        for doc in results:
            party = doc.metadata["party"]
            if party != party_filter and party_filter != 'All':
                continue
            speech_content = doc.page_content
            speech_date = doc.metadata["date"]
            speech_date = speech_date.strftime("%Y-%m-%d")
            print(speech_date)
            # Error here?
            kws_entry = {'Speech Content': speech_content,
                        'Date': speech_date,
                        'Party': party}
            
            kws_data.append(kws_entry)
    
    df_res = pd.DataFrame(kws_data)

    return df_res