|
import streamlit as st |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch.nn.functional as F |
|
|
|
|
|
model_ckpt = "AbhishekBhavnani/TweetClassification" |
|
tokenizer = AutoTokenizer.from_pretrained(model_ckpt) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt) |
|
model.eval() |
|
|
|
|
|
st.title("Tweet Emotion Classifier") |
|
text = st.text_area("Enter your tweet here") |
|
|
|
if st.button("Predict"): |
|
if text.strip(): |
|
with torch.no_grad(): |
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
outputs = model(**inputs) |
|
probs = F.softmax(outputs.logits, dim=-1) |
|
top = torch.argmax(probs, dim=1).item() |
|
label = model.config.id2label[top] |
|
score = probs[0][top].item() |
|
|
|
st.success(f"**Prediction**: {label} ({score:.4f})") |
|
else: |
|
st.warning("Please enter a tweet.") |
|
|