|
|
|
import gradio as gr |
|
import os |
|
import torch |
|
|
|
from transformers import DistilBertTokenizerFast |
|
from timeit import default_timer as timer |
|
|
|
|
|
class_names = ["Positive", "Negative",] |
|
|
|
|
|
|
|
model = torch.load(f="BERT_sentiment_analysis.pth", |
|
map_location=torch.device("cpu")) |
|
|
|
|
|
|
|
|
|
def predict(text: str): |
|
"""Transforms and performs a prediction on img and returns prediction and time taken. |
|
""" |
|
|
|
start_time = timer() |
|
|
|
tokenizer = DistilBertTokenizerFast.from_pretrained( |
|
'distilbert-base-uncased' |
|
) |
|
|
|
input = tokenizer(text, return_tensors="pt").to("cpu") |
|
|
|
model.eval() |
|
with torch.inference_mode(): |
|
|
|
logits = model(**input).logits |
|
predicted_class_id = logits.argmax().item() |
|
|
|
if predicted_class_id == 1: |
|
result = "Positive π" |
|
else: |
|
result = "Negative π" |
|
|
|
|
|
pred_time = round(timer() - start_time, 5) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
title = "Sentiment Classifier" |
|
description = "A Sentiment Classifier trained by fine-tuning [DistilBert](https://huggingface.co/docs/transformers/v4.42.0/en/model_doc/distilbert#transformers.DistilBertForSequenceClassification) Transformer model using hugging face [transformers](https://huggingface.co/docs/transformers/en/index) library." |
|
article = "The model classifies sentiment of an input text (whether the text shows a positive or negative sentiment)." |
|
|
|
|
|
demo = gr.Interface(fn=predict, |
|
inputs=[gr.Textbox(label="Input")], |
|
outputs=[gr.Label(label="Prediction")], |
|
title=title, |
|
description=description, |
|
article=article) |
|
|
|
|
|
demo.launch() |
|
|