File size: 14,030 Bytes
905ac99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
#!/usr/bin/env python3
"""

Improved MAE Waste Classifier with temperature scaling and bias correction

"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
import warnings
warnings.filterwarnings("ignore")

# Import MAE model
from mae.models_vit import vit_base_patch16

class ImprovedMAEWasteClassifier:
    def __init__(self, 

                 model_path=None, 

                 hf_model_id=None, 

                 device=None,

                 temperature=2.5,  # Temperature scaling to reduce overconfidence

                 cardboard_penalty=0.8):  # Penalty factor for cardboard predictions
        """

        Initialize improved MAE waste classifier with bias correction

        

        Args:

            model_path: Local path to model file

            hf_model_id: Hugging Face model ID

            device: Device to run on

            temperature: Temperature scaling factor (>1 reduces confidence)

            cardboard_penalty: Penalty factor for cardboard predictions

        """
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.temperature = temperature
        self.cardboard_penalty = cardboard_penalty
        
        # Class names (must match training order)
        self.class_names = [
            'Cardboard', 'Food Organics', 'Glass', 'Metal', 
            'Miscellaneous Trash', 'Paper', 'Plastic', 'Textile Trash', 'Vegetation'
        ]
        
        # Class-specific confidence thresholds
        self.class_thresholds = {
            'Cardboard': 0.8,  # Higher threshold for cardboard
            'Plastic': 0.6,
            'Metal': 0.6,
            'Glass': 0.6,
            'Paper': 0.6,
            'Food Organics': 0.5,
            'Miscellaneous Trash': 0.5,
            'Textile Trash': 0.4,  # Lower threshold for underrepresented class
            'Vegetation': 0.5
        }
        
        # Load model
        self.model = self._load_model(model_path, hf_model_id)
        self.model.eval()
        
        # Data preprocessing
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        print(f"βœ… Improved MAE Classifier loaded on {self.device}")
        print(f"🌑️ Temperature scaling: {self.temperature}")
        print(f"πŸ—‚οΈ Cardboard penalty: {self.cardboard_penalty}")

    def _load_model(self, model_path=None, hf_model_id=None):
        """Load the finetuned MAE model"""
        
        # Determine model path
        if model_path and os.path.exists(model_path):
            checkpoint_path = model_path
            print(f"πŸ“ Loading local model from {model_path}")
        elif hf_model_id:
            print(f"🌐 Downloading model from HF Hub: {hf_model_id}")
            checkpoint_path = hf_hub_download(
                repo_id=hf_model_id,
                filename="best_model.pth",
                cache_dir="./hf_cache"
            )
            print(f"βœ… Downloaded model to: {checkpoint_path}")
        else:
            # Try local file
            local_path = "output_simple_mae/best_model.pth"
            if os.path.exists(local_path):
                checkpoint_path = local_path
                print(f"πŸ“ Using local model: {local_path}")
            else:
                raise FileNotFoundError("No model found. Provide model_path or hf_model_id")
        
        # Create model
        model = vit_base_patch16(num_classes=len(self.class_names))
        
        # Load checkpoint
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        
        # Handle different checkpoint formats
        if 'model_state_dict' in checkpoint:
            state_dict = checkpoint['model_state_dict']
        elif 'model' in checkpoint:
            state_dict = checkpoint['model']
        else:
            state_dict = checkpoint
        
        # Load state dict
        model.load_state_dict(state_dict, strict=False)
        model = model.to(self.device)
        
        print(f"βœ… Loaded finetuned MAE model from {checkpoint_path}")
        return model

    def _apply_temperature_scaling(self, logits):
        """Apply temperature scaling to reduce overconfidence"""
        return logits / self.temperature

    def _apply_class_bias_correction(self, probs):
        """Apply bias correction to reduce cardboard overconfidence"""
        probs_corrected = probs.clone()
        
        # Find cardboard class index
        cardboard_idx = self.class_names.index('Cardboard')
        
        # Apply penalty to cardboard predictions
        probs_corrected[cardboard_idx] *= self.cardboard_penalty
        
        # Renormalize probabilities
        probs_corrected = probs_corrected / probs_corrected.sum()
        
        return probs_corrected

    def _ensemble_prediction(self, image, num_crops=5):
        """Use ensemble of augmented predictions for better stability"""
        
        # Different augmentation transforms
        augment_transforms = [
            transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.RandomResizedCrop(224, scale=(0.9, 1.0)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]),
            transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.RandomHorizontalFlip(p=1.0),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]),
            transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]),
            transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ColorJitter(brightness=0.1, contrast=0.1),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]),
            # Standard transform
            self.transform
        ]
        
        all_probs = []
        
        with torch.no_grad():
            for transform in augment_transforms[:num_crops]:
                # Apply transform
                input_tensor = transform(image).unsqueeze(0).to(self.device)
                
                # Get prediction
                logits = self.model(input_tensor)
                
                # Apply temperature scaling
                scaled_logits = self._apply_temperature_scaling(logits)
                
                # Get probabilities
                probs = F.softmax(scaled_logits, dim=1).squeeze(0)
                
                # Apply bias correction
                corrected_probs = self._apply_class_bias_correction(probs)
                
                all_probs.append(corrected_probs.cpu().numpy())
        
        # Average ensemble predictions
        ensemble_probs = np.mean(all_probs, axis=0)
        
        return ensemble_probs

    def classify_image(self, image, top_k=5, use_ensemble=True):
        """

        Classify a waste image with improved confidence calibration

        

        Args:

            image: PIL Image or path to image

            top_k: Number of top predictions to return

            use_ensemble: Whether to use ensemble prediction

        

        Returns:

            Dictionary with classification results

        """
        try:
            # Load image if path provided
            if isinstance(image, str):
                image = Image.open(image).convert('RGB')
            elif not isinstance(image, Image.Image):
                raise ValueError("Image must be PIL Image or file path")
            
            # Get predictions
            if use_ensemble:
                probs = self._ensemble_prediction(image)
            else:
                # Single prediction with improvements
                input_tensor = self.transform(image).unsqueeze(0).to(self.device)
                
                with torch.no_grad():
                    logits = self.model(input_tensor)
                    scaled_logits = self._apply_temperature_scaling(logits)
                    probs = F.softmax(scaled_logits, dim=1).squeeze(0)
                    probs = self._apply_class_bias_correction(probs)
                    probs = probs.cpu().numpy()
            
            # Get top predictions
            top_indices = np.argsort(probs)[::-1][:top_k]
            top_predictions = []
            
            for idx in top_indices:
                class_name = self.class_names[idx]
                confidence = float(probs[idx])
                
                top_predictions.append({
                    'class': class_name,
                    'confidence': confidence
                })
            
            # Determine final prediction with class-specific thresholds
            predicted_class = top_predictions[0]['class']
            predicted_confidence = top_predictions[0]['confidence']
            
            # Check if prediction meets class-specific threshold
            threshold = self.class_thresholds.get(predicted_class, 0.5)
            
            if predicted_confidence < threshold:
                # If below threshold, mark as uncertain
                predicted_class = "Uncertain"
                predicted_confidence = predicted_confidence
            
            return {
                'success': True,
                'predicted_class': predicted_class,
                'confidence': predicted_confidence,
                'top_predictions': top_predictions,
                'ensemble_used': use_ensemble,
                'temperature': self.temperature
            }
            
        except Exception as e:
            return {
                'success': False,
                'error': str(e)
            }

    def get_disposal_instructions(self, class_name):
        """Get disposal instructions for a waste class"""
        instructions = {
            'Cardboard': 'Flatten and place in recycling bin. Remove any tape or staples.',
            'Food Organics': 'Place in compost bin or organic waste collection.',
            'Glass': 'Rinse and place in glass recycling bin. Remove caps and lids.',
            'Metal': 'Rinse cans and place in metal recycling bin.',
            'Miscellaneous Trash': 'Place in general waste bin.',
            'Paper': 'Place in paper recycling bin. Remove any plastic components.',
            'Plastic': 'Check recycling number and place in appropriate plastic recycling bin.',
            'Textile Trash': 'Donate if in good condition, otherwise place in textile recycling.',
            'Vegetation': 'Compost or place in yard waste collection.',
            'Uncertain': 'Please take another photo from a different angle or with better lighting.'
        }
        
        return instructions.get(class_name, 'Please consult local waste management guidelines.')

    def get_model_info(self):
        """Get model information"""
        return {
            'model_name': 'Improved ViT-Base MAE',
            'architecture': 'Vision Transformer (ViT-Base)',
            'pretrained': 'MAE (Masked Autoencoder)',
            'num_classes': len(self.class_names),
            'device': str(self.device),
            'temperature': self.temperature,
            'cardboard_penalty': self.cardboard_penalty,
            'improvements': [
                'Temperature scaling for confidence calibration',
                'Class-specific bias correction',
                'Ensemble predictions for stability',
                'Class-specific confidence thresholds'
            ]
        }

def test_improved_classifier():
    """Test the improved classifier"""
    print("πŸ§ͺ Testing Improved MAE Waste Classifier...")
    
    # Load improved classifier
    classifier = ImprovedMAEWasteClassifier(hf_model_id="ysfad/mae-waste-classifier")
    
    # Test with a sample image
    test_image = "fail_images/image.webp"
    if os.path.exists(test_image):
        print(f"\nπŸ” Testing with {test_image}")
        
        # Test both single and ensemble prediction
        print("\n1. Single prediction:")
        result1 = classifier.classify_image(test_image, use_ensemble=False)
        if result1['success']:
            print(f"🎯 Predicted: {result1['predicted_class']} ({result1['confidence']:.3f})")
        
        print("\n2. Ensemble prediction:")
        result2 = classifier.classify_image(test_image, use_ensemble=True)
        if result2['success']:
            print(f"🎯 Predicted: {result2['predicted_class']} ({result2['confidence']:.3f})")
            print("πŸ“Š Top predictions:")
            for i, pred in enumerate(result2['top_predictions'], 1):
                print(f"  {i}. {pred['class']}: {pred['confidence']:.3f}")
    
    print("\nπŸ€– Model Info:")
    info = classifier.get_model_info()
    for key, value in info.items():
        if isinstance(value, list):
            print(f"  {key}:")
            for item in value:
                print(f"    - {item}")
        else:
            print(f"  {key}: {value}")

if __name__ == "__main__":
    test_improved_classifier()