logasanjeev's picture
Update app.py
b5fc4a7 verified
raw
history blame
6.99 kB
import gradio as gr
import pandas as pd
import plotly.express as px
import shutil
import os
import torch
from huggingface_hub import hf_hub_download
from importlib import import_module
# Load inference.py and model
repo_id = "logasanjeev/goemotions-bert"
local_file = hf_hub_download(repo_id=repo_id, filename="inference.py")
print("Downloaded inference.py successfully!")
current_dir = os.getcwd()
destination = os.path.join(current_dir, "inference.py")
shutil.copy(local_file, destination)
print("Copied inference.py to current directory!")
inference_module = import_module("inference")
predict_emotions = inference_module.predict_emotions
print("Imported predict_emotions successfully!")
_, _ = predict_emotions("dummy text")
emotion_labels = inference_module.EMOTION_LABELS
default_thresholds = inference_module.THRESHOLDS
# Prediction function (simplified, no export)
def predict_emotions_with_details(text, confidence_threshold=0.0):
if not text.strip():
return "Please enter some text.", "", None
predictions_str, processed_text = predict_emotions(text)
# Parse predictions
predictions = []
if predictions_str != "No emotions predicted.":
for line in predictions_str.split("\n"):
emotion, confidence = line.split(": ")
predictions.append((emotion, float(confidence)))
# Get raw logits for top 5 (though not displayed in this simplified version)
encodings = inference_module.TOKENIZER(
processed_text,
padding='max_length',
truncation=True,
max_length=128,
return_tensors='pt'
)
input_ids = encodings['input_ids'].to(inference_module.DEVICE)
attention_mask = encodings['attention_mask'].to(inference_module.DEVICE)
with torch.no_grad():
outputs = inference_module.MODEL(input_ids, attention_mask=attention_mask)
logits = torch.sigmoid(outputs.logits).cpu().numpy()[0]
# Filter predictions based on threshold
filtered_predictions = []
for emotion, confidence in predictions:
thresh = default_thresholds[emotion_labels.index(emotion)]
adjusted_thresh = max(thresh, confidence_threshold)
if confidence >= adjusted_thresh:
filtered_predictions.append((emotion, confidence))
if not filtered_predictions:
thresholded_output = "No emotions predicted above thresholds."
else:
thresholded_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in filtered_predictions])
# Create bar chart
fig = None
if filtered_predictions:
df = pd.DataFrame(filtered_predictions, columns=["Emotion", "Confidence"])
fig = px.bar(
df,
x="Emotion",
y="Confidence",
color="Emotion",
text="Confidence",
title="Emotion Confidence Levels",
height=300,
color_discrete_sequence=px.colors.qualitative.Pastel
)
fig.update_traces(texttemplate='%{text:.2f}', textposition='auto')
fig.update_layout(showlegend=False, margin=dict(t=40, b=40), xaxis_title="", yaxis_title="Confidence")
return processed_text, thresholded_output, fig
# Simplified CSS
custom_css = """
body {
font-family: 'Arial', sans-serif;
background-color: #f9f9f9;
}
.gr-panel {
border-radius: 8px;
box-shadow: 0 2px 10px rgba(0,0,0,0.05);
background: white;
padding: 15px;
margin-bottom: 15px;
}
.gr-button {
border-radius: 6px;
padding: 10px 20px;
font-weight: 500;
background: #4a90e2;
color: white;
transition: background 0.3s ease;
}
.gr-button:hover {
background: #357abd;
}
#title {
font-size: 2.2em;
font-weight: 600;
color: #333;
text-align: center;
margin-bottom: 10px;
}
#description {
font-size: 1.1em;
color: #666;
text-align: center;
max-width: 600px;
margin: 0 auto 20px auto;
}
#examples-title {
font-size: 1.3em;
font-weight: 500;
color: #333;
margin-bottom: 10px;
}
footer {
text-align: center;
margin-top: 30px;
padding: 15px;
font-size: 0.9em;
color: #666;
}
footer a {
color: #4a90e2;
text-decoration: none;
}
footer a:hover {
text-decoration: underline;
}
"""
# Gradio Blocks UI (Simplified)
with gr.Blocks(css=custom_css) as demo:
# Header
gr.Markdown("<div id='title'>GoEmotions BERT Classifier</div>", elem_id="title")
gr.Markdown(
"""
<div id='description'>
Predict emotions from text using a fine-tuned BERT model.
Enter your text below to see the detected emotions and their confidence scores.
</div>
""",
elem_id="description"
)
# Input Section
with gr.Group():
text_input = gr.Textbox(
label="Enter Your Text",
placeholder="Type something like 'I’m just chilling today'...",
lines=2,
show_label=False
)
confidence_slider = gr.Slider(
minimum=0.0,
maximum=0.9,
value=0.0,
step=0.05,
label="Minimum Confidence Threshold",
info="Filter predictions below this confidence level (default thresholds still apply)"
)
submit_btn = gr.Button("Predict Emotions")
# Output Section
with gr.Group():
processed_text_output = gr.Textbox(label="Preprocessed Text", lines=1, interactive=False)
thresholded_output = gr.Textbox(label="Predicted Emotions", lines=3, interactive=False)
output_plot = gr.Plot(label="Emotion Confidence Chart")
# Example carousel
with gr.Group():
gr.Markdown("<div id='examples-title'>Try These Examples</div>", elem_id="examples-title")
examples = gr.Examples(
examples=[
["I’m thrilled to win this award! πŸ˜„", "Joy Example"],
["This is so frustrating, nothing works. 😣", "Annoyance Example"],
["I feel so sorry for what happened. 😒", "Sadness Example"],
["What a beautiful day to be alive! 🌞", "Admiration Example"],
["Feeling nervous about the exam tomorrow πŸ˜“ u/student r/study", "Nervousness Example"]
],
inputs=[text_input],
label=""
)
# Footer
gr.HTML(
"""
<footer>
Built by logasanjeev |
<a href="https://huggingface.co/logasanjeev/goemotions-bert">Model Card</a> |
<a href="https://www.kaggle.com/code/ravindranlogasanjeev/evaluation-logasanjeev-goemotions-bert/notebook">Kaggle Notebook</a>
</footer>
"""
)
# Bind predictions
submit_btn.click(
fn=predict_emotions_with_details,
inputs=[text_input, confidence_slider],
outputs=[processed_text_output, thresholded_output, output_plot]
)
# Launch
if __name__ == "__main__":
demo.launch()