jayebaku commited on
Commit
780faa4
verified
1 Parent(s): f91dfc4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +404 -0
app.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import pandas as pd
4
+
5
+ from classifier import classify
6
+ from statistics import mean
7
+ from qa_summary import generate_answer
8
+
9
+
10
+ HFTOKEN = os.environ["HF_TOKEN"]
11
+
12
+
13
+
14
+ js = """
15
+ async () => {
16
+ // Load Twitter Widgets script
17
+ const script = document.createElement("script");
18
+ script.onload = () => console.log("Twitter Widgets.js loaded");
19
+ script.src = "https://platform.twitter.com/widgets.js";
20
+ document.head.appendChild(script);
21
+
22
+ // Define a global function to reload Twitter widgets
23
+ globalThis.reloadTwitterWidgets = () => {
24
+ if (window.twttr && twttr.widgets) {
25
+ twttr.widgets.load();
26
+ }
27
+ };
28
+ }
29
+ """
30
+
31
+ def T_on_select(evt: gr.SelectData):
32
+
33
+ if evt.index[1] == 3:
34
+ html = """<blockquote class="twitter-tweet" data-dnt="true" data-theme="dark">""" + \
35
+ f"""\n<a href="https://twitter.com/anyuser/status/{evt.value}"></a></blockquote>"""
36
+ else:
37
+ html = f"""<h2>{evt.value}</h2>"""
38
+ return gr.update(value=html)
39
+
40
+ def single_classification(text, event_model, threshold):
41
+ res = classify(text, event_model, HFTOKEN, threshold)
42
+ return res["event"], res["score"]
43
+
44
+ def load_and_classify_csv(file, text_field, event_model, threshold):
45
+ filepath = file.name
46
+ if ".csv" in filepath:
47
+ df = pd.read_csv(filepath)
48
+ else:
49
+ df = pd.read_table(filepath)
50
+
51
+ if text_field not in df.columns:
52
+ raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.")
53
+
54
+ labels, scores = [], []
55
+ for post in df[text_field].to_list():
56
+ res = classify(post, event_model, HFTOKEN, threshold)
57
+ labels.append(res["event"])
58
+ scores.append(res["score"])
59
+
60
+ df["model_label"] = labels
61
+ df["model_score"] = scores
62
+
63
+ # model_confidence = round(mean(scores), 5)
64
+ model_confidence = mean(scores)
65
+ fire_related = gr.CheckboxGroup(choices=df[df["model_label"]=="fire"][text_field].to_list())
66
+ flood_related = gr.CheckboxGroup(choices=df[df["model_label"]=="flood"][text_field].to_list())
67
+ not_related = gr.CheckboxGroup(choices=df[df["model_label"]=="none"][text_field].to_list())
68
+
69
+ return flood_related, fire_related, not_related, model_confidence, len(df[text_field].to_list()), df, gr.update(interactive=True), gr.update(interactive=True)
70
+
71
+ def load_and_classify_csv_dataframe(file, text_field, event_model, threshold): #, filter
72
+
73
+ filepath = file.name
74
+ if ".csv" in filepath:
75
+ df = pd.read_csv(filepath)
76
+ else:
77
+ df = pd.read_table(filepath)
78
+
79
+ if text_field not in df.columns:
80
+ raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.")
81
+
82
+ labels, scores = [], []
83
+ for post in df[text_field].to_list():
84
+ res = classify(post, event_model, HFTOKEN, threshold)
85
+ labels.append(res["event"])
86
+ scores.append(round(res["score"], 5))
87
+
88
+ df["event_label"] = labels
89
+ df["model_score"] = scores
90
+
91
+ result_df = df[[text_field, "event_label", "model_score", "tweet_id"]].copy()
92
+ result_df["tweet_id"] = result_df["tweet_id"].astype(str)
93
+
94
+ filters = list(result_df["event_label"].unique())
95
+ extra_filters = ['Not-'+x for x in filters]+['All']
96
+
97
+ return gr.update(value=result_df), result_df, gr.update(choices=sorted(filters+extra_filters),
98
+ value='All',
99
+ label="Filter data by label",
100
+ visible=True)
101
+
102
+
103
+ def calculate_accuracy(flood_selections, fire_selections, none_selections, num_posts, text_field, data_df):
104
+ posts = data_df[text_field].to_list()
105
+ selections = flood_selections + fire_selections + none_selections
106
+ eval = []
107
+ for post in posts:
108
+ if post in selections:
109
+ eval.append("incorrect")
110
+ else:
111
+ eval.append("correct")
112
+
113
+ data_df["model_eval"] = eval
114
+ incorrect = len(selections)
115
+ correct = num_posts - incorrect
116
+ accuracy = (correct/num_posts)*100
117
+
118
+ data_df.to_csv("output.csv")
119
+ return incorrect, correct, accuracy, data_df, gr.DownloadButton(label=f"Download CSV", value="output.csv", visible=True)
120
+
121
+ def init_queries(history):
122
+ history = history or []
123
+ if not history:
124
+ history = [
125
+ "What areas are being evacuated?",
126
+ "What areas are predicted to be impacted?",
127
+ "What areas are without power?",
128
+ "What barriers are hindering response efforts?",
129
+ "What events have been canceled?",
130
+ "What preparations are being made?",
131
+ "What regions have announced a state of emergency?",
132
+ "What roads are blocked / closed?",
133
+ "What services have been closed?",
134
+ "What warnings are currently in effect?",
135
+ "Where are emergency services deployed?",
136
+ "Where are emergency services needed?",
137
+ "Where are evacuations needed?",
138
+ "Where are people needing rescued?",
139
+ "Where are recovery efforts taking place?",
140
+ "Where has building or infrastructure damage occurred?",
141
+ "Where has flooding occured?"
142
+ "Where are volunteers being requested?",
143
+ "Where has road damage occured?",
144
+ "What area has the wildfire burned?",
145
+ "Where have homes been damaged or destroyed?"]
146
+
147
+ return gr.CheckboxGroup(choices=history), history
148
+
149
+ def add_query(to_add, history):
150
+ if to_add not in history:
151
+ history.append(to_add)
152
+ return gr.CheckboxGroup(choices=history), history
153
+
154
+ def qa_summarise(selected_queries, qa_llm_model, text_field, data_df):
155
+
156
+ qa_input_df = data_df[data_df["model_label"] != "none"].reset_index()
157
+ texts = qa_input_df[text_field].to_list()
158
+
159
+ summary = generate_answer(qa_llm_model, texts, selected_queries[0], selected_queries, mode="multi_summarize")
160
+
161
+ doc_df = pd.DataFrame()
162
+ doc_df["number"] = [i+1 for i in range(len(texts))]
163
+ doc_df["text"] = texts
164
+
165
+ return summary, doc_df
166
+
167
+
168
+ with gr.Blocks(fill_width=True) as demo:
169
+
170
+ demo.load(None,None,None,js=js)
171
+
172
+ event_models = ["jayebaku/distilbert-base-multilingual-cased-crexdata-relevance-classifier",
173
+ "jayebaku/distilbert-base-multilingual-cased-weather-classifier-2",
174
+ "jayebaku/twitter-xlm-roberta-base-crexdata-relevance-classifier",
175
+ "jayebaku/twhin-bert-base-crexdata-relevance-classifier"]
176
+
177
+ T_data_ss_state = gr.State(value=pd.DataFrame())
178
+
179
+
180
+ with gr.Tab("Event Type Classification"):
181
+ gr.Markdown(
182
+ """
183
+ # T4.5 Relevance Classifier Demo
184
+ This is a demo created to explore floods and wildfire classification in social media posts.\n
185
+ Usage:\n
186
+ - Upload .tsv or .csv data file (must contain a text column with social media posts).\n
187
+ - Next, type the name of the text column.\n
188
+ - Then, choose a BERT classifier model from the drop down.\n
189
+ - Finally, click the 'start prediction' buttton.\n
190
+ """)
191
+ with gr.Row():
192
+ with gr.Column(scale=4):
193
+ T_file_input = gr.File(label="Upload CSV or TSV File", file_types=['.tsv', '.csv'])
194
+
195
+ with gr.Column(scale=6):
196
+ T_text_field = gr.Textbox(label="Text field name", value="tweet_text")
197
+ T_event_model = gr.Dropdown(event_models, value=event_models[0], label="Select classification model")
198
+ T_predict_button = gr.Button("Start Prediction")
199
+ with gr.Accordion("Prediction threshold", open=False):
200
+ T_threshold = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold", show_label=False,
201
+ info="This value sets a threshold by which texts classified flood or fire are accepted, \
202
+ higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True)
203
+
204
+ with gr.Row():
205
+ with gr.Column(scale=8):
206
+ T_data = gr.DataFrame(wrap=True,
207
+ show_fullscreen_button=True,
208
+ show_copy_button=True,
209
+ show_row_numbers=True,
210
+ show_search="filter",
211
+ column_widths=["49%","17%","17%","17%"])
212
+
213
+ with gr.Column(scale=2):
214
+ T_data_filter = gr.Dropdown(visible=False)
215
+ T_tweet_embed = gr.HTML("<h1>Select a Tweet ID to view Tweet</h1>")
216
+
217
+
218
+
219
+ with gr.Tab("Event Type Classification Eval"):
220
+ gr.Markdown(
221
+ """
222
+ # T4.5 Relevance Classifier Demo
223
+ This is a demo created to explore floods and wildfire classification in social media posts.\n
224
+ Usage:\n
225
+ - Upload .tsv or .csv data file (must contain a text column with social media posts).\n
226
+ - Next, type the name of the text column.\n
227
+ - Then, choose a BERT classifier model from the drop down.\n
228
+ - Finally, click the 'start prediction' buttton.\n
229
+ Evaluation:\n
230
+ - To evaluate the model's accuracy select the INCORRECT classifications using the checkboxes in front of each post.\n
231
+ - Then, click on the 'Calculate Accuracy' button.\n
232
+ - Then, click on the 'Download data as CSV' to get the classifications and evaluation data as a .csv file.
233
+ """)
234
+ with gr.Row():
235
+ with gr.Column(scale=4):
236
+ file_input = gr.File(label="Upload CSV or TSV File", file_types=['.tsv', '.csv'])
237
+
238
+ with gr.Column(scale=6):
239
+ text_field = gr.Textbox(label="Text field name", value="tweet_text")
240
+ event_model = gr.Dropdown(event_models, value=event_models[0], label="Select classification model")
241
+ ETCE_predict_button = gr.Button("Start Prediction")
242
+ with gr.Accordion("Prediction threshold", open=False):
243
+ threshold = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold", show_label=False,
244
+ info="This value sets a threshold by which texts classified flood or fire are accepted, \
245
+ higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True)
246
+
247
+ with gr.Row(): # XXX confirm this is not a problem later --equal_height=True
248
+ with gr.Column():
249
+ gr.Markdown("""### Flood-related""")
250
+ flood_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
251
+
252
+ with gr.Column():
253
+ gr.Markdown("""### Fire-related""")
254
+ fire_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
255
+
256
+ with gr.Column():
257
+ gr.Markdown("""### None""")
258
+ none_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
259
+
260
+ with gr.Row():
261
+ with gr.Column(scale=5):
262
+ gr.Markdown(r"""
263
+ Accuracy: is the model's ability to make correct predicitons.
264
+ It is the fraction of correct prediction out of the total predictions.
265
+
266
+ $$
267
+ \text{Accuracy} = \frac{\text{Correct predictions}}{\text{All predictions}} * 100
268
+ $$
269
+
270
+ Model Confidence: is the mean probabilty of each case
271
+ belonging to their assigned classes. A value of 1 is best.
272
+ """, latex_delimiters=[{ "left": "$$", "right": "$$", "display": True }])
273
+ gr.Markdown("\n\n\n")
274
+ model_confidence = gr.Number(label="Model Confidence")
275
+
276
+ with gr.Column(scale=5):
277
+ correct = gr.Number(label="Number of correct classifications")
278
+ incorrect = gr.Number(label="Number of incorrect classifications")
279
+ accuracy = gr.Number(label="Model Accuracy (%)")
280
+
281
+ ETCE_accuracy_button = gr.Button("Calculate Accuracy")
282
+ download_csv = gr.DownloadButton(visible=False)
283
+ num_posts = gr.Number(visible=False)
284
+ data = gr.DataFrame(visible=False)
285
+ data_eval = gr.DataFrame(visible=False)
286
+
287
+
288
+ qa_tab = gr.Tab("Question Answering")
289
+ with qa_tab:
290
+ gr.Markdown(
291
+ """
292
+ # Question Answering Demo
293
+ This section uses RAG to answer questions about the relevant social media posts identified by the relevance classifier\n
294
+ Usage:\n
295
+ - Select queries from predefined\n
296
+ - Parameters for QA can be editted in sidebar\n
297
+
298
+ Note: QA process is disabled untill after the relevance classification is done
299
+ """)
300
+
301
+ with gr.Accordion("Parameters", open=False):
302
+ with gr.Row():
303
+ with gr.Column():
304
+ qa_llm_model = gr.Dropdown(["mistral", "solar", "phi3mini"], label="QA model", value="phi3mini", interactive=True)
305
+ aggregator = gr.Dropdown(["linear", "outrank"], label="Aggregation method", value="linear", interactive=True)
306
+ with gr.Column():
307
+ batch_size = gr.Slider(50, 500, value=150, step=1, label="Batch size", info="Choose between 50 and 500", interactive=True)
308
+ topk = gr.Slider(1, 10, value=5, step=1, label="Number of results to retrieve", info="Choose between 1 and 10", interactive=True)
309
+
310
+ selected_queries = gr.CheckboxGroup(label="Select at least one query using the checkboxes", interactive=True)
311
+ queries_state = gr.State()
312
+ qa_tab.select(init_queries, inputs=queries_state, outputs=[selected_queries, queries_state])
313
+
314
+ query_inp = gr.Textbox(label="Add custom queries like the one above, one at a time")
315
+ QA_addqry_button = gr.Button("Add to queries", interactive=False)
316
+ QA_run_button = gr.Button("Start QA", interactive=False)
317
+ hsummary = gr.Textbox(label="Summary")
318
+
319
+ qa_df = gr.DataFrame()
320
+
321
+
322
+ with gr.Tab("Single Text Classification"):
323
+ gr.Markdown(
324
+ """
325
+ # Event Type Prediction Demo
326
+ In this section you test the relevance classifier with written texts.\n
327
+ Usage:\n
328
+ - Type a tweet-like text in the textbox.\n
329
+ - Then press Enter.\n
330
+ """)
331
+ with gr.Row():
332
+ with gr.Column(scale=3):
333
+ model_sing_classify = gr.Dropdown(event_models, value=event_models[0], label="Select classification model")
334
+ with gr.Column(scale=7):
335
+ threshold_sing_classify = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold",
336
+ info="This value sets a threshold by which texts classified flood or fire are accepted, \
337
+ higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True)
338
+
339
+ text_to_classify = gr.Textbox(label="Text", info="Enter tweet-like text", submit_btn=True)
340
+ text_to_classify_examples = gr.Examples([["The streets are flooded, I can't leave #BostonStorm"],
341
+ ["Controlado el incendio de Rodezno que ha obligado a desalojar a varias bodegas de la zona."],
342
+ ["Cambrils:estaci贸 Renfe inundada 19 persones dins d'un tren. FGC a Capellades, petit descarrilament 5 passatgers #Inuncat @emergenciescat"],
343
+ ["Anscheinend steht die komplette Neckarwiese unter Wasser! #Hochwasser"]], text_to_classify)
344
+
345
+ with gr.Row():
346
+ with gr.Column():
347
+ classification = gr.Textbox(label="Classification")
348
+ with gr.Column():
349
+ classification_score = gr.Number(label="Classification Score")
350
+
351
+
352
+
353
+
354
+
355
+
356
+
357
+
358
+ # Test event listeners
359
+ T_predict_button.click(
360
+ load_and_classify_csv_dataframe,
361
+ inputs=[T_file_input, T_text_field, T_event_model, T_threshold],
362
+ outputs=[T_data, T_data_ss_state, T_data_filter]
363
+ )
364
+
365
+ T_data.select(T_on_select, None, T_tweet_embed).then(fn=None, js="reloadTwitterWidgets()")
366
+
367
+ @T_data_filter.input(inputs=[T_data_ss_state, T_data_filter], outputs=T_data)
368
+ def filter_df(df, filter):
369
+ if filter == "All":
370
+ result_df = df.copy()
371
+ elif filter.startswith("Not"):
372
+ result_df = df[df["event_label"]!=filter.split('-')[1]].copy()
373
+ else:
374
+ result_df = df[df["event_label"]==filter].copy()
375
+ return gr.update(value=result_df)
376
+
377
+
378
+ # Button clicks ETC Eval
379
+ ETCE_predict_button.click(
380
+ load_and_classify_csv,
381
+ inputs=[file_input, text_field, event_model, threshold],
382
+ outputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, model_confidence, num_posts, data, QA_addqry_button, QA_run_button])
383
+
384
+ ETCE_accuracy_button.click(
385
+ calculate_accuracy,
386
+ inputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, num_posts, text_field, data],
387
+ outputs=[incorrect, correct, accuracy, data_eval, download_csv])
388
+
389
+
390
+ # Button clicks QA
391
+ QA_addqry_button.click(add_query, inputs=[query_inp, queries_state], outputs=[selected_queries, queries_state])
392
+
393
+ QA_run_button.click(qa_summarise,
394
+ inputs=[selected_queries, qa_llm_model, text_field, data], ## XXX fix text_field
395
+ outputs=[hsummary, qa_df])
396
+
397
+
398
+ # Event listener for single text classification
399
+ text_to_classify.submit(
400
+ single_classification,
401
+ inputs=[text_to_classify, model_sing_classify, threshold_sing_classify],
402
+ outputs=[classification, classification_score])
403
+
404
+ demo.launch()