Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline | |
class TwitterEmotionClassifier: | |
def __init__(self, model_name: str, model_type: str): | |
self.is_gpu = False | |
self.model_type = model_type | |
device = torch.device("cuda") if self.is_gpu else torch.device("cpu") | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model.to(device) | |
model.eval() | |
self.bertweet = pipeline( | |
"text-classification", | |
model=model, | |
tokenizer=tokenizer, | |
device=self.is_gpu - 1, | |
) | |
self.deberta = None | |
self.emotions = { | |
"LABEL_0": "sadness", | |
"LABEL_1": "joy", | |
"LABEL_2": "love", | |
"LABEL_3": "anger", | |
"LABEL_4": "fear", | |
"LABEL_5": "surprise", | |
} | |
def get_model(self, model_type: str): | |
if self.model_type == "bertweet" and model_type == self.model_type: | |
return self.bertweet | |
elif model_type == "deberta": | |
if self.deberta: | |
return self.deberta | |
model = AutoModelForSequenceClassification.from_pretrained( | |
"Emanuel/twitter-emotion-deberta-v3-base" | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"Emanuel/twitter-emotion-deberta-v3-base" | |
) | |
self.deberta = pipeline( | |
"text-classification", | |
model=model, | |
tokenizer=tokenizer, | |
device=self.is_gpu - 1, | |
) | |
return self.deberta | |
def predict(self, twitter: str, model_type: str): | |
classifier = self.get_model(model_type) | |
preds = classifier(twitter, return_all_scores=True) | |
if preds: | |
pred = preds[0] | |
res = { | |
"Sadness ๐ข": pred[0]["score"], | |
"Joy ๐": pred[1]["score"], | |
"Love ๐": pred[2]["score"], | |
"Anger ๐ ": pred[3]["score"], | |
"Fear ๐ฑ": pred[4]["score"], | |
"Surprise ๐ฎ": pred[5]["score"], | |
} | |
return res | |
return None | |
def main(): | |
model = TwitterEmotionClassifier("Emanuel/bertweet-emotion-base", "bertweet") | |
interFace = gr.Interface( | |
fn=model.predict, | |
inputs=[ | |
gr.inputs.Textbox( | |
placeholder="What's happenning?", label="Tweet content", lines=5 | |
), | |
gr.inputs.Radio(["bertweet", "deberta"], label="Model"), | |
], | |
outputs=gr.outputs.Label(num_top_classes=6, label="Emotions of this tweet is "), | |
verbose=True, | |
examples=[ | |
["Tesla Bot is truly amazing. It's the early steps of a revolution in the role that AI & robots play in human civilization. | |
What the Tesla team was been able to accomplish in the last few months is just incredible. | |
As someone who loves AI and robotics, I'm inspired beyond words.", "bertweet"], | |
[ | |
"I got food poisoning. It sucks ๐ฅต but it makes me appreciate: | |
1. the days when I'm not sick and | |
2. just how damn incredible the human body is at fighting off all the things that try to kill it. | |
Biology is awesome. Life is awesome.", | |
"bertweet", | |
], | |
["I'm adding human-created captions to many podcasts soon. (It's expensive ๐) These identify the speaker, are timed to the audio, and so make for good training data. | |
When you and I do a podcast, we too will become immortalized as training data.", "bertweet"], | |
[ | |
"We live inside a simulation and are ourselves creating progressively more realistic and interesting simulations. Existence is a recursive simulation generator.", | |
"bertweet", | |
], | |
["Here's my conversation with Will Sasso, one of the funniest people on the planet and someone who I've been a fan of for over 20 years. https://youtube.com/watch?v=xewD1apJNhw | |
PS: His | |
account | |
@WillSasso | |
got hacked yesterday. | |
@TwitterSupport | |
please help him out!", "deberta"], | |
], | |
title="Emotion classification ๐ค", | |
description="", | |
theme="huggingface", | |
) | |
interFace.launch() | |
if __name__ == "__main__": | |
main() |