|
import gradio as gr |
|
import pandas as pd |
|
import plotly.express as px |
|
import shutil |
|
import os |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from importlib import import_module |
|
|
|
|
|
repo_id = "logasanjeev/goemotions-bert" |
|
local_file = hf_hub_download(repo_id=repo_id, filename="inference.py") |
|
print("Downloaded inference.py successfully!") |
|
|
|
current_dir = os.getcwd() |
|
destination = os.path.join(current_dir, "inference.py") |
|
shutil.copy(local_file, destination) |
|
print("Copied inference.py to current directory!") |
|
|
|
inference_module = import_module("inference") |
|
predict_emotions = inference_module.predict_emotions |
|
print("Imported predict_emotions successfully!") |
|
|
|
_, _ = predict_emotions("dummy text") |
|
emotion_labels = inference_module.EMOTION_LABELS |
|
default_thresholds = inference_module.THRESHOLDS |
|
|
|
|
|
def predict_emotions_with_details(text, confidence_threshold=0.0): |
|
if not text.strip(): |
|
return "Please enter some text.", "", None |
|
|
|
predictions_str, processed_text = predict_emotions(text) |
|
|
|
|
|
predictions = [] |
|
if predictions_str != "No emotions predicted.": |
|
for line in predictions_str.split("\n"): |
|
emotion, confidence = line.split(": ") |
|
predictions.append((emotion, float(confidence))) |
|
|
|
|
|
encodings = inference_module.TOKENIZER( |
|
processed_text, |
|
padding='max_length', |
|
truncation=True, |
|
max_length=128, |
|
return_tensors='pt' |
|
) |
|
input_ids = encodings['input_ids'].to(inference_module.DEVICE) |
|
attention_mask = encodings['attention_mask'].to(inference_module.DEVICE) |
|
|
|
with torch.no_grad(): |
|
outputs = inference_module.MODEL(input_ids, attention_mask=attention_mask) |
|
logits = torch.sigmoid(outputs.logits).cpu().numpy()[0] |
|
|
|
|
|
filtered_predictions = [] |
|
for emotion, confidence in predictions: |
|
thresh = default_thresholds[emotion_labels.index(emotion)] |
|
adjusted_thresh = max(thresh, confidence_threshold) |
|
if confidence >= adjusted_thresh: |
|
filtered_predictions.append((emotion, confidence)) |
|
|
|
if not filtered_predictions: |
|
thresholded_output = "No emotions predicted above thresholds." |
|
else: |
|
thresholded_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in filtered_predictions]) |
|
|
|
|
|
fig = None |
|
if filtered_predictions: |
|
df = pd.DataFrame(filtered_predictions, columns=["Emotion", "Confidence"]) |
|
fig = px.bar( |
|
df, |
|
x="Emotion", |
|
y="Confidence", |
|
color="Emotion", |
|
text="Confidence", |
|
title="Emotion Confidence Levels", |
|
height=300, |
|
color_discrete_sequence=px.colors.qualitative.Pastel |
|
) |
|
fig.update_traces(texttemplate='%{text:.2f}', textposition='auto') |
|
fig.update_layout(showlegend=False, margin=dict(t=40, b=40), xaxis_title="", yaxis_title="Confidence") |
|
|
|
return processed_text, thresholded_output, fig |
|
|
|
|
|
custom_css = """ |
|
body { |
|
font-family: 'Arial', sans-serif; |
|
background-color: #f9f9f9; |
|
} |
|
.gr-panel { |
|
border-radius: 8px; |
|
box-shadow: 0 2px 10px rgba(0,0,0,0.05); |
|
background: white; |
|
padding: 15px; |
|
margin-bottom: 15px; |
|
} |
|
.gr-button { |
|
border-radius: 6px; |
|
padding: 10px 20px; |
|
font-weight: 500; |
|
background: #4a90e2; |
|
color: white; |
|
transition: background 0.3s ease; |
|
} |
|
.gr-button:hover { |
|
background: #357abd; |
|
} |
|
#title { |
|
font-size: 2.2em; |
|
font-weight: 600; |
|
color: #333; |
|
text-align: center; |
|
margin-bottom: 10px; |
|
} |
|
#description { |
|
font-size: 1.1em; |
|
color: #666; |
|
text-align: center; |
|
max-width: 600px; |
|
margin: 0 auto 20px auto; |
|
} |
|
#examples-title { |
|
font-size: 1.3em; |
|
font-weight: 500; |
|
color: #333; |
|
margin-bottom: 10px; |
|
} |
|
footer { |
|
text-align: center; |
|
margin-top: 30px; |
|
padding: 15px; |
|
font-size: 0.9em; |
|
color: #666; |
|
} |
|
footer a { |
|
color: #4a90e2; |
|
text-decoration: none; |
|
} |
|
footer a:hover { |
|
text-decoration: underline; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
|
|
gr.Markdown("<div id='title'>GoEmotions BERT Classifier</div>", elem_id="title") |
|
gr.Markdown( |
|
""" |
|
<div id='description'> |
|
Predict emotions from text using a fine-tuned BERT model. |
|
Enter your text below to see the detected emotions and their confidence scores. |
|
</div> |
|
""", |
|
elem_id="description" |
|
) |
|
|
|
|
|
with gr.Group(): |
|
text_input = gr.Textbox( |
|
label="Enter Your Text", |
|
placeholder="Type something like 'Iβm just chilling today'...", |
|
lines=2, |
|
show_label=False |
|
) |
|
confidence_slider = gr.Slider( |
|
minimum=0.0, |
|
maximum=0.9, |
|
value=0.0, |
|
step=0.05, |
|
label="Minimum Confidence Threshold", |
|
info="Filter predictions below this confidence level (default thresholds still apply)" |
|
) |
|
submit_btn = gr.Button("Predict Emotions") |
|
|
|
|
|
with gr.Group(): |
|
processed_text_output = gr.Textbox(label="Preprocessed Text", lines=1, interactive=False) |
|
thresholded_output = gr.Textbox(label="Predicted Emotions", lines=3, interactive=False) |
|
output_plot = gr.Plot(label="Emotion Confidence Chart") |
|
|
|
|
|
with gr.Group(): |
|
gr.Markdown("<div id='examples-title'>Try These Examples</div>", elem_id="examples-title") |
|
examples = gr.Examples( |
|
examples=[ |
|
["Iβm thrilled to win this award! π", "Joy Example"], |
|
["This is so frustrating, nothing works. π£", "Annoyance Example"], |
|
["I feel so sorry for what happened. π’", "Sadness Example"], |
|
["What a beautiful day to be alive! π", "Admiration Example"], |
|
["Feeling nervous about the exam tomorrow π u/student r/study", "Nervousness Example"] |
|
], |
|
inputs=[text_input], |
|
label="" |
|
) |
|
|
|
|
|
gr.HTML( |
|
""" |
|
<footer> |
|
Built by logasanjeev | |
|
<a href="https://huggingface.co/logasanjeev/goemotions-bert">Model Card</a> | |
|
<a href="https://www.kaggle.com/code/ravindranlogasanjeev/evaluation-logasanjeev-goemotions-bert/notebook">Kaggle Notebook</a> |
|
</footer> |
|
""" |
|
) |
|
|
|
|
|
submit_btn.click( |
|
fn=predict_emotions_with_details, |
|
inputs=[text_input, confidence_slider], |
|
outputs=[processed_text_output, thresholded_output, output_plot] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |