File size: 9,674 Bytes
e15cf70
de63d9f
e15cf70
 
 
 
de63d9f
e15cf70
de63d9f
e15cf70
de63d9f
 
 
 
 
 
 
e15cf70
de63d9f
e15cf70
 
 
de63d9f
e15cf70
 
 
 
de63d9f
 
e15cf70
 
 
 
de63d9f
 
 
e15cf70
de63d9f
 
 
 
 
 
 
e15cf70
 
de63d9f
e15cf70
de63d9f
 
 
 
 
e15cf70
de63d9f
 
 
 
 
 
 
e15cf70
de63d9f
e15cf70
 
de63d9f
e15cf70
de63d9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e15cf70
de63d9f
 
 
 
 
 
 
 
e15cf70
de63d9f
 
 
 
 
 
 
 
 
 
 
e15cf70
 
 
 
de63d9f
e15cf70
de63d9f
e15cf70
de63d9f
e15cf70
 
de63d9f
e15cf70
 
 
 
 
 
de63d9f
 
 
 
 
 
 
 
 
 
 
 
e15cf70
de63d9f
e15cf70
de63d9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e15cf70
 
de63d9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e15cf70
 
 
 
de63d9f
 
 
 
 
 
 
 
e15cf70
 
de63d9f
e15cf70
 
de63d9f
 
 
 
 
 
 
 
e15cf70
 
 
de63d9f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
#!/usr/bin/env python3
"""Improved Gradio app for waste classification using enhanced MAE ViT-Base model."""

import os
import gradio as gr
from PIL import Image
from improved_mae_classifier import ImprovedMAEWasteClassifier

print("πŸš€ Initializing Improved MAE waste classifier...")
try:
    # Load the improved classifier with optimized settings
    classifier = ImprovedMAEWasteClassifier(
        hf_model_id="ysfad/mae-waste-classifier",
        temperature=2.5,  # Reduced overconfidence
        cardboard_penalty=0.8  # Reduced cardboard bias
    )
    print("βœ… Improved MAE Classifier ready!")
except Exception as e:
    print(f"❌ Error loading improved classifier: {e}")
    raise

def classify_waste(image):
    """Classify waste item and provide disposal instructions with improved handling."""
    if image is None:
        return "Please upload an image.", "", "", ""
    
    try:
        # Classify the image using ensemble prediction for better accuracy
        result = classifier.classify_image(image, top_k=5, use_ensemble=True)
        
        if not result['success']:
            return f"Error: {result['error']}", "", "", ""
        
        predicted_class = result['predicted_class']
        confidence = result['confidence']
        top_predictions = result['top_predictions']
        
        # Format prediction result with confidence handling
        if predicted_class == "Uncertain":
            prediction_text = f"πŸ€” **Uncertain Classification**\n\nConfidence too low for reliable prediction ({confidence:.1%})\n\nπŸ’‘ **Suggestions:**\n- Try a clearer photo\n- Better lighting\n- Different angle\n- Remove background clutter"
            confidence_text = f"Highest confidence: {confidence:.1%} (below threshold)"
        else:
            prediction_text = f"🎯 **{predicted_class}**\n\nConfidence: {confidence:.1%}"
            confidence_text = f"Confidence: {confidence:.1%}"
        
        # Get disposal instructions
        instructions = classifier.get_disposal_instructions(predicted_class)
        
        # Create detailed predictions table
        predictions_table = "| Rank | Class | Confidence |\n|------|-------|------------|\n"
        for i, pred in enumerate(top_predictions, 1):
            conf_percent = pred['confidence'] * 100
            predictions_table += f"| {i} | {pred['class']} | {conf_percent:.1f}% |\n"
        
        # Model information
        model_info = classifier.get_model_info()
        info_text = f"""**Model:** {model_info['model_name']}

**Architecture:** {model_info['architecture']}

**Classes:** {model_info['num_classes']}

**Device:** {model_info['device']}

**Improvements:** Temperature scaling, bias correction, ensemble prediction"""
        
        return prediction_text, confidence_text, instructions, predictions_table, info_text
        
    except Exception as e:
        return f"Error processing image: {str(e)}", "", "", "", ""

# Create Gradio interface with improved design
with gr.Blocks(
    title="πŸ—‚οΈ Improved MAE Waste Classifier",
    theme=gr.themes.Soft(),
    css="""

    .gradio-container {

        max-width: 1200px !important;

    }

    .header {

        text-align: center;

        padding: 20px;

        background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);

        color: white;

        border-radius: 10px;

        margin-bottom: 20px;

    }

    .improvement-box {

        background: #e8f5e8;

        border: 2px solid #4caf50;

        border-radius: 8px;

        padding: 15px;

        margin: 10px 0;

    }

    .warning-box {

        background: #fff3cd;

        border: 2px solid #ffc107;

        border-radius: 8px;

        padding: 15px;

        margin: 10px 0;

    }

    """
) as demo:
    
    # Header
    gr.HTML("""

    <div class="header">

        <h1>πŸ—‚οΈ Improved MAE Waste Classifier</h1>

        <p>Enhanced AI-powered waste classification with bias correction and uncertainty handling</p>

        <p><strong>✨ New Features:</strong> Temperature scaling β€’ Cardboard bias reduction β€’ Uncertainty detection β€’ Ensemble predictions</p>

    </div>

    """)
    
    # Improvements notice
    gr.HTML("""

    <div class="improvement-box">

        <h3>πŸŽ‰ Recent Improvements</h3>

        <ul>

            <li><strong>βœ… Reduced Cardboard Bias:</strong> From 83% to 17% false cardboard predictions</li>

            <li><strong>βœ… Better Confidence:</strong> 39% reduction in overconfident predictions</li>

            <li><strong>βœ… Uncertainty Handling:</strong> Shows "Uncertain" for low-confidence predictions</li>

            <li><strong>βœ… Ensemble Predictions:</strong> Uses multiple augmentations for stability</li>

        </ul>

    </div>

    """)
    
    with gr.Row():
        with gr.Column(scale=1):
            # Image input
            image_input = gr.Image(
                label="πŸ“Έ Upload Waste Image",
                type="pil",
                height=400
            )
            
            # Classification button
            classify_btn = gr.Button(
                "πŸ” Classify Waste",
                variant="primary",
                size="lg"
            )
            
            # Quick tips
            gr.HTML("""

            <div class="warning-box">

                <h4>πŸ“‹ Tips for Better Results:</h4>

                <ul>

                    <li>Use clear, well-lit photos</li>

                    <li>Center the item in frame</li>

                    <li>Avoid cluttered backgrounds</li>

                    <li>Try different angles if uncertain</li>

                </ul>

            </div>

            """)
        
        with gr.Column(scale=2):
            # Results section
            with gr.Group():
                gr.HTML("<h3>🎯 Classification Results</h3>")
                
                prediction_output = gr.Markdown(
                    label="Prediction",
                    value="Upload an image to get started!"
                )
                
                confidence_output = gr.Textbox(
                    label="πŸ“Š Confidence Score",
                    interactive=False
                )
                
                instructions_output = gr.Textbox(
                    label="♻️ Disposal Instructions",
                    lines=3,
                    interactive=False
                )
    
    # Detailed results section
    with gr.Row():
        with gr.Column():
            gr.HTML("<h3>πŸ“Š Detailed Predictions</h3>")
            predictions_table = gr.Markdown(
                label="Top 5 Predictions",
                value="| Rank | Class | Confidence |\n|------|-------|------------|\n| - | Upload image first | - |"
            )
        
        with gr.Column():
            gr.HTML("<h3>πŸ€– Model Information</h3>")
            model_info_output = gr.Markdown(
                label="Model Details",
                value="Model information will appear here after classification."
            )
    
    # About section
    with gr.Accordion("ℹ️ About This Improved Model", open=False):
        gr.HTML("""

        <div style="padding: 20px;">

            <h4>🧠 Model Architecture</h4>

            <p>This classifier uses a <strong>Vision Transformer (ViT-Base)</strong> pre-trained with <strong>Masked Autoencoder (MAE)</strong> and fine-tuned on the RealWaste dataset.</p>

            

            <h4>✨ Key Improvements</h4>

            <ul>

                <li><strong>Temperature Scaling (T=2.5):</strong> Reduces overconfident predictions</li>

                <li><strong>Cardboard Bias Correction:</strong> Applies 0.8x penalty to cardboard predictions</li>

                <li><strong>Class-specific Thresholds:</strong> Higher threshold (0.8) for cardboard, lower (0.4) for textile</li>

                <li><strong>Ensemble Prediction:</strong> Averages 5 augmented predictions for stability</li>

                <li><strong>Uncertainty Detection:</strong> Shows "Uncertain" when confidence is too low</li>

            </ul>

            

            <h4>πŸ“Š Performance Metrics</h4>

            <ul>

                <li><strong>Original Validation Accuracy:</strong> 93.27%</li>

                <li><strong>Cardboard Bias Reduction:</strong> 66.6% improvement</li>

                <li><strong>Confidence Calibration:</strong> 38.7% reduction in overconfidence</li>

                <li><strong>Classes:</strong> 9 waste categories</li>

            </ul>

            

            <h4>πŸ—‚οΈ Waste Categories</h4>

            <p><strong>Cardboard, Food Organics, Glass, Metal, Miscellaneous Trash, Paper, Plastic, Textile Trash, Vegetation</strong></p>

        </div>

        """)
    
    # Event handlers
    classify_btn.click(
        fn=classify_waste,
        inputs=[image_input],
        outputs=[
            prediction_output,
            confidence_output, 
            instructions_output,
            predictions_table,
            model_info_output
        ]
    )
    
    # Auto-classify on image upload
    image_input.change(
        fn=classify_waste,
        inputs=[image_input],
        outputs=[
            prediction_output,
            confidence_output,
            instructions_output,
            predictions_table,
            model_info_output
        ]
    )

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7863,
        share=False
    )