Spaces:
Running
Running
import os | |
import numpy as np | |
import torch | |
import gradio as gr | |
from transformers import TextClassificationPipeline, DistilBertTokenizer, DistilBertForSequenceClassification | |
# HuggingFace dataset to save the flagged examples | |
HF_TOKEN = os.getenv('HF_TOKEN') | |
hf_saver = gr.HuggingFaceDatasetSaver(HF_TOKEN, "wfh-problematic") | |
# model path in hugginface | |
model_path = "yabramuvdi/distilbert-wfh" | |
tokenizer = DistilBertTokenizer.from_pretrained(model_path, use_auth_token=HF_TOKEN) | |
model = DistilBertForSequenceClassification.from_pretrained(model_path, use_auth_token=HF_TOKEN) | |
# create a pipeline for predictions | |
classifier = TextClassificationPipeline(model=model, | |
tokenizer=tokenizer, | |
return_all_scores=True) | |
# basic elements of page | |
title = "Remote Work Detection Application" | |
#description = "Online tool accompanying the paper _Remote Work across Jobs, Employers, and Countries_ (Hansen et al. 2022). The application allows the user to input arbitrary text and receive a predicted probability of the text exhibiting the possibility of remote work." | |
description = "This page lets users test the __Work-from-Home Algorithmic Measurement (WHAM)__ model used in the paper _“Remote Work across Jobs, Companies and Countries” (Hansen, Lambert, Bloom, Davis, Sadun & Taska, 2022)_. It is maintained by Yabra Muvdi, who works as a pre-doctoral researcher for Professor Hansen. Yabra was instrumental in developing WHAM. \n\n The application allows users to input any arbitrary text and computes the predicted probability of the text exhibiting the possibility of remote work. Users can also flag any examples that are incorrectly classified by the model. This is simply done by clicking on the _“Flag”_ button and then selecting _“mistake”_." | |
article = "" # text at the end of the app | |
examples = [ | |
["This is a work from home position.", 0.5], | |
["This position does not allow remote work.", 0.5], | |
] | |
#%% | |
def predict_wfh(input_text, input_slider): | |
# get scores from model | |
predictions = classifier(input_text)[0] | |
# use selected threshold to classify as WFH | |
prob_wfh = predictions[1]["score"] | |
if prob_wfh > input_slider: | |
wfh = 1 | |
no_wfh = 0 | |
else: | |
wfh = 0 | |
no_wfh = 1 | |
return({"Not work from home": no_wfh, "Work from home": wfh}, f"Probability of WFH: {np.round(prob_wfh, 3)}") | |
label = gr.outputs.Label(num_top_classes=1, type="confidences", label="Binary classification") | |
text_output = gr.outputs.Textbox(type="auto", label="Predicted probability") | |
app = gr.Interface(fn=[predict_wfh], | |
inputs=[gr.inputs.Textbox(lines=10, label="Input text"), gr.inputs.Slider(0, 1, 0.001, label="Classification threshold", default=0.5)], | |
outputs=[label, text_output], | |
theme="huggingface", | |
title=title, | |
description=description, | |
article=article, | |
examples=examples, | |
allow_flagging="manual", | |
flagging_options=["mistake"], | |
flagging_callback=hf_saver | |
) | |
#app.launch(auth=("yabra", "wfh123"), auth_message="Authentication Problem") | |
app.launch() | |