File size: 11,650 Bytes
f3d0f1e
 
 
38166c5
f3d0f1e
38166c5
 
 
 
 
 
 
 
f3d0f1e
227833b
d0fd192
38166c5
 
f3d0f1e
d0fd192
 
f3d0f1e
d0fd192
3ebff47
 
d0fd192
3ebff47
f3d0f1e
 
 
 
 
 
 
 
d0fd192
c98215f
f3d0f1e
 
 
 
 
 
c98215f
f3d0f1e
c98215f
f3d0f1e
c98215f
 
 
 
 
 
 
 
d0fd192
f3d0f1e
 
 
 
 
d0fd192
f3d0f1e
d0fd192
f3d0f1e
 
38166c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85df319
d0fd192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85df319
d0fd192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3d0f1e
 
85df319
38166c5
85df319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef6a605
 
f3d0f1e
85df319
 
38166c5
85df319
 
0d7e513
e681b03
 
85df319
0d7e513
 
85df319
0d7e513
85df319
 
 
0d7e513
85df319
 
38166c5
 
85df319
0d7e513
 
e681b03
85df319
0d7e513
 
38166c5
 
 
 
 
 
 
 
 
 
 
85df319
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.llms.huggingface_hub import HuggingFaceHub
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS


from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain

from langchain_community.docstore.in_memory import InMemoryDocstore
from faiss import IndexFlatL2

#import functools
import pandas as pd

# Load environmental variables from .env-file
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())

# Define important variables 
embeddings = HuggingFaceEmbeddings(model_name="paraphrase-multilingual-MiniLM-L12-v2") # Remove embedding input parameter from functions?
llm = HuggingFaceHub(
    # ToDo: Try different models here
    repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
    # repo_id="CohereForAI/c4ai-command-r-v01", # too large 69gb
    # repo_id="CohereForAI/c4ai-command-r-v01-4bit", # too large 22gb
    # repo_id="meta-llama/Meta-Llama-3-8B", # too large 16 gb
    task="text-generation",
    model_kwargs={
        "max_new_tokens": 512,
        "top_k": 30,
        "temperature": 0.1,
        "repetition_penalty": 1.03,
        }
)
# ToDo: Experiment with different templates
prompt_test = ChatPromptTemplate.from_template("""<s>[INST] 

                    Instruction: Beantworte die folgende Frage auf deutsch und nur auf der Grundlage des angegebenen Kontexts:



                    Context: {context}



                    Question: {input}  

                    [/INST]"""
                    
) 
prompt_de = ChatPromptTemplate.from_template("""Beantworte die folgende Frage auf deutsch und nur auf der Grundlage des angegebenen Kontexts:



        <context>

        {context}

        </context>



        Frage: {input}

        """
        # Returns the answer in German
)
prompt_en = ChatPromptTemplate.from_template("""Answer the following question in English and solely based on the provided context:



        <context>

        {context}

        </context>



        Question: {input}

        """
        # Returns the answer in English
)

db_all = FAISS.load_local(folder_path="./src/FAISS", index_name="speeches_1949_09_12",
                                           embeddings=embeddings, allow_dangerous_deserialization=True)

def get_vectorstore(inputs, embeddings):
    """

    Combine multiple FAISS vector stores into a single vector store based on the specified inputs.



    Parameters

    ----------

    inputs : list of str

        A list of strings specifying which vector stores to combine. Each string represents a specific 

        index or a special keyword "All". If "All" is the first entry in the list, 

        it directly return the pre-defined vectorstore for all speeches

        

    embeddings : Embeddings

        An instance of embeddings that will be used to load the vector stores. The specific type and

        structure of `embeddings` depend on the implementation of the `get_vectorstore` function.



    Returns

    -------

    FAISS

        A FAISS vector store that combines the specified indices into a single vector store.

    

    """

    # Default folder path
    folder_path = "./src/FAISS"
 

    if inputs[0] == "All" or inputs[0] is None:
        return db_all

    # Initialize empty db
    embedding_function = embeddings 
    dimensions = len(embedding_function.embed_query("dummy"))

    db = FAISS(
        embedding_function=embedding_function,
        index=IndexFlatL2(dimensions),
        docstore=InMemoryDocstore(),
        index_to_docstore_id={},
        normalize_L2=False
    )

    # Retrieve inputs: 20. Legislaturperiode, 19. Legislaturperiode, ...
    for input in inputs:
        # Ignore if user also selected All among other legislatures
        if input == "All":
            continue
        # Retrieve selected index and merge vector stores
        index = input.split(".")[0]
        index_name = f'{index}_legislature'
        local_db = FAISS.load_local(folder_path=folder_path, index_name=index_name,
                                    embeddings=embeddings, allow_dangerous_deserialization=True)
        db.merge_from(local_db)
        print('Successfully merged inputs')
    return db

def RAG(llm, prompt, db, question):
    """

    Apply Retrieval-Augmented Generation (RAG) by providing the context and the question to the 

    language model using a predefined template.



    Parameters:

    ----------

    llm : LanguageModel

        An instance of the language model to be used for generating responses.

        

    prompt : str

        A predefined template or prompt that structures how the context and question are presented to the language model.

        

    db : VectorStore

        A vector store instance that supports retrieval of relevant documents based on the input question.

        

    question : str

        The question or query to be answered by the language model.



    Returns:

    -------

    str

        The response generated by the language model, based on the retrieved context and provided question.

    """
    # Create a document chain using the provided language model and prompt template
    document_chain = create_stuff_documents_chain(llm=llm, prompt=prompt)
    # Convert the vector store into a retriever
    retriever = db.as_retriever()
    # Create a retrieval chain that integrates the retriever with the document chain
    retrieval_chain = create_retrieval_chain(retriever, document_chain)
    # Invoke the retrieval chain with the input question to get the final response
    response = retrieval_chain.invoke({"input": question})
    
    return response


def chatbot(message, history, db_inputs, prompt_language, llm=llm):
    """

    Generate a response from the chatbot based on the provided message, history, database inputs, prompt language, and LLM model.



    Parameters:

    -----------

    message : str

        The message or question to be answered by the chatbot.

        

    history : list

        The history of previous interactions or messages.

        

    db_inputs : list

        A list of strings specifying which vector stores to combine. Each string represents a specific index or a special keyword "All".

        

    prompt_language : str

        The language of the prompt to be used for generating the response. Should be either "DE" for German or "EN" for English.

        

    llm : LLM, optional

        An instance of the Language Model to be used for generating the response. Defaults to the global variable `llm`.



    Returns:

    --------

    str

        The response generated by the chatbot.

    """
    
    db = get_vectorstore(inputs = db_inputs, embeddings=embeddings)
    
    # Select prompt based on user input
    if prompt_language == "DE":
        prompt = prompt_de
        raw_response = RAG(llm=llm, prompt=prompt, db=db, question=message)
        # Only necessary because mistral does include it´s json structure in the output including its input content
        try:
            response = raw_response['answer'].split("Antwort: ")[1]
        except:  
            response = raw_response['answer']
        return response
    else:
        prompt = prompt_en
        raw_response = RAG(llm=llm, prompt=prompt, db=db, question=message)
        # Only necessary because mistral does include it´s json structure in the output including its input content
        try:
            response = raw_response['answer'].split("Answer: ")[1]
        except:  
            response = raw_response['answer']
    
    return response  


def keyword_search(query, n=10, embeddings=embeddings, method="ss", party_filter="All"):
    """

    Retrieve speech contents based on keywords using a specified method.



    Parameters:

    ----------

    db : FAISS

        The FAISS vector store containing speech embeddings.



    query : str

        The keyword(s) to search for in the speech contents.



    n : int, optional

        The number of speech contents to retrieve (default is 10).



    embeddings : Embeddings, optional

        An instance of embeddings used for embedding queries (default is embeddings).



    method : str, optional

        The method used for retrieving speech contents. Options are 'ss' (semantic search) and 'mmr' 

        (maximal marginal relevance) (default is 'ss').



    party_filter : str, optional

        A filter for retrieving speech contents by party affiliation. Specify 'All' to retrieve 

        speeches from all parties (default is 'All').



    Returns:

    -------

    pandas.DataFrame

        A DataFrame containing the speech contents, dates, and party affiliations.

    

    Notes:

    -----

    - The `db` parameter should be a FAISS vector store containing speech embeddings.

    - The `query` parameter specifies the keyword(s) to search for in the speech contents.

    - The `n` parameter determines the number of speech contents to retrieve (default is 10).

    - The `embeddings` parameter is an instance of embeddings used for embedding queries (default is embeddings).

    - The `method` parameter specifies the method used for retrieving speech contents. Options are 'ss' (semantic search) 

      and 'mmr' (maximal marginal relevance) (default is 'ss').

    - The `party_filter` parameter is a filter for retrieving speech contents by party affiliation. Specify 'All' to retrieve 

      speeches from all parties (default is 'All').

    """
    
    db = get_vectorstore(inputs=["All"], embeddings=embeddings)
    query_embedding = embeddings.embed_query(query)

    # Maximal Marginal Relevance
    if method == "mmr":
        df_res = pd.DataFrame(columns=['Speech Content', 'Date', 'Party', 'Relevance'])
        results = db.max_marginal_relevance_search_with_score_by_vector(query_embedding, k=n)
        for doc in results:
            party = doc[0].metadata["party"]
            if party != party_filter and party_filter != 'All':
                continue
            speech_content = doc[0].page_content
            speech_date = doc[0].metadata["date"]
            score = round(doc[1], ndigits=2)
            df_res = pd.concat([df_res, pd.DataFrame({'Speech Content': [speech_content],
                                                      'Date': [speech_date],
                                                      'Party': [party],
                                                      'Relevance': [score]})], ignore_index=True)
        df_res.sort_values('Relevance', inplace=True, ascending=True)

    # Similarity Search
    elif method == "ss":
        kws_data = [] 
        results = db.similarity_search_by_vector(query_embedding, k=n)
        for doc in results:
            party = doc.metadata["party"]
            if party != party_filter and party_filter != 'All':
                continue
            speech_content = doc.page_content
            speech_date = doc.metadata["date"]
            speech_date = speech_date.strftime("%Y-%m-%d")
            print(speech_date)
            # Error here?
            kws_entry = {'Speech Content': speech_content,
                        'Date': speech_date,
                        'Party': party}
            
            kws_data.append(kws_entry)
    
    df_res = pd.DataFrame(kws_data)

    return df_res