BHO commited on
Commit
e556db9
Β·
1 Parent(s): decbdf3

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +416 -0
app.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from haystack.document_stores import FAISSDocumentStore
3
+ from haystack.nodes import EmbeddingRetriever
4
+ import openai
5
+ import pandas as pd
6
+ import os
7
+ from utils import (
8
+ make_pairs,
9
+ set_openai_api_key,
10
+ create_user_id,
11
+ to_completion,
12
+ )
13
+
14
+ from datetime import datetime
15
+
16
+ # from azure.storage.fileshare import ShareServiceClient
17
+
18
+ try:
19
+ from dotenv import load_dotenv
20
+
21
+ load_dotenv()
22
+ except:
23
+ pass
24
+
25
+ theme = gr.themes.Soft(
26
+ primary_hue="sky",
27
+ font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
28
+ )
29
+
30
+ init_prompt = (
31
+ "TKOQA, an AI Assistant for Tikehau. "
32
+
33
+ )
34
+ sources_prompt = (
35
+ "When relevant, use facts and numbers from the following documents in your answer. "
36
+ )
37
+
38
+
39
+ def get_reformulation_prompt(query: str) -> str:
40
+ return f"""Reformulate the following user message to be a short standalone question in English, in the context of the Universal Registration Document of Tikehau .
41
+ ---
42
+ query: what is the AUM of Tikehau in 2022?
43
+ standalone question: What is the AUM of TIkehau in 2022?
44
+ language: English
45
+ ---
46
+ query: what is T2?
47
+ standalone question: what is the transition energy fund at Tikehau?
48
+ language: English
49
+ ---
50
+ query: what is the business of Tikehau?
51
+ standalone question: What are the main business units of Tikehau?
52
+ language: English
53
+ ---
54
+ query: {query}
55
+ standalone question:"""
56
+
57
+
58
+
59
+ system_template = {
60
+ "role": "system",
61
+ "content": init_prompt,
62
+ }
63
+
64
+ # openai.api_type = "azure"
65
+ os.environ["OPENAI_API_KEY"] = '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'
66
+ openai.api_key = os.environ["OPENAI_API_KEY"]
67
+
68
+ # BHO
69
+ # openai.api_base = os.environ["ressource_endpoint"]
70
+ # openai.api_version = "2022-12-01"
71
+
72
+ document_store = FAISSDocumentStore()
73
+
74
+ ds = FAISSDocumentStore.load(index_path="./tko_urd.faiss", config_path="./tko_urd.json",)
75
+
76
+ retriever = EmbeddingRetriever(
77
+ document_store=ds,
78
+ embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
79
+ model_format="sentence_transformers",
80
+ progress_bar=False,
81
+ )
82
+
83
+ # retrieve_giec = EmbeddingRetriever(
84
+ # document_store=FAISSDocumentStore.load(
85
+ # index_path="./documents/climate_gpt_v2_only_giec.faiss",
86
+ # config_path="./documents/climate_gpt_v2_only_giec.json",
87
+ # ),
88
+ # embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
89
+ # model_format="sentence_transformers",
90
+ # )
91
+
92
+ # BHO
93
+ # For Azure connection in secrets in HuggingFace
94
+ # credential = {
95
+ # "account_key": os.environ["account_key"],
96
+ # "account_name": os.environ["account_name"],
97
+ # }
98
+
99
+ # BHO
100
+ # account_url = os.environ["account_url"]
101
+ # file_share_name = "climategpt"
102
+ # service = ShareServiceClient(account_url=account_url, credential=credential)
103
+ # share_client = service.get_share_client(file_share_name)
104
+ user_id = create_user_id(10)
105
+
106
+
107
+ def filter_sources(df, k_summary=3, k_total=10, source="ipcc"):
108
+ assert source in ["ipcc", "ipbes", "all"]
109
+
110
+ # Filter by source
111
+ if source == "ipcc":
112
+ df = df.loc[df["source"] == "IPCC"]
113
+ elif source == "ipbes":
114
+ df = df.loc[df["source"] == "IPBES"]
115
+ else:
116
+ pass
117
+
118
+ # Prepare summaries
119
+ df_summaries = df #.loc[df.loc.obj.values]
120
+ # Separate summaries and full reports
121
+ #df_summaries = df.loc[df["report_type"].isin(["SPM", "TS"])]
122
+ #df_full = df.loc[~df["report_type"].isin(["SPM", "TS"])]
123
+
124
+ # Find passages from summaries dataset
125
+ passages_summaries = df_summaries.head(k_summary)
126
+
127
+ # Find passages from full reports dataset
128
+ # passages_fullreports = df_full.head(k_total - len(passages_summaries))
129
+
130
+ # Concatenate passages
131
+ #passages = pd.concat([passages_summaries, passages_fullreports], axis=0, ignore_index=True)
132
+ passages = passages_summaries
133
+ return passages
134
+
135
+
136
+ def retrieve_with_summaries(query, retriever, k_summary=3, k_total=10, source="ipcc", max_k=100, threshold=0.555,
137
+ as_dict=True):
138
+ assert max_k > k_total
139
+ docs = retriever.retrieve(query, top_k=max_k)
140
+ docs = [{**x.meta, "score": x.score, "content": x.content} for x in docs if x.score > threshold]
141
+ if len(docs) == 0:
142
+ return []
143
+ res = pd.DataFrame(docs)
144
+ passages_df = filter_sources(res, k_summary, k_total, source)
145
+ if as_dict:
146
+ contents = passages_df["content"].tolist()
147
+ meta = passages_df.drop(columns=["content"]).to_dict(orient="records")
148
+ passages = []
149
+ for i in range(len(contents)):
150
+ passages.append({"content": contents[i], "meta": meta[i]})
151
+ return passages
152
+ else:
153
+ return passages_df
154
+
155
+
156
+ def make_html_source(source, i):
157
+ meta = source['meta']
158
+ return f"""
159
+ <div class="card">
160
+ <div class="card-content">
161
+ <h2>Doc {i} - {meta['file_name']} - Page {meta['page_number']}</h2>
162
+ <p>{source['content']}</p>
163
+ </div>
164
+
165
+ </div>
166
+ """
167
+
168
+
169
+ def chat(
170
+ user_id: str,
171
+ query: str,
172
+ history: list = [system_template],
173
+ report_type: str = "All available",
174
+ threshold: float = 0.555,
175
+ ) -> tuple:
176
+ """retrieve relevant documents in the document store then query gpt-turbo
177
+
178
+ Args:
179
+ query (str): user message.
180
+ history (list, optional): history of the conversation. Defaults to [system_template].
181
+ report_type (str, optional): should be "All available" or "IPCC only". Defaults to "All available".
182
+ threshold (float, optional): similarity threshold, don't increase more than 0.568. Defaults to 0.56.
183
+
184
+ Yields:
185
+ tuple: chat gradio format, chat openai format, sources used.
186
+ """
187
+
188
+ if report_type not in ["IPCC", "IPBES"]: report_type = "all"
189
+ print("Searching in ", report_type, " reports")
190
+
191
+ reformulated_query = openai.Completion.create(
192
+ engine="text-davinci-003",
193
+ prompt=get_reformulation_prompt(query),
194
+ temperature=0,
195
+ max_tokens=128,
196
+ stop=["\n---\n", "<|im_end|>"],
197
+ )
198
+
199
+ reformulated_query = reformulated_query["choices"][0]["text"]
200
+ reformulated_query, language = reformulated_query.split("\n")
201
+ language = language.split(":")[1].strip()
202
+
203
+ sources = retrieve_with_summaries(reformulated_query, retriever, k_total=10, k_summary=3, as_dict=True,
204
+ source=report_type.lower(), threshold=threshold)
205
+ response_retriever = {
206
+ "language": language,
207
+ "reformulated_query": reformulated_query,
208
+ "query": query,
209
+ "sources": sources,
210
+ }
211
+
212
+ # docs = [d for d in retriever.retrieve(query=reformulated_query, top_k=10) if d.score > threshold]
213
+ messages = history + [{"role": "user", "content": query}]
214
+
215
+ if len(sources) > 0:
216
+ docs_string = []
217
+ docs_html = []
218
+ for i, d in enumerate(sources, 1):
219
+ #docs_string.append(f"πŸ“ƒ Doc {i}: {d['meta']['short_name']} page {d['meta']['page_number']}\n{d['content']}")
220
+ docs_string.append(f"πŸ“ƒ Doc {i}: {d['meta']['file_name']} page {d['meta']['page_number']}\n{d['content']}")
221
+ docs_html.append(make_html_source(d, i))
222
+ docs_string = "\n\n".join([f"Query used for retrieval:\n{reformulated_query}"] + docs_string)
223
+ docs_html = "\n\n".join([f"Query used for retrieval:\n{reformulated_query}"] + docs_html)
224
+ messages.append({"role": "system", "content": f"{sources_prompt}\n\n{docs_string}\n\nAnswer in {language}:"})
225
+
226
+ response = openai.Completion.create(
227
+ # engine="climateGPT",
228
+ engine="text-davinci-003",
229
+ prompt=to_completion(messages),
230
+ temperature=0, # deterministic
231
+ stream=True,
232
+ max_tokens=1024,
233
+ )
234
+
235
+ complete_response = ""
236
+ messages.pop()
237
+
238
+ messages.append({"role": "assistant", "content": complete_response})
239
+ timestamp = str(datetime.now().timestamp())
240
+ file = user_id[0] + timestamp + ".json"
241
+ logs = {
242
+ "user_id": user_id[0],
243
+ "prompt": query,
244
+ "retrived": sources,
245
+ "report_type": report_type,
246
+ "prompt_eng": messages[0],
247
+ "answer": messages[-1]["content"],
248
+ "time": timestamp,
249
+ }
250
+ # log_on_azure(file, logs, share_client)
251
+ print(logs)
252
+
253
+ for chunk in response:
254
+ if (chunk_message := chunk["choices"][0].get("text")) and chunk_message != "<|im_end|>":
255
+ complete_response += chunk_message
256
+ messages[-1]["content"] = complete_response
257
+ gradio_format = make_pairs([a["content"] for a in messages[1:]])
258
+ yield gradio_format, messages, docs_html
259
+
260
+ else:
261
+ docs_string = "⚠️ No relevant passages found in the URDs"
262
+ complete_response = "**⚠️ No relevant passages found in the URDs **"
263
+ messages.append({"role": "assistant", "content": complete_response})
264
+ gradio_format = make_pairs([a["content"] for a in messages[1:]])
265
+ yield gradio_format, messages, docs_string
266
+
267
+
268
+ def save_feedback(feed: str, user_id):
269
+ if len(feed) > 1:
270
+ timestamp = str(datetime.now().timestamp())
271
+ file = user_id[0] + timestamp + ".json"
272
+ logs = {
273
+ "user_id": user_id[0],
274
+ "feedback": feed,
275
+ "time": timestamp,
276
+ }
277
+ # log_on_azure(file, logs, share_client)
278
+ print(logs)
279
+ return "Feedback submitted, thank you!"
280
+
281
+
282
+ def reset_textbox():
283
+ return gr.update(value="")
284
+
285
+
286
+ # def log_on_azure(file, logs, share_client):
287
+ # file_client = share_client.get_file_client(file)
288
+ # file_client.upload_file(str(logs))
289
+
290
+
291
+ with gr.Blocks(title="TKO URD Q&A", css="style.css", theme=theme) as demo:
292
+ user_id_state = gr.State([user_id])
293
+
294
+ # Gradio
295
+ gr.Markdown("<h1><center>Tikehau Capital Q&A </center></h1>")
296
+
297
+ with gr.Row():
298
+ with gr.Column(scale=2):
299
+ chatbot = gr.Chatbot(elem_id="chatbot", label=" Tikehau Capital Q&A chatbot", show_label=False)
300
+ state = gr.State([system_template])
301
+
302
+ with gr.Row():
303
+ ask = gr.Textbox(
304
+ show_label=True,
305
+ placeholder="Ask here your Tikehau-related question and press enter",
306
+ ).style(container=False)
307
+ #ask_examples_hidden = gr.Textbox(elem_id="hidden-message")
308
+
309
+ # examples_questions = gr.Examples(
310
+ # [
311
+ # "What is the AUM of Tikehau in 2022?",
312
+ # ],
313
+ # [ask_examples_hidden],
314
+ # examples_per_page=15,
315
+ #)
316
+
317
+ with gr.Column(scale=1, variant="panel"):
318
+ gr.Markdown("### Sources")
319
+ sources_textbox = gr.Markdown(show_label=False)
320
+
321
+ # dropdown_sources = gr.inputs.Dropdown(
322
+ # ["IPCC", "IPBES", "ALL"],
323
+ # default="ALL",
324
+ # label="Select reports",
325
+ # )
326
+ dropdown_sources = gr.State(["All"])
327
+
328
+ ask.submit(
329
+ fn=chat,
330
+ inputs=[
331
+ user_id_state,
332
+ ask,
333
+ state,
334
+ dropdown_sources
335
+
336
+ ],
337
+ outputs=[chatbot, state, sources_textbox],
338
+ )
339
+ ask.submit(reset_textbox, [], [ask])
340
+
341
+ # ask_examples_hidden.change(
342
+ # fn=chat,
343
+ # inputs=[
344
+ # user_id_state,
345
+ # ask_examples_hidden,
346
+ # state,
347
+ # dropdown_sources
348
+ # ],
349
+ # outputs=[chatbot, state, sources_textbox],
350
+ # )
351
+
352
+ with gr.Row():
353
+ with gr.Column(scale=1):
354
+ gr.Markdown(
355
+ """
356
+ <div class="warning-box">
357
+ Version 0.1-beta - This tool is under active development
358
+ </div>
359
+
360
+ """
361
+ )
362
+
363
+ with gr.Column(scale=1):
364
+ gr.Markdown("*Source : Tikehau Universal Registration Documents *")
365
+
366
+ gr.Markdown("## How to use TKO URD Q&A")
367
+ with gr.Row():
368
+ with gr.Column(scale=1):
369
+ gr.Markdown(
370
+ """
371
+ ### πŸ’ͺ Getting started
372
+ - In the chatbot section, simply type your Tikehau-related question, answers will be provided with references to relevant URDs.
373
+ """
374
+ )
375
+ with gr.Column(scale=1):
376
+ gr.Markdown(
377
+ """
378
+ ### ⚠️ Limitations
379
+ <div class="warning-box">
380
+ <ul>
381
+ <li>Please note that, like any AI, the model may occasionally generate an inaccurate or imprecise answer.</li>
382
+ </div>
383
+ """
384
+ )
385
+
386
+ gr.Markdown("## πŸ™ Feedback and feature requests")
387
+ gr.Markdown(
388
+ """
389
+ ### Beta test
390
+ - Feedback welcome.
391
+ """
392
+ )
393
+
394
+ gr.Markdown(
395
+ """
396
+
397
+
398
+ ## πŸ›’οΈ Carbon Footprint
399
+
400
+ Carbon emissions were measured during the development and inference process using CodeCarbon [https://github.com/mlco2/codecarbon](https://github.com/mlco2/codecarbon)
401
+
402
+ | Phase | Description | Emissions | Source |
403
+ | --- | --- | --- | --- |
404
+ | Inference | API call to turbo-GPT | ~0.38gCO2e / call | https://medium.com/@chrispointon/the-carbon-footprint-of-chatgpt-e1bc14e4cc2a |
405
+
406
+ Carbon Emissions are **relatively low but not negligible** compared to other usages: one question asked to ClimateQ&A is around 0.482gCO2e - equivalent to 2.2m by car (https://datagir.ademe.fr/apps/impact-co2/)
407
+ Or around 2 to 4 times more than a typical Google search.
408
+
409
+ </b>.
410
+
411
+ """
412
+ )
413
+
414
+ demo.queue(concurrency_count=16)
415
+
416
+ demo.launch()