Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| from transformers import TextClassificationPipeline, DistilBertTokenizer, DistilBertForSequenceClassification | |
| # HuggingFace dataset to save the flagged examples | |
| HF_TOKEN = os.getenv('HF_TOKEN') | |
| hf_saver = gr.HuggingFaceDatasetSaver(HF_TOKEN, "wfh-problematic") | |
| # model path in hugginface | |
| model_path = "yabramuvdi/distilbert-wfh" | |
| tokenizer = DistilBertTokenizer.from_pretrained(model_path, use_auth_token=HF_TOKEN) | |
| model = DistilBertForSequenceClassification.from_pretrained(model_path, use_auth_token=HF_TOKEN) | |
| # create a pipeline for predictions | |
| classifier = TextClassificationPipeline(model=model, | |
| tokenizer=tokenizer, | |
| return_all_scores=True) | |
| # basic elements of page | |
| title = "Work From Home Predictor" | |
| description = "Demo application that predicts the pressence of work from home in any given sequence of text." | |
| article = "" # text at the end of the app | |
| examples = [ | |
| ["This is a work from home position.", 0.9], | |
| ["This position does not allow remote work.", 0.5], | |
| ] | |
| #%% | |
| def predict_wfh(input_text, input_slider): | |
| # get scores from model | |
| predictions = classifier(input_text)[0] | |
| # use selected threshold to classify as WFH | |
| prob_wfh = predictions[1]["score"] | |
| if prob_wfh > input_slider: | |
| wfh = 1 | |
| no_wfh = 0 | |
| else: | |
| wfh = 0 | |
| no_wfh = 1 | |
| return({"Not work from home": no_wfh, "Work from home": wfh}, f"Probability of WFH: {np.round(prob_wfh, 3)}") | |
| label = gr.outputs.Label(num_top_classes=1, type="confidences", label="Binary classification") | |
| text_output = gr.outputs.Textbox(type="auto", label="Predicted probability") | |
| app = gr.Interface(fn=[predict_wfh], | |
| inputs=[gr.inputs.Textbox(lines=10, label="Input text"), gr.inputs.Slider(0, 1, 0.001, label="Classification threshold", default=0.998)], | |
| outputs=[label, text_output], | |
| theme="huggingface", | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=examples, | |
| allow_flagging="manual", | |
| flagging_options=["mistake"], | |
| flagging_callback=hf_saver | |
| ) | |
| app.launch() |