wfh-app-v2 / app.py
yabramuvdi's picture
added one more option for tagging
bfd8f63
raw
history blame
2.64 kB
import os
import numpy as np
import torch
import gradio as gr
from transformers import TextClassificationPipeline, DistilBertTokenizer, DistilBertForSequenceClassification
# HuggingFace dataset to save the flagged examples
HF_TOKEN = os.getenv('HF_TOKEN')
hf_saver = gr.HuggingFaceDatasetSaver(HF_TOKEN, "wfh-problematic")
# model path in hugginface
model_path = "yabramuvdi/distilbert-wfh"
tokenizer = DistilBertTokenizer.from_pretrained(model_path, use_auth_token=HF_TOKEN)
model = DistilBertForSequenceClassification.from_pretrained(model_path, use_auth_token=HF_TOKEN)
# create a pipeline for predictions
classifier = TextClassificationPipeline(model=model,
tokenizer=tokenizer,
return_all_scores=True)
# basic elements of page
title = "Remote Work Detection Application"
description = "Online tool accompanying the paper _Remote Work across Jobs, Employers, and Countries_ (Hansen et al. 2022). The application allows the user to input arbitrary text and receive a predicted probability of the text exhibiting the possibility of remote work."
article = "" # text at the end of the app
examples = [
["This is a work from home position.", 0.5],
["This position does not allow remote work.", 0.5],
]
#%%
def predict_wfh(input_text, input_slider):
# get scores from model
predictions = classifier(input_text)[0]
# use selected threshold to classify as WFH
prob_wfh = predictions[1]["score"]
if prob_wfh > input_slider:
wfh = 1
no_wfh = 0
else:
wfh = 0
no_wfh = 1
return({"Not work from home": no_wfh, "Work from home": wfh}, f"Probability of WFH: {np.round(prob_wfh, 3)}")
label = gr.outputs.Label(num_top_classes=1, type="confidences", label="Binary classification")
text_output = gr.outputs.Textbox(type="auto", label="Predicted probability")
app = gr.Interface(fn=[predict_wfh],
inputs=[gr.inputs.Textbox(lines=10, label="Input text"), gr.inputs.Slider(0, 1, 0.001, label="Classification threshold", default=0.5)],
outputs=[label, text_output],
theme="huggingface",
title=title,
description=description,
article=article,
examples=examples,
allow_flagging="manual",
flagging_options=["mistake", "problematic"],
flagging_callback=hf_saver
)
#app.launch(auth=("yabra", "wfh123"), auth_message="Authentication Problem")
app.launch()