Spaces:
Sleeping
Sleeping
import os | |
import time | |
import gradio as gr | |
import pandas as pd | |
from classifier import classify | |
from statistics import mean | |
HFTOKEN = os.environ["HF_TOKEN"] | |
def load_and_analyze_csv(file, text_field, event_model): | |
df = pd.read_table(file.name) | |
if text_field not in df.columns: | |
raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.") | |
floods, fires, nones, scores = [], [], [], [] | |
for post in df[text_field].to_list(): | |
res = classify(post, event_model, HFTOKEN) | |
if res["event"] == 'flood': | |
floods.append(post) | |
elif res["event"] == 'fire': | |
fires.append(post) | |
else: | |
nones.append(post) | |
scores.append(res["score"]) | |
model_confidence = round(mean(scores), 5) | |
fire_related = gr.CheckboxGroup(choices=fires) | |
flood_related = gr.CheckboxGroup(choices=floods) | |
not_related = gr.CheckboxGroup(choices=nones) | |
return flood_related, fire_related, not_related, model_confidence | |
def analyze_selected_texts(selections): | |
selected_texts = selections | |
analysis_results = [f"Word Count: {len(text.split())}" for text in selected_texts] | |
result_df = pd.DataFrame({"Selected Text": selected_texts, "Analysis": analysis_results}) | |
return result_df | |
with gr.Blocks() as demo: | |
event_models = ["jayebaku/distilbert-base-multilingual-cased-crexdata-relevance-classifier"] | |
with gr.Tab("Event Type Classification"): | |
gr.Markdown( | |
""" | |
# T4.5 Relevance Classifier Demo | |
This is a demo created to explore floods and wildfire classification in social media posts.\n | |
Usage:\n | |
\tUpload .tsv data file (must contain a text column with social media posts).\n | |
\tNext, type the name of the text column.\n | |
\tThen, choose a BERT classifier model from the drop down.\n | |
\tFinally, click the 'start classification' buttton.\n | |
Evaluation:\n | |
\tTo evaluate the model's accuracy select the INCORRECT classifications using the checkboxes in front of each post.\n | |
\tThen, click on the 'Calculate Accuracy' button.\n | |
\tThen, click on the 'Download data as CSV' to get the classifications and evaluation data as a .csv file. | |
""") | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=4): | |
file_input = gr.File(label="Upload CSV File") | |
with gr.Column(scale=6): | |
text_field = gr.Textbox(label="Text field name", value="tweet_text") | |
event_model = gr.Dropdown(event_models, label="Select classification model") | |
predict_button = gr.Button("Start Prediction") | |
with gr.Row(): # XXX confirm this is not a problem later --equal_height=True | |
with gr.Column(): | |
gr.Markdown("""### Flood-related""") | |
flood_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) | |
with gr.Column(): | |
gr.Markdown("""### Fire-related""") | |
fire_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) | |
with gr.Column(): | |
gr.Markdown("""### None""") | |
none_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True) | |
model_confidence = gr.Number(label="Model Confidence") | |
predict_button.click(load_and_analyze_csv, inputs=[file_input, text_field, event_model], | |
outputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, model_confidence]) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=6): | |
gr.Markdown(r""" | |
Accuracy: is the model's ability to make correct predicitons. | |
It is the fraction of correct prediction out of the total predictions. | |
$ | |
\text{Accuracy} = \frac{\text{Correct predictions}}{\text{All predictions}} * 100 | |
$ | |
Model Confidence: is the mean probabilty of each case | |
belonging to their assigned classes. A value of 1 is best. | |
""", latex_delimiters=[{ "left": "$", "right": "$", "display": True }]) | |
with gr.Column(scale=4): | |
correct = gr.Number(label="Number of correct classifications", value=0) | |
incorrect = gr.Number(label="Number of incorrect classifications", value=0) | |
accuracy = gr.Number(label="Model Accuracy", value=0) | |
with gr.Tab("Question Answering"): | |
# XXX Add some button disabling here, if the classification process is not completed first XXX | |
analysis_button = gr.Button("Analyze Selected Texts") | |
analysis_output = gr.DataFrame(headers=["Selected Text", "Analysis"]) | |
analysis_button.click(analyze_selected_texts, inputs=flood_checkbox_output, outputs=analysis_output) | |
demo.launch() |