logasanjeev's picture
Update app.py
4df5a3d verified
raw
history blame
11.8 kB
import gradio as gr
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import shutil
import os
import torch
from huggingface_hub import hf_hub_download
from importlib import import_module
# Load inference.py and model
repo_id = "logasanjeev/goemotions-bert"
local_file = hf_hub_download(repo_id=repo_id, filename="inference.py")
print("Downloaded inference.py successfully!")
current_dir = os.getcwd()
destination = os.path.join(current_dir, "inference.py")
shutil.copy(local_file, destination)
print("Copied inference.py to current directory!")
inference_module = import_module("inference")
predict_emotions = inference_module.predict_emotions
print("Imported predict_emotions successfully!")
_, _ = predict_emotions("dummy text")
emotion_labels = inference_module.EMOTION_LABELS
default_thresholds = inference_module.THRESHOLDS
# Prediction function with export capability
def predict_emotions_with_details(text, confidence_threshold=0.0, chart_type="bar"):
if not text.strip():
return "Please enter some text.", "", "", None, None, False
predictions_str, processed_text = predict_emotions(text)
# Parse predictions
predictions = []
if predictions_str != "No emotions predicted.":
for line in predictions_str.split("\n"):
emotion, confidence = line.split(": ")
predictions.append((emotion, float(confidence)))
# Get raw logits for all emotions
encodings = inference_module.TOKENIZER(
processed_text,
padding='max_length',
truncation=True,
max_length=128,
return_tensors='pt'
)
input_ids = encodings['input_ids'].to(inference_module.DEVICE)
attention_mask = encodings['attention_mask'].to(inference_module.DEVICE)
with torch.no_grad():
outputs = inference_module.MODEL(input_ids, attention_mask=attention_mask)
logits = torch.sigmoid(outputs.logits).cpu().numpy()[0]
# All emotions for top 5
all_emotions = [(emotion_labels[i], round(logit, 4)) for i, logit in enumerate(logits)]
all_emotions.sort(key=lambda x: x[1], reverse=True)
top_5_emotions = all_emotions[:5]
top_5_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in top_5_emotions])
# Filter predictions based on threshold
filtered_predictions = []
for emotion, confidence in predictions:
thresh = default_thresholds[emotion_labels.index(emotion)]
adjusted_thresh = max(thresh, confidence_threshold)
if confidence >= adjusted_thresh:
filtered_predictions.append((emotion, confidence))
if not filtered_predictions:
thresholded_output = "No emotions predicted above thresholds."
else:
thresholded_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in filtered_predictions])
# Create visualization
fig = None
df_export = None
if filtered_predictions:
df = pd.DataFrame(filtered_predictions, columns=["Emotion", "Confidence"])
df_export = df.copy()
if chart_type == "bar":
fig = px.bar(
df,
x="Emotion",
y="Confidence",
color="Emotion",
text="Confidence",
title="Emotion Confidence Levels (Above Threshold)",
height=400,
color_discrete_sequence=px.colors.qualitative.Plotly
)
fig.update_traces(texttemplate='%{text:.2f}', textposition='auto')
fig.update_layout(showlegend=False, margin=dict(t=40, b=40), xaxis_title="", yaxis_title="Confidence")
else: # pie chart
fig = px.pie(
df,
names="Emotion",
values="Confidence",
title="Emotion Confidence Distribution (Above Threshold)",
height=400,
color_discrete_sequence=px.colors.qualitative.Plotly
)
fig.update_traces(textinfo='percent+label', pull=[0.1] + [0] * (len(df) - 1))
fig.update_layout(margin=dict(t=40, b=40))
return processed_text, thresholded_output, top_5_output, fig, df_export, True
# Custom CSS for enhanced styling
custom_css = """
body {
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
background-color: #f5f7fa;
}
.gr-panel {
border-radius: 16px;
box-shadow: 0 6px 20px rgba(0,0,0,0.08);
background: white;
padding: 20px;
margin-bottom: 20px;
}
.gr-button {
border-radius: 8px;
padding: 12px 24px;
font-weight: 600;
transition: all 0.3s ease;
}
.gr-button-primary {
background: #4a90e2;
color: white;
}
.gr-button-primary:hover {
background: #357abd;
}
.gr-button-secondary {
background: #e4e7eb;
color: #333;
}
.gr-button-secondary:hover {
background: #d1d5db;
}
#title {
font-size: 2.8em;
font-weight: 700;
color: #1a3c6e;
text-align: center;
margin-bottom: 10px;
}
#description {
font-size: 1.2em;
color: #555;
text-align: center;
max-width: 800px;
margin: 0 auto 30px auto;
}
#theme-toggle {
position: fixed;
top: 20px;
right: 20px;
background: none;
border: none;
font-size: 1.5em;
cursor: pointer;
transition: transform 0.3s;
}
#theme-toggle:hover {
transform: scale(1.2);
}
.dark-mode {
background: #1e2a44;
color: #e0e0e0;
}
.dark-mode .gr-panel {
background: #2a3a5a;
box-shadow: 0 6px 20px rgba(0,0,0,0.2);
}
.dark-mode #title {
color: #66b3ff;
}
.dark-mode #description {
color: #b0b0b0;
}
.dark-mode .gr-button-secondary {
background: #3a4a6a;
color: #e0e0e0;
}
.dark-mode .gr-button-secondary:hover {
background: #4a5a7a;
}
#loading {
font-style: italic;
color: #888;
text-align: center;
display: none;
}
#loading.visible {
display: block;
}
#examples-title {
font-size: 1.5em;
font-weight: 600;
color: #1a3c6e;
margin-bottom: 10px;
}
.dark-mode #examples-title {
color: #66b3ff;
}
footer {
text-align: center;
margin-top: 40px;
padding: 20px;
font-size: 0.9em;
color: #666;
}
footer a {
color: #4a90e2;
text-decoration: none;
}
footer a:hover {
text-decoration: underline;
}
.dark-mode footer {
color: #b0b0b0;
}
"""
# JavaScript for theme toggle
theme_js = """
function toggleTheme() {
document.body.classList.toggle('dark-mode');
const toggleBtn = document.getElementById('theme-toggle');
toggleBtn.innerHTML = document.body.classList.contains('dark-mode') ? '☀️' : '🌙';
}
function showLoading() {
document.getElementById('loading').classList.add('visible');
}
function hideLoading() {
document.getElementById('loading').classList.remove('visible');
}
"""
# Gradio Blocks UI
with gr.Blocks(css=custom_css) as demo:
# Theme toggle button
gr.HTML(
"""
<button id='theme-toggle' onclick='toggleTheme()'>🌙</button>
<script>{}</script>
""".format(theme_js)
)
# Header
gr.Markdown("<div id='title'>GoEmotions BERT Classifier</div>", elem_id="title")
gr.Markdown(
"""
<div id='description'>
Predict emotions from text using a fine-tuned BERT-base model on the GoEmotions dataset.
Detect 28 emotions with optimized thresholds (Micro F1: 0.6006).
View preprocessed text, top 5 emotions, and thresholded predictions with interactive visualizations!
</div>
""",
elem_id="description"
)
# Main content
with gr.Row():
with gr.Column(scale=1):
# Input Section
with gr.Group():
gr.Markdown("### Input Text")
text_input = gr.Textbox(
label="Enter Your Text",
placeholder="Type something like 'I’m just chilling today'...",
lines=3,
show_label=False
)
confidence_slider = gr.Slider(
minimum=0.0,
maximum=0.9,
value=0.0,
step=0.05,
label="Minimum Confidence Threshold",
info="Filter predictions below this confidence level (default thresholds still apply)"
)
chart_type = gr.Radio(
choices=["bar", "pie"],
value="bar",
label="Chart Type",
info="Choose how to visualize the emotion confidences"
)
with gr.Row():
submit_btn = gr.Button("Predict Emotions", variant="primary")
reset_btn = gr.Button("Reset", variant="secondary")
# Loading indicator
loading_indicator = gr.HTML("<div id='loading'>Predicting emotions, please wait...</div>")
# Output Section
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### Results")
processed_text_output = gr.Textbox(label="Preprocessed Text", lines=2, interactive=False)
thresholded_output = gr.Textbox(label="Predicted Emotions (Above Threshold)", lines=5, interactive=False)
top_5_output = gr.Textbox(label="Top 5 Emotions (Regardless of Threshold)", lines=5, interactive=False)
output_plot = gr.Plot(label="Emotion Confidence Visualization (Above Threshold)")
# Export predictions
export_btn = gr.File(label="Download Predictions as CSV", visible=False)
# Example carousel
with gr.Group():
gr.Markdown("<div id='examples-title'>Example Texts</div>", elem_id="examples-title")
examples = gr.Examples(
examples=[
["I’m just chilling today.", "Neutral Example"],
["Thank you for saving my life!", "Gratitude Example"],
["I’m nervous about my exam tomorrow.", "Nervousness Example"],
["I love my new puppy so much!", "Love Example"],
["I’m so relieved the storm passed.", "Relief Example"]
],
inputs=[text_input],
label="",
examples_per_page=3
)
# Footer
gr.HTML(
"""
<footer>
Built with ❤️ by logasanjeev |
<a href="https://huggingface.co/logasanjeev/goemotions-bert">Model Card</a> |
<a href="https://www.kaggle.com/code/ravindranlogasanjeev/evaluation-logasanjeev-goemotions-bert/notebook">Kaggle Notebook</a> |
<a href="https://github.com/logasanjeev">GitHub</a>
</footer>
"""
)
# State to manage loading visibility
loading_state = gr.State(value=False)
# Bind predictions with loading spinner
def start_loading():
return True
def stop_loading(processed_text, thresholded_output, top_5_output, fig, df_export, loading_state):
return processed_text, thresholded_output, top_5_output, fig, df_export, False
submit_btn.click(
fn=start_loading,
inputs=[],
outputs=[loading_state]
).then(
fn=predict_emotions_with_details,
inputs=[text_input, confidence_slider, chart_type],
outputs=[processed_text_output, thresholded_output, top_5_output, output_plot, export_btn, loading_state]
)
# Reset functionality
reset_btn.click(
fn=lambda: ("", "", "", None, None, False),
inputs=[],
outputs=[text_input, processed_text_output, thresholded_output, top_5_output, output_plot, export_btn, loading_state]
)
# Launch
if __name__ == "__main__":
demo.launch()