Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import time | |
| import gradio as gr | |
| import pandas as pd | |
| from classifier import classify | |
| from statistics import mean | |
| HFTOKEN = os.environ["HF_TOKEN"] | |
| def load_and_analyze_csv(file, text_field, event_model): | |
| df = pd.read_table(file.name) | |
| if text_field not in df.columns: | |
| raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.") | |
| labels, scores = [], [] | |
| for post in df[text_field].to_list(): | |
| res = classify(post, event_model, HFTOKEN) | |
| labels.append(res["event"]) | |
| scores.append(res["score"]) | |
| df["model_label"] = labels | |
| df["model_score"] = scores | |
| model_confidence = round(mean(scores), 5) | |
| fire_related = gr.CheckboxGroup(choices=df[df["model_label"]=="fire"][text_field].to_list()) #fires | |
| flood_related = gr.CheckboxGroup(choices=df[df["model_label"]=="flood"][text_field].to_list()) | |
| not_related = gr.CheckboxGroup(choices=df[df["model_label"]=="none"][text_field].to_list()) | |
| return flood_related, fire_related, not_related, model_confidence, len(df[text_field].to_list()), df | |
| def qa_process(selections): | |
| selected_texts = selections | |
| analysis_results = [f"Word Count: {len(text.split())}" for text in selected_texts] | |
| result_df = pd.DataFrame({"Selected Text": selected_texts, "Analysis": analysis_results}) | |
| return result_df | |
| def calculate_accuracy(flood_selections, fire_selections, none_selections, num_posts, text_field, data_df): | |
| posts = data_df[text_field].to_list() | |
| selections = flood_selections + fire_selections + none_selections | |
| eval = [] | |
| for post in posts: | |
| if post in selections: | |
| eval.append("incorrect") | |
| else: | |
| eval.append("correct") | |
| data_df["model_eval"] = eval | |
| incorrect = len(selections) | |
| correct = num_posts - incorrect | |
| accuracy = (correct/num_posts)*100 | |
| data_df.to_csv("output.csv") | |
| return incorrect, correct, accuracy, data_df, gr.DownloadButton(label=f"Download CSV", value="output.csv", visible=True) | |
| def get_queries(): | |
| queries = [ | |
| "What areas are being evacuated?", | |
| "What areas are predicted to be impacted?", | |
| "What areas are without power?", | |
| "What barriers are hindering response efforts?", | |
| "What events have been canceled?", | |
| "What preparations are being made?", | |
| "What regions have announced a state of emergency?", | |
| "What roads are blocked / closed?", | |
| "What services have been closed?", | |
| "What warnings are currently in effect?", | |
| "Where are emergency services deployed?", | |
| "Where are emergency services needed?", | |
| "Where are evacuations needed?", | |
| "Where are people needing rescued?", | |
| "Where are recovery efforts taking place?", | |
| "Where has building or infrastructure damage occurred?", | |
| "Where has flooding occured?" | |
| "Where are volunteers being requested?", | |
| "Where has road damage occured?", | |
| "What area has the wildfire burned?", | |
| "Where have homes been damaged or destroyed?"] | |
| return gr.CheckboxGroup(choices=queries) | |
| with gr.Blocks() as demo: | |
| event_models = ["jayebaku/distilbert-base-multilingual-cased-crexdata-relevance-classifier"] | |
| with gr.Tab("Event Type Classification"): | |
| gr.Markdown( | |
| """ | |
| # T4.5 Relevance Classifier Demo | |
| This is a demo created to explore floods and wildfire classification in social media posts.\n | |
| Usage:\n | |
| -Upload .tsv data file (must contain a text column with social media posts).\n | |
| -Next, type the name of the text column.\n | |
| -Then, choose a BERT classifier model from the drop down.\n | |
| -Finally, click the 'start prediction' buttton.\n | |
| Evaluation:\n | |
| -To evaluate the model's accuracy select the INCORRECT classifications using the checkboxes in front of each post.\n | |
| -Then, click on the 'Calculate Accuracy' button.\n | |
| -Then, click on the 'Download data as CSV' to get the classifications and evaluation data as a .csv file. | |
| """) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=4): | |
| file_input = gr.File(label="Upload CSV File") | |
| with gr.Column(scale=6): | |
| text_field = gr.Textbox(label="Text field name", value="tweet_text") | |
| event_model = gr.Dropdown(event_models, label="Select classification model") | |
| predict_button = gr.Button("Start Prediction") | |
| with gr.Row(): # XXX confirm this is not a problem later --equal_height=True | |
| with gr.Column(): | |
| gr.Markdown("""### Flood-related""") | |
| flood_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) | |
| with gr.Column(): | |
| gr.Markdown("""### Fire-related""") | |
| fire_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) | |
| with gr.Column(): | |
| gr.Markdown("""### None""") | |
| none_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=5): | |
| gr.Markdown(r""" | |
| Accuracy: is the model's ability to make correct predicitons. | |
| It is the fraction of correct prediction out of the total predictions. | |
| $$ | |
| \text{Accuracy} = \frac{\text{Correct predictions}}{\text{All predictions}} * 100 | |
| $$ | |
| Model Confidence: is the mean probabilty of each case | |
| belonging to their assigned classes. A value of 1 is best. | |
| """, latex_delimiters=[{ "left": "$$", "right": "$$", "display": True }]) | |
| gr.Markdown("\n\n\n") | |
| model_confidence = gr.Number(label="Model Confidence") | |
| with gr.Column(scale=5): | |
| correct = gr.Number(label="Number of correct classifications") | |
| incorrect = gr.Number(label="Number of incorrect classifications") | |
| accuracy = gr.Number(label="Model Accuracy (%)") | |
| accuracy_button = gr.Button("Calculate Accuracy") | |
| download_csv = gr.DownloadButton(visible=False) | |
| num_posts = gr.Number(visible=False) | |
| data = gr.DataFrame(visible=False) | |
| data_eval = gr.DataFrame(visible=False) | |
| predict_button.click( | |
| load_and_analyze_csv, | |
| inputs=[file_input, text_field, event_model], | |
| outputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, model_confidence, num_posts, data]) | |
| accuracy_button.click( | |
| calculate_accuracy, | |
| inputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, num_posts, text_field, data], | |
| outputs=[incorrect, correct, accuracy, data_eval, download_csv]) | |
| qa_tab = gr.Tab("Question Answering") | |
| qa_tab.select(get_queries, None, selected_queries) | |
| with qa_tab: | |
| # XXX Add some button disabling here, if the classification process is not completed first XXX | |
| selected_queries = gr.CheckboxGroup(label="Select at least one query using the checkboxes", interactive=True) | |
| qa_button = gr.Button("Start QA") | |
| analysis_output = gr.DataFrame(headers=["Selected Text", "Analysis"]) | |
| qa_button.click(qa_process, inputs=selected_queries, outputs=analysis_output) | |
| # analysis_button = gr.Button("Analyze Selected Texts") | |
| # analysis_output = gr.DataFrame(headers=["Selected Text", "Analysis"]) | |
| # analysis_button.click(analyze_selected_texts, inputs=flood_checkbox_output, outputs=analysis_output) | |
| demo.launch() |