#!/usr/bin/env python3 """Improved Gradio app for waste classification using enhanced MAE ViT-Base model.""" import os import gradio as gr from PIL import Image from improved_mae_classifier import ImprovedMAEWasteClassifier print("šŸš€ Initializing Improved MAE waste classifier...") try: # Load the improved classifier with optimized settings classifier = ImprovedMAEWasteClassifier( hf_model_id="ysfad/mae-waste-classifier", temperature=2.5, # Reduced overconfidence cardboard_penalty=0.8 # Reduced cardboard bias ) print("āœ… Improved MAE Classifier ready!") except Exception as e: print(f"āŒ Error loading improved classifier: {e}") raise def classify_waste(image): """Classify waste item and provide disposal instructions with improved handling.""" if image is None: return "Please upload an image.", "", "", "" try: # Classify the image using ensemble prediction for better accuracy result = classifier.classify_image(image, top_k=5, use_ensemble=True) if not result['success']: return f"Error: {result['error']}", "", "", "" predicted_class = result['predicted_class'] confidence = result['confidence'] top_predictions = result['top_predictions'] # Format prediction result with confidence handling if predicted_class == "Uncertain": prediction_text = f"šŸ¤” **Uncertain Classification**\n\nConfidence too low for reliable prediction ({confidence:.1%})\n\nšŸ’” **Suggestions:**\n- Try a clearer photo\n- Better lighting\n- Different angle\n- Remove background clutter" confidence_text = f"Highest confidence: {confidence:.1%} (below threshold)" else: prediction_text = f"šŸŽÆ **{predicted_class}**\n\nConfidence: {confidence:.1%}" confidence_text = f"Confidence: {confidence:.1%}" # Get disposal instructions instructions = classifier.get_disposal_instructions(predicted_class) # Create detailed predictions table predictions_table = "| Rank | Class | Confidence |\n|------|-------|------------|\n" for i, pred in enumerate(top_predictions, 1): conf_percent = pred['confidence'] * 100 predictions_table += f"| {i} | {pred['class']} | {conf_percent:.1f}% |\n" # Model information model_info = classifier.get_model_info() info_text = f"""**Model:** {model_info['model_name']} **Architecture:** {model_info['architecture']} **Classes:** {model_info['num_classes']} **Device:** {model_info['device']} **Improvements:** Temperature scaling, bias correction, ensemble prediction""" return prediction_text, confidence_text, instructions, predictions_table, info_text except Exception as e: return f"Error processing image: {str(e)}", "", "", "", "" # Create Gradio interface with improved design with gr.Blocks( title="šŸ—‚ļø Improved MAE Waste Classifier", theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 1200px !important; } .header { text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px; margin-bottom: 20px; } .improvement-box { background: #e8f5e8; border: 2px solid #4caf50; border-radius: 8px; padding: 15px; margin: 10px 0; } .warning-box { background: #fff3cd; border: 2px solid #ffc107; border-radius: 8px; padding: 15px; margin: 10px 0; } """ ) as demo: # Header gr.HTML("""

šŸ—‚ļø Improved MAE Waste Classifier

Enhanced AI-powered waste classification with bias correction and uncertainty handling

✨ New Features: Temperature scaling • Cardboard bias reduction • Uncertainty detection • Ensemble predictions

""") # Improvements notice gr.HTML("""

šŸŽ‰ Recent Improvements

""") with gr.Row(): with gr.Column(scale=1): # Image input image_input = gr.Image( label="šŸ“ø Upload Waste Image", type="pil", height=400 ) # Classification button classify_btn = gr.Button( "šŸ” Classify Waste", variant="primary", size="lg" ) # Quick tips gr.HTML("""

šŸ“‹ Tips for Better Results:

""") with gr.Column(scale=2): # Results section with gr.Group(): gr.HTML("

šŸŽÆ Classification Results

") prediction_output = gr.Markdown( label="Prediction", value="Upload an image to get started!" ) confidence_output = gr.Textbox( label="šŸ“Š Confidence Score", interactive=False ) instructions_output = gr.Textbox( label="ā™»ļø Disposal Instructions", lines=3, interactive=False ) # Detailed results section with gr.Row(): with gr.Column(): gr.HTML("

šŸ“Š Detailed Predictions

") predictions_table = gr.Markdown( label="Top 5 Predictions", value="| Rank | Class | Confidence |\n|------|-------|------------|\n| - | Upload image first | - |" ) with gr.Column(): gr.HTML("

šŸ¤– Model Information

") model_info_output = gr.Markdown( label="Model Details", value="Model information will appear here after classification." ) # About section with gr.Accordion("ā„¹ļø About This Improved Model", open=False): gr.HTML("""

🧠 Model Architecture

This classifier uses a Vision Transformer (ViT-Base) pre-trained with Masked Autoencoder (MAE) and fine-tuned on the RealWaste dataset.

✨ Key Improvements

šŸ“Š Performance Metrics

šŸ—‚ļø Waste Categories

Cardboard, Food Organics, Glass, Metal, Miscellaneous Trash, Paper, Plastic, Textile Trash, Vegetation

""") # Event handlers classify_btn.click( fn=classify_waste, inputs=[image_input], outputs=[ prediction_output, confidence_output, instructions_output, predictions_table, model_info_output ] ) # Auto-classify on image upload image_input.change( fn=classify_waste, inputs=[image_input], outputs=[ prediction_output, confidence_output, instructions_output, predictions_table, model_info_output ] ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7863, share=False )