|
import gradio as gr |
|
import pandas as pd |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
import shutil |
|
import os |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from importlib import import_module |
|
|
|
|
|
repo_id = "logasanjeev/goemotions-bert" |
|
local_file = hf_hub_download(repo_id=repo_id, filename="inference.py") |
|
print("Downloaded inference.py successfully!") |
|
|
|
current_dir = os.getcwd() |
|
destination = os.path.join(current_dir, "inference.py") |
|
shutil.copy(local_file, destination) |
|
print("Copied inference.py to current directory!") |
|
|
|
inference_module = import_module("inference") |
|
predict_emotions = inference_module.predict_emotions |
|
print("Imported predict_emotions successfully!") |
|
|
|
_, _ = predict_emotions("dummy text") |
|
emotion_labels = inference_module.EMOTION_LABELS |
|
default_thresholds = inference_module.THRESHOLDS |
|
|
|
|
|
def predict_emotions_with_details(text, confidence_threshold=0.0, chart_type="bar"): |
|
if not text.strip(): |
|
return "Please enter some text.", "", "", None, None, False |
|
|
|
predictions_str, processed_text = predict_emotions(text) |
|
|
|
|
|
predictions = [] |
|
if predictions_str != "No emotions predicted.": |
|
for line in predictions_str.split("\n"): |
|
emotion, confidence = line.split(": ") |
|
predictions.append((emotion, float(confidence))) |
|
|
|
|
|
encodings = inference_module.TOKENIZER( |
|
processed_text, |
|
padding='max_length', |
|
truncation=True, |
|
max_length=128, |
|
return_tensors='pt' |
|
) |
|
input_ids = encodings['input_ids'].to(inference_module.DEVICE) |
|
attention_mask = encodings['attention_mask'].to(inference_module.DEVICE) |
|
|
|
with torch.no_grad(): |
|
outputs = inference_module.MODEL(input_ids, attention_mask=attention_mask) |
|
logits = torch.sigmoid(outputs.logits).cpu().numpy()[0] |
|
|
|
|
|
all_emotions = [(emotion_labels[i], round(logit, 4)) for i, logit in enumerate(logits)] |
|
all_emotions.sort(key=lambda x: x[1], reverse=True) |
|
top_5_emotions = all_emotions[:5] |
|
top_5_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in top_5_emotions]) |
|
|
|
|
|
filtered_predictions = [] |
|
for emotion, confidence in predictions: |
|
thresh = default_thresholds[emotion_labels.index(emotion)] |
|
adjusted_thresh = max(thresh, confidence_threshold) |
|
if confidence >= adjusted_thresh: |
|
filtered_predictions.append((emotion, confidence)) |
|
|
|
if not filtered_predictions: |
|
thresholded_output = "No emotions predicted above thresholds." |
|
else: |
|
thresholded_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in filtered_predictions]) |
|
|
|
|
|
fig = None |
|
df_export = None |
|
if filtered_predictions: |
|
df = pd.DataFrame(filtered_predictions, columns=["Emotion", "Confidence"]) |
|
df_export = df.copy() |
|
|
|
if chart_type == "bar": |
|
fig = px.bar( |
|
df, |
|
x="Emotion", |
|
y="Confidence", |
|
color="Emotion", |
|
text="Confidence", |
|
title="Emotion Confidence Levels (Above Threshold)", |
|
height=400, |
|
color_discrete_sequence=px.colors.qualitative.Plotly |
|
) |
|
fig.update_traces(texttemplate='%{text:.2f}', textposition='auto') |
|
fig.update_layout(showlegend=False, margin=dict(t=40, b=40), xaxis_title="", yaxis_title="Confidence") |
|
else: |
|
fig = px.pie( |
|
df, |
|
names="Emotion", |
|
values="Confidence", |
|
title="Emotion Confidence Distribution (Above Threshold)", |
|
height=400, |
|
color_discrete_sequence=px.colors.qualitative.Plotly |
|
) |
|
fig.update_traces(textinfo='percent+label', pull=[0.1] + [0] * (len(df) - 1)) |
|
fig.update_layout(margin=dict(t=40, b=40)) |
|
|
|
return processed_text, thresholded_output, top_5_output, fig, df_export, True |
|
|
|
|
|
custom_css = """ |
|
body { |
|
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; |
|
background-color: #f5f7fa; |
|
} |
|
.gr-panel { |
|
border-radius: 16px; |
|
box-shadow: 0 6px 20px rgba(0,0,0,0.08); |
|
background: white; |
|
padding: 20px; |
|
margin-bottom: 20px; |
|
} |
|
.gr-button { |
|
border-radius: 8px; |
|
padding: 12px 24px; |
|
font-weight: 600; |
|
transition: all 0.3s ease; |
|
} |
|
.gr-button-primary { |
|
background: #4a90e2; |
|
color: white; |
|
} |
|
.gr-button-primary:hover { |
|
background: #357abd; |
|
} |
|
.gr-button-secondary { |
|
background: #e4e7eb; |
|
color: #333; |
|
} |
|
.gr-button-secondary:hover { |
|
background: #d1d5db; |
|
} |
|
#title { |
|
font-size: 2.8em; |
|
font-weight: 700; |
|
color: #1a3c6e; |
|
text-align: center; |
|
margin-bottom: 10px; |
|
} |
|
#description { |
|
font-size: 1.2em; |
|
color: #555; |
|
text-align: center; |
|
max-width: 800px; |
|
margin: 0 auto 30px auto; |
|
} |
|
#theme-toggle { |
|
position: fixed; |
|
top: 20px; |
|
right: 20px; |
|
background: none; |
|
border: none; |
|
font-size: 1.5em; |
|
cursor: pointer; |
|
transition: transform 0.3s; |
|
} |
|
#theme-toggle:hover { |
|
transform: scale(1.2); |
|
} |
|
.dark-mode { |
|
background: #1e2a44; |
|
color: #e0e0e0; |
|
} |
|
.dark-mode .gr-panel { |
|
background: #2a3a5a; |
|
box-shadow: 0 6px 20px rgba(0,0,0,0.2); |
|
} |
|
.dark-mode #title { |
|
color: #66b3ff; |
|
} |
|
.dark-mode #description { |
|
color: #b0b0b0; |
|
} |
|
.dark-mode .gr-button-secondary { |
|
background: #3a4a6a; |
|
color: #e0e0e0; |
|
} |
|
.dark-mode .gr-button-secondary:hover { |
|
background: #4a5a7a; |
|
} |
|
#loading { |
|
font-style: italic; |
|
color: #888; |
|
text-align: center; |
|
display: none; |
|
} |
|
#loading.visible { |
|
display: block; |
|
} |
|
#examples-title { |
|
font-size: 1.5em; |
|
font-weight: 600; |
|
color: #1a3c6e; |
|
margin-bottom: 10px; |
|
} |
|
.dark-mode #examples-title { |
|
color: #66b3ff; |
|
} |
|
footer { |
|
text-align: center; |
|
margin-top: 40px; |
|
padding: 20px; |
|
font-size: 0.9em; |
|
color: #666; |
|
} |
|
footer a { |
|
color: #4a90e2; |
|
text-decoration: none; |
|
} |
|
footer a:hover { |
|
text-decoration: underline; |
|
} |
|
.dark-mode footer { |
|
color: #b0b0b0; |
|
} |
|
""" |
|
|
|
|
|
theme_js = """ |
|
function toggleTheme() { |
|
document.body.classList.toggle('dark-mode'); |
|
const toggleBtn = document.getElementById('theme-toggle'); |
|
toggleBtn.innerHTML = document.body.classList.contains('dark-mode') ? '☀️' : '🌙'; |
|
} |
|
function showLoading() { |
|
document.getElementById('loading').classList.add('visible'); |
|
} |
|
function hideLoading() { |
|
document.getElementById('loading').classList.remove('visible'); |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
|
|
gr.HTML( |
|
""" |
|
<button id='theme-toggle' onclick='toggleTheme()'>🌙</button> |
|
<script>{}</script> |
|
""".format(theme_js) |
|
) |
|
|
|
|
|
gr.Markdown("<div id='title'>GoEmotions BERT Classifier</div>", elem_id="title") |
|
gr.Markdown( |
|
""" |
|
<div id='description'> |
|
Predict emotions from text using a fine-tuned BERT-base model on the GoEmotions dataset. |
|
Detect 28 emotions with optimized thresholds (Micro F1: 0.6006). |
|
View preprocessed text, top 5 emotions, and thresholded predictions with interactive visualizations! |
|
</div> |
|
""", |
|
elem_id="description" |
|
) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
with gr.Group(): |
|
gr.Markdown("### Input Text") |
|
text_input = gr.Textbox( |
|
label="Enter Your Text", |
|
placeholder="Type something like 'I’m just chilling today'...", |
|
lines=3, |
|
show_label=False |
|
) |
|
confidence_slider = gr.Slider( |
|
minimum=0.0, |
|
maximum=0.9, |
|
value=0.0, |
|
step=0.05, |
|
label="Minimum Confidence Threshold", |
|
info="Filter predictions below this confidence level (default thresholds still apply)" |
|
) |
|
chart_type = gr.Radio( |
|
choices=["bar", "pie"], |
|
value="bar", |
|
label="Chart Type", |
|
info="Choose how to visualize the emotion confidences" |
|
) |
|
with gr.Row(): |
|
submit_btn = gr.Button("Predict Emotions", variant="primary") |
|
reset_btn = gr.Button("Reset", variant="secondary") |
|
|
|
|
|
loading_indicator = gr.HTML("<div id='loading'>Predicting emotions, please wait...</div>") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
with gr.Group(): |
|
gr.Markdown("### Results") |
|
processed_text_output = gr.Textbox(label="Preprocessed Text", lines=2, interactive=False) |
|
thresholded_output = gr.Textbox(label="Predicted Emotions (Above Threshold)", lines=5, interactive=False) |
|
top_5_output = gr.Textbox(label="Top 5 Emotions (Regardless of Threshold)", lines=5, interactive=False) |
|
output_plot = gr.Plot(label="Emotion Confidence Visualization (Above Threshold)") |
|
|
|
|
|
export_btn = gr.File(label="Download Predictions as CSV", visible=False) |
|
|
|
|
|
with gr.Group(): |
|
gr.Markdown("<div id='examples-title'>Example Texts</div>", elem_id="examples-title") |
|
examples = gr.Examples( |
|
examples=[ |
|
["I’m just chilling today.", "Neutral Example"], |
|
["Thank you for saving my life!", "Gratitude Example"], |
|
["I’m nervous about my exam tomorrow.", "Nervousness Example"], |
|
["I love my new puppy so much!", "Love Example"], |
|
["I’m so relieved the storm passed.", "Relief Example"] |
|
], |
|
inputs=[text_input], |
|
label="", |
|
examples_per_page=3 |
|
) |
|
|
|
|
|
gr.HTML( |
|
""" |
|
<footer> |
|
Built with ❤️ by logasanjeev | |
|
<a href="https://huggingface.co/logasanjeev/goemotions-bert">Model Card</a> | |
|
<a href="https://www.kaggle.com/code/ravindranlogasanjeev/evaluation-logasanjeev-goemotions-bert/notebook">Kaggle Notebook</a> | |
|
<a href="https://github.com/logasanjeev">GitHub</a> |
|
</footer> |
|
""" |
|
) |
|
|
|
|
|
loading_state = gr.State(value=False) |
|
|
|
|
|
def start_loading(): |
|
return True |
|
|
|
def stop_loading(processed_text, thresholded_output, top_5_output, fig, df_export, loading_state): |
|
return processed_text, thresholded_output, top_5_output, fig, df_export, False |
|
|
|
submit_btn.click( |
|
fn=start_loading, |
|
inputs=[], |
|
outputs=[loading_state] |
|
).then( |
|
fn=predict_emotions_with_details, |
|
inputs=[text_input, confidence_slider, chart_type], |
|
outputs=[processed_text_output, thresholded_output, top_5_output, output_plot, export_btn, loading_state] |
|
) |
|
|
|
|
|
reset_btn.click( |
|
fn=lambda: ("", "", "", None, None, False), |
|
inputs=[], |
|
outputs=[text_input, processed_text_output, thresholded_output, top_5_output, output_plot, export_btn, loading_state] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |