import gradio as gr import pandas as pd import plotly.express as px import plotly.graph_objects as go import shutil import os import torch from huggingface_hub import hf_hub_download from importlib import import_module # Load inference.py and model repo_id = "logasanjeev/goemotions-bert" local_file = hf_hub_download(repo_id=repo_id, filename="inference.py") print("Downloaded inference.py successfully!") current_dir = os.getcwd() destination = os.path.join(current_dir, "inference.py") shutil.copy(local_file, destination) print("Copied inference.py to current directory!") inference_module = import_module("inference") predict_emotions = inference_module.predict_emotions print("Imported predict_emotions successfully!") _, _ = predict_emotions("dummy text") emotion_labels = inference_module.EMOTION_LABELS default_thresholds = inference_module.THRESHOLDS # Prediction function with export capability def predict_emotions_with_details(text, confidence_threshold=0.0, chart_type="bar"): if not text.strip(): return "Please enter some text.", "", "", None, None predictions_str, processed_text = predict_emotions(text) # Parse predictions predictions = [] if predictions_str != "No emotions predicted.": for line in predictions_str.split("\n"): emotion, confidence = line.split(": ") predictions.append((emotion, float(confidence))) # Get raw logits for all emotions encodings = inference_module.TOKENIZER( processed_text, padding='max_length', truncation=True, max_length=128, return_tensors='pt' ) input_ids = encodings['input_ids'].to(inference_module.DEVICE) attention_mask = encodings['attention_mask'].to(inference_module.DEVICE) with torch.no_grad(): outputs = inference_module.MODEL(input_ids, attention_mask=attention_mask) logits = torch.sigmoid(outputs.logits).cpu().numpy()[0] # All emotions for top 5 all_emotions = [(emotion_labels[i], round(logit, 4)) for i, logit in enumerate(logits)] all_emotions.sort(key=lambda x: x[1], reverse=True) top_5_emotions = all_emotions[:5] top_5_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in top_5_emotions]) # Filter predictions based on threshold filtered_predictions = [] for emotion, confidence in predictions: thresh = default_thresholds[emotion_labels.index(emotion)] adjusted_thresh = max(thresh, confidence_threshold) if confidence >= adjusted_thresh: filtered_predictions.append((emotion, confidence)) if not filtered_predictions: thresholded_output = "No emotions predicted above thresholds." else: thresholded_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in filtered_predictions]) # Create visualization fig = None df_export = None if filtered_predictions: df = pd.DataFrame(filtered_predictions, columns=["Emotion", "Confidence"]) df_export = df.copy() if chart_type == "bar": fig = px.bar( df, x="Emotion", y="Confidence", color="Emotion", text="Confidence", title="Emotion Confidence Levels (Above Threshold)", height=400, color_discrete_sequence=px.colors.qualitative.Plotly ) fig.update_traces(texttemplate='%{text:.2f}', textposition='auto') fig.update_layout(showlegend=False, margin=dict(t=40, b=40), xaxis_title="", yaxis_title="Confidence") else: # pie chart fig = px.pie( df, names="Emotion", values="Confidence", title="Emotion Confidence Distribution (Above Threshold)", height=400, color_discrete_sequence=px.colors.qualitative.Plotly ) fig.update_traces(textinfo='percent+label', pull=[0.1] + [0] * (len(df) - 1)) fig.update_layout(margin=dict(t=40, b=40)) return processed_text, thresholded_output, top_5_output, fig, df_export # Custom CSS for enhanced styling custom_css = """ body { font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; background-color: #f5f7fa; } .gr-panel { border-radius: 16px; box-shadow: 0 6px 20px rgba(0,0,0,0.08); background: white; padding: 20px; margin-bottom: 20px; } .gr-button { border-radius: 8px; padding: 12px 24px; font-weight: 600; transition: all 0.3s ease; } .gr-button-primary { background: #4a90e2; color: white; } .gr-button-primary:hover { background: #357abd; } .gr-button-secondary { background: #e4e7eb; color: #333; } .gr-button-secondary:hover { background: #d1d5db; } #title { font-size: 2.8em; font-weight: 700; color: #1a3c6e; text-align: center; margin-bottom: 10px; } #description { font-size: 1.2em; color: #555; text-align: center; max-width: 800px; margin: 0 auto 30px auto; } #theme-toggle { position: fixed; top: 20px; right: 20px; background: none; border: none; font-size: 1.5em; cursor: pointer; transition: transform 0.3s; } #theme-toggle:hover { transform: scale(1.2); } .dark-mode { background: #1e2a44; color: #e0e0e0; } .dark-mode .gr-panel { background: #2a3a5a; box-shadow: 0 6px 20px rgba(0,0,0,0.2); } .dark-mode #title { color: #66b3ff; } .dark-mode #description { color: #b0b0b0; } .dark-mode .gr-button-secondary { background: #3a4a6a; color: #e0e0e0; } .dark-mode .gr-button-secondary:hover { background: #4a5a7a; } #loading { font-style: italic; color: #888; text-align: center; } #examples-title { font-size: 1.5em; font-weight: 600; color: #1a3c6e; margin-bottom: 10px; } .dark-mode #examples-title { color: #66b3ff; } footer { text-align: center; margin-top: 40px; padding: 20px; font-size: 0.9em; color: #666; } footer a { color: #4a90e2; text-decoration: none; } footer a:hover { text-decoration: underline; } .dark-mode footer { color: #b0b0b0; } """ # JavaScript for theme toggle and loading spinner theme_js = """ function toggleTheme() { document.body.classList.toggle('dark-mode'); const toggleBtn = document.getElementById('theme-toggle'); toggleBtn.innerHTML = document.body.classList.contains('dark-mode') ? '☀️' : '🌙'; } function showLoading() { document.getElementById('loading').style.display = 'block'; } function hideLoading() { document.getElementById('loading').style.display = 'none'; } """ # Gradio Blocks UI with gr.Blocks(css=custom_css) as demo: # Theme toggle button gr.HTML( """ """.format(theme_js) ) # Header gr.Markdown("
GoEmotions BERT Classifier
", elem_id="title") gr.Markdown( """
Predict emotions from text using a fine-tuned BERT-base model on the GoEmotions dataset. Detect 28 emotions with optimized thresholds (Micro F1: 0.6006). View preprocessed text, top 5 emotions, and thresholded predictions with interactive visualizations!
""", elem_id="description" ) # Main content with gr.Row(): with gr.Column(scale=1): # Input Section with gr.Group(): gr.Markdown("### Input Text") text_input = gr.Textbox( label="Enter Your Text", placeholder="Type something like 'I’m just chilling today'...", lines=3, show_label=False ) confidence_slider = gr.Slider( minimum=0.0, maximum=0.9, value=0.0, step=0.05, label="Minimum Confidence Threshold", info="Filter predictions below this confidence level (default thresholds still apply)" ) chart_type = gr.Radio( choices=["bar", "pie"], value="bar", label="Chart Type", info="Choose how to visualize the emotion confidences" ) with gr.Row(): submit_btn = gr.Button("Predict Emotions", variant="primary") reset_btn = gr.Button("Reset", variant="secondary") # Loading indicator gr.HTML("") # Output Section with gr.Row(): with gr.Column(scale=1): with gr.Group(): gr.Markdown("### Results") processed_text_output = gr.Textbox(label="Preprocessed Text", lines=2, interactive=False) thresholded_output = gr.Textbox(label="Predicted Emotions (Above Threshold)", lines=5, interactive=False) top_5_output = gr.Textbox(label="Top 5 Emotions (Regardless of Threshold)", lines=5, interactive=False) output_plot = gr.Plot(label="Emotion Confidence Visualization (Above Threshold)") # Export predictions export_btn = gr.File(label="Download Predictions as CSV", visible=False) # Example carousel with gr.Group(): gr.Markdown("
Example Texts
", elem_id="examples-title") examples = gr.Examples( examples=[ ["I’m just chilling today.", "Neutral Example"], ["Thank you for saving my life!", "Gratitude Example"], ["I’m nervous about my exam tomorrow.", "Nervousness Example"], ["I love my new puppy so much!", "Love Example"], ["I’m so relieved the storm passed.", "Relief Example"] ], inputs=[text_input], label="", examples_per_page=3 ) # Footer gr.HTML( """ """ ) # Bind predictions with loading spinner submit_btn.click( fn=predict_emotions_with_details, inputs=[text_input, confidence_slider, chart_type], outputs=[processed_text_output, thresholded_output, top_5_output, output_plot, export_btn], _js="showLoading(); return [arguments[0], arguments[1], arguments[2]]" ).then( fn=None, inputs=None, outputs=None, _js="hideLoading" ) # Reset functionality reset_btn.click( fn=lambda: ("", "", "", None, None), inputs=[], outputs=[text_input, processed_text_output, thresholded_output, top_5_output, output_plot, export_btn] ) # Launch if __name__ == "__main__": demo.launch()