wfh-app-v2 / app.py
yabramuvdi's picture
Update app.py
e89af88
raw
history blame
3.33 kB
import os
import numpy as np
import torch
import gradio as gr
from transformers import TextClassificationPipeline, DistilBertTokenizer, DistilBertForSequenceClassification
# HuggingFace dataset to save the flagged examples
HF_TOKEN = os.getenv('HF_TOKEN')
hf_saver = gr.HuggingFaceDatasetSaver(HF_TOKEN, "wfh-problematic")
# model path in hugginface
model_path = "yabramuvdi/distilbert-wfh"
tokenizer = DistilBertTokenizer.from_pretrained(model_path, use_auth_token=HF_TOKEN)
model = DistilBertForSequenceClassification.from_pretrained(model_path, use_auth_token=HF_TOKEN)
# create a pipeline for predictions
classifier = TextClassificationPipeline(model=model,
tokenizer=tokenizer,
return_all_scores=True)
# basic elements of page
title = "Remote Work Detection Application"
#description = "Online tool accompanying the paper _Remote Work across Jobs, Employers, and Countries_ (Hansen et al. 2022). The application allows the user to input arbitrary text and receive a predicted probability of the text exhibiting the possibility of remote work."
description = "This page lets users test the __Work-from-Home Algorithmic Measurement (WHAM)__ model used in the paper _“Remote Work across Jobs, Companies and Countries” (Hansen, Lambert, Bloom, Davis, Sadun & Taska, 2022)_. It is maintained by Yabra Muvdi, who works as a pre-doctoral researcher for Professor Hansen. Yabra was instrumental in developing WHAM. \n\n The application allows users to input any arbitrary text and computes the predicted probability of the text exhibiting the possibility of remote work. Users can also flag any examples that are incorrectly classified by the model. This is simply done by clicking on the _“Flag”_ button and then selecting _“mistake”_."
article = "" # text at the end of the app
examples = [
["This is a work from home position.", 0.5],
["This position does not allow remote work.", 0.5],
]
#%%
def predict_wfh(input_text, input_slider):
# get scores from model
predictions = classifier(input_text)[0]
# use selected threshold to classify as WFH
prob_wfh = predictions[1]["score"]
if prob_wfh > input_slider:
wfh = 1
no_wfh = 0
else:
wfh = 0
no_wfh = 1
return({"Not work from home": no_wfh, "Work from home": wfh}, f"Probability of WFH: {np.round(prob_wfh, 3)}")
label = gr.outputs.Label(num_top_classes=1, type="confidences", label="Binary classification")
text_output = gr.outputs.Textbox(type="auto", label="Predicted probability")
app = gr.Interface(fn=[predict_wfh],
inputs=[gr.inputs.Textbox(lines=10, label="Input text"), gr.inputs.Slider(0, 1, 0.001, label="Classification threshold", default=0.5)],
outputs=[label, text_output],
theme="huggingface",
title=title,
description=description,
article=article,
examples=examples,
allow_flagging="manual",
flagging_options=["mistake"],
flagging_callback=hf_saver
)
#app.launch(auth=("yabra", "wfh123"), auth_message="Authentication Problem")
app.launch()