Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
"""Improved Gradio app for waste classification using enhanced MAE ViT-Base model.""" | |
import os | |
import gradio as gr | |
from PIL import Image | |
from improved_mae_classifier import ImprovedMAEWasteClassifier | |
print("π Initializing Improved MAE waste classifier...") | |
try: | |
# Load the improved classifier with optimized settings | |
classifier = ImprovedMAEWasteClassifier( | |
hf_model_id="ysfad/mae-waste-classifier", | |
temperature=2.5, # Reduced overconfidence | |
cardboard_penalty=0.8 # Reduced cardboard bias | |
) | |
print("β Improved MAE Classifier ready!") | |
except Exception as e: | |
print(f"β Error loading improved classifier: {e}") | |
raise | |
def classify_waste(image): | |
"""Classify waste item and provide disposal instructions with improved handling.""" | |
if image is None: | |
return "Please upload an image.", "", "", "" | |
try: | |
# Classify the image using ensemble prediction for better accuracy | |
result = classifier.classify_image(image, top_k=5, use_ensemble=True) | |
if not result['success']: | |
return f"Error: {result['error']}", "", "", "" | |
predicted_class = result['predicted_class'] | |
confidence = result['confidence'] | |
top_predictions = result['top_predictions'] | |
# Format prediction result with confidence handling | |
if predicted_class == "Uncertain": | |
prediction_text = f"π€ **Uncertain Classification**\n\nConfidence too low for reliable prediction ({confidence:.1%})\n\nπ‘ **Suggestions:**\n- Try a clearer photo\n- Better lighting\n- Different angle\n- Remove background clutter" | |
confidence_text = f"Highest confidence: {confidence:.1%} (below threshold)" | |
else: | |
prediction_text = f"π― **{predicted_class}**\n\nConfidence: {confidence:.1%}" | |
confidence_text = f"Confidence: {confidence:.1%}" | |
# Get disposal instructions | |
instructions = classifier.get_disposal_instructions(predicted_class) | |
# Create detailed predictions table | |
predictions_table = "| Rank | Class | Confidence |\n|------|-------|------------|\n" | |
for i, pred in enumerate(top_predictions, 1): | |
conf_percent = pred['confidence'] * 100 | |
predictions_table += f"| {i} | {pred['class']} | {conf_percent:.1f}% |\n" | |
# Model information | |
model_info = classifier.get_model_info() | |
info_text = f"""**Model:** {model_info['model_name']} | |
**Architecture:** {model_info['architecture']} | |
**Classes:** {model_info['num_classes']} | |
**Device:** {model_info['device']} | |
**Improvements:** Temperature scaling, bias correction, ensemble prediction""" | |
return prediction_text, confidence_text, instructions, predictions_table, info_text | |
except Exception as e: | |
return f"Error processing image: {str(e)}", "", "", "", "" | |
# Create Gradio interface with improved design | |
with gr.Blocks( | |
title="ποΈ Improved MAE Waste Classifier", | |
theme=gr.themes.Soft(), | |
css=""" | |
.gradio-container { | |
max-width: 1200px !important; | |
} | |
.header { | |
text-align: center; | |
padding: 20px; | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
color: white; | |
border-radius: 10px; | |
margin-bottom: 20px; | |
} | |
.improvement-box { | |
background: #e8f5e8; | |
border: 2px solid #4caf50; | |
border-radius: 8px; | |
padding: 15px; | |
margin: 10px 0; | |
} | |
.warning-box { | |
background: #fff3cd; | |
border: 2px solid #ffc107; | |
border-radius: 8px; | |
padding: 15px; | |
margin: 10px 0; | |
} | |
""" | |
) as demo: | |
# Header | |
gr.HTML(""" | |
<div class="header"> | |
<h1>ποΈ Improved MAE Waste Classifier</h1> | |
<p>Enhanced AI-powered waste classification with bias correction and uncertainty handling</p> | |
<p><strong>β¨ New Features:</strong> Temperature scaling β’ Cardboard bias reduction β’ Uncertainty detection β’ Ensemble predictions</p> | |
</div> | |
""") | |
# Improvements notice | |
gr.HTML(""" | |
<div class="improvement-box"> | |
<h3>π Recent Improvements</h3> | |
<ul> | |
<li><strong>β Reduced Cardboard Bias:</strong> From 83% to 17% false cardboard predictions</li> | |
<li><strong>β Better Confidence:</strong> 39% reduction in overconfident predictions</li> | |
<li><strong>β Uncertainty Handling:</strong> Shows "Uncertain" for low-confidence predictions</li> | |
<li><strong>β Ensemble Predictions:</strong> Uses multiple augmentations for stability</li> | |
</ul> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Image input | |
image_input = gr.Image( | |
label="πΈ Upload Waste Image", | |
type="pil", | |
height=400 | |
) | |
# Classification button | |
classify_btn = gr.Button( | |
"π Classify Waste", | |
variant="primary", | |
size="lg" | |
) | |
# Quick tips | |
gr.HTML(""" | |
<div class="warning-box"> | |
<h4>π Tips for Better Results:</h4> | |
<ul> | |
<li>Use clear, well-lit photos</li> | |
<li>Center the item in frame</li> | |
<li>Avoid cluttered backgrounds</li> | |
<li>Try different angles if uncertain</li> | |
</ul> | |
</div> | |
""") | |
with gr.Column(scale=2): | |
# Results section | |
with gr.Group(): | |
gr.HTML("<h3>π― Classification Results</h3>") | |
prediction_output = gr.Markdown( | |
label="Prediction", | |
value="Upload an image to get started!" | |
) | |
confidence_output = gr.Textbox( | |
label="π Confidence Score", | |
interactive=False | |
) | |
instructions_output = gr.Textbox( | |
label="β»οΈ Disposal Instructions", | |
lines=3, | |
interactive=False | |
) | |
# Detailed results section | |
with gr.Row(): | |
with gr.Column(): | |
gr.HTML("<h3>π Detailed Predictions</h3>") | |
predictions_table = gr.Markdown( | |
label="Top 5 Predictions", | |
value="| Rank | Class | Confidence |\n|------|-------|------------|\n| - | Upload image first | - |" | |
) | |
with gr.Column(): | |
gr.HTML("<h3>π€ Model Information</h3>") | |
model_info_output = gr.Markdown( | |
label="Model Details", | |
value="Model information will appear here after classification." | |
) | |
# About section | |
with gr.Accordion("βΉοΈ About This Improved Model", open=False): | |
gr.HTML(""" | |
<div style="padding: 20px;"> | |
<h4>π§ Model Architecture</h4> | |
<p>This classifier uses a <strong>Vision Transformer (ViT-Base)</strong> pre-trained with <strong>Masked Autoencoder (MAE)</strong> and fine-tuned on the RealWaste dataset.</p> | |
<h4>β¨ Key Improvements</h4> | |
<ul> | |
<li><strong>Temperature Scaling (T=2.5):</strong> Reduces overconfident predictions</li> | |
<li><strong>Cardboard Bias Correction:</strong> Applies 0.8x penalty to cardboard predictions</li> | |
<li><strong>Class-specific Thresholds:</strong> Higher threshold (0.8) for cardboard, lower (0.4) for textile</li> | |
<li><strong>Ensemble Prediction:</strong> Averages 5 augmented predictions for stability</li> | |
<li><strong>Uncertainty Detection:</strong> Shows "Uncertain" when confidence is too low</li> | |
</ul> | |
<h4>π Performance Metrics</h4> | |
<ul> | |
<li><strong>Original Validation Accuracy:</strong> 93.27%</li> | |
<li><strong>Cardboard Bias Reduction:</strong> 66.6% improvement</li> | |
<li><strong>Confidence Calibration:</strong> 38.7% reduction in overconfidence</li> | |
<li><strong>Classes:</strong> 9 waste categories</li> | |
</ul> | |
<h4>ποΈ Waste Categories</h4> | |
<p><strong>Cardboard, Food Organics, Glass, Metal, Miscellaneous Trash, Paper, Plastic, Textile Trash, Vegetation</strong></p> | |
</div> | |
""") | |
# Event handlers | |
classify_btn.click( | |
fn=classify_waste, | |
inputs=[image_input], | |
outputs=[ | |
prediction_output, | |
confidence_output, | |
instructions_output, | |
predictions_table, | |
model_info_output | |
] | |
) | |
# Auto-classify on image upload | |
image_input.change( | |
fn=classify_waste, | |
inputs=[image_input], | |
outputs=[ | |
prediction_output, | |
confidence_output, | |
instructions_output, | |
predictions_table, | |
model_info_output | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7863, | |
share=False | |
) |