ysfad commited on
Commit
905ac99
·
verified ·
1 Parent(s): de63d9f

Update: Enhanced classifier with temperature scaling

Browse files
Files changed (1) hide show
  1. improved_mae_classifier.py +346 -0
improved_mae_classifier.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Improved MAE Waste Classifier with temperature scaling and bias correction
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+ from PIL import Image
12
+ from torchvision import transforms
13
+ from huggingface_hub import hf_hub_download
14
+ import warnings
15
+ warnings.filterwarnings("ignore")
16
+
17
+ # Import MAE model
18
+ from mae.models_vit import vit_base_patch16
19
+
20
+ class ImprovedMAEWasteClassifier:
21
+ def __init__(self,
22
+ model_path=None,
23
+ hf_model_id=None,
24
+ device=None,
25
+ temperature=2.5, # Temperature scaling to reduce overconfidence
26
+ cardboard_penalty=0.8): # Penalty factor for cardboard predictions
27
+ """
28
+ Initialize improved MAE waste classifier with bias correction
29
+
30
+ Args:
31
+ model_path: Local path to model file
32
+ hf_model_id: Hugging Face model ID
33
+ device: Device to run on
34
+ temperature: Temperature scaling factor (>1 reduces confidence)
35
+ cardboard_penalty: Penalty factor for cardboard predictions
36
+ """
37
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
38
+ self.temperature = temperature
39
+ self.cardboard_penalty = cardboard_penalty
40
+
41
+ # Class names (must match training order)
42
+ self.class_names = [
43
+ 'Cardboard', 'Food Organics', 'Glass', 'Metal',
44
+ 'Miscellaneous Trash', 'Paper', 'Plastic', 'Textile Trash', 'Vegetation'
45
+ ]
46
+
47
+ # Class-specific confidence thresholds
48
+ self.class_thresholds = {
49
+ 'Cardboard': 0.8, # Higher threshold for cardboard
50
+ 'Plastic': 0.6,
51
+ 'Metal': 0.6,
52
+ 'Glass': 0.6,
53
+ 'Paper': 0.6,
54
+ 'Food Organics': 0.5,
55
+ 'Miscellaneous Trash': 0.5,
56
+ 'Textile Trash': 0.4, # Lower threshold for underrepresented class
57
+ 'Vegetation': 0.5
58
+ }
59
+
60
+ # Load model
61
+ self.model = self._load_model(model_path, hf_model_id)
62
+ self.model.eval()
63
+
64
+ # Data preprocessing
65
+ self.transform = transforms.Compose([
66
+ transforms.Resize((224, 224)),
67
+ transforms.ToTensor(),
68
+ transforms.Normalize(
69
+ mean=[0.485, 0.456, 0.406],
70
+ std=[0.229, 0.224, 0.225]
71
+ )
72
+ ])
73
+
74
+ print(f"✅ Improved MAE Classifier loaded on {self.device}")
75
+ print(f"🌡️ Temperature scaling: {self.temperature}")
76
+ print(f"🗂️ Cardboard penalty: {self.cardboard_penalty}")
77
+
78
+ def _load_model(self, model_path=None, hf_model_id=None):
79
+ """Load the finetuned MAE model"""
80
+
81
+ # Determine model path
82
+ if model_path and os.path.exists(model_path):
83
+ checkpoint_path = model_path
84
+ print(f"📁 Loading local model from {model_path}")
85
+ elif hf_model_id:
86
+ print(f"🌐 Downloading model from HF Hub: {hf_model_id}")
87
+ checkpoint_path = hf_hub_download(
88
+ repo_id=hf_model_id,
89
+ filename="best_model.pth",
90
+ cache_dir="./hf_cache"
91
+ )
92
+ print(f"✅ Downloaded model to: {checkpoint_path}")
93
+ else:
94
+ # Try local file
95
+ local_path = "output_simple_mae/best_model.pth"
96
+ if os.path.exists(local_path):
97
+ checkpoint_path = local_path
98
+ print(f"📁 Using local model: {local_path}")
99
+ else:
100
+ raise FileNotFoundError("No model found. Provide model_path or hf_model_id")
101
+
102
+ # Create model
103
+ model = vit_base_patch16(num_classes=len(self.class_names))
104
+
105
+ # Load checkpoint
106
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
107
+
108
+ # Handle different checkpoint formats
109
+ if 'model_state_dict' in checkpoint:
110
+ state_dict = checkpoint['model_state_dict']
111
+ elif 'model' in checkpoint:
112
+ state_dict = checkpoint['model']
113
+ else:
114
+ state_dict = checkpoint
115
+
116
+ # Load state dict
117
+ model.load_state_dict(state_dict, strict=False)
118
+ model = model.to(self.device)
119
+
120
+ print(f"✅ Loaded finetuned MAE model from {checkpoint_path}")
121
+ return model
122
+
123
+ def _apply_temperature_scaling(self, logits):
124
+ """Apply temperature scaling to reduce overconfidence"""
125
+ return logits / self.temperature
126
+
127
+ def _apply_class_bias_correction(self, probs):
128
+ """Apply bias correction to reduce cardboard overconfidence"""
129
+ probs_corrected = probs.clone()
130
+
131
+ # Find cardboard class index
132
+ cardboard_idx = self.class_names.index('Cardboard')
133
+
134
+ # Apply penalty to cardboard predictions
135
+ probs_corrected[cardboard_idx] *= self.cardboard_penalty
136
+
137
+ # Renormalize probabilities
138
+ probs_corrected = probs_corrected / probs_corrected.sum()
139
+
140
+ return probs_corrected
141
+
142
+ def _ensemble_prediction(self, image, num_crops=5):
143
+ """Use ensemble of augmented predictions for better stability"""
144
+
145
+ # Different augmentation transforms
146
+ augment_transforms = [
147
+ transforms.Compose([
148
+ transforms.Resize((256, 256)),
149
+ transforms.RandomResizedCrop(224, scale=(0.9, 1.0)),
150
+ transforms.ToTensor(),
151
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
152
+ ]),
153
+ transforms.Compose([
154
+ transforms.Resize((224, 224)),
155
+ transforms.RandomHorizontalFlip(p=1.0),
156
+ transforms.ToTensor(),
157
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
158
+ ]),
159
+ transforms.Compose([
160
+ transforms.Resize((256, 256)),
161
+ transforms.CenterCrop(224),
162
+ transforms.ToTensor(),
163
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
164
+ ]),
165
+ transforms.Compose([
166
+ transforms.Resize((224, 224)),
167
+ transforms.ColorJitter(brightness=0.1, contrast=0.1),
168
+ transforms.ToTensor(),
169
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
170
+ ]),
171
+ # Standard transform
172
+ self.transform
173
+ ]
174
+
175
+ all_probs = []
176
+
177
+ with torch.no_grad():
178
+ for transform in augment_transforms[:num_crops]:
179
+ # Apply transform
180
+ input_tensor = transform(image).unsqueeze(0).to(self.device)
181
+
182
+ # Get prediction
183
+ logits = self.model(input_tensor)
184
+
185
+ # Apply temperature scaling
186
+ scaled_logits = self._apply_temperature_scaling(logits)
187
+
188
+ # Get probabilities
189
+ probs = F.softmax(scaled_logits, dim=1).squeeze(0)
190
+
191
+ # Apply bias correction
192
+ corrected_probs = self._apply_class_bias_correction(probs)
193
+
194
+ all_probs.append(corrected_probs.cpu().numpy())
195
+
196
+ # Average ensemble predictions
197
+ ensemble_probs = np.mean(all_probs, axis=0)
198
+
199
+ return ensemble_probs
200
+
201
+ def classify_image(self, image, top_k=5, use_ensemble=True):
202
+ """
203
+ Classify a waste image with improved confidence calibration
204
+
205
+ Args:
206
+ image: PIL Image or path to image
207
+ top_k: Number of top predictions to return
208
+ use_ensemble: Whether to use ensemble prediction
209
+
210
+ Returns:
211
+ Dictionary with classification results
212
+ """
213
+ try:
214
+ # Load image if path provided
215
+ if isinstance(image, str):
216
+ image = Image.open(image).convert('RGB')
217
+ elif not isinstance(image, Image.Image):
218
+ raise ValueError("Image must be PIL Image or file path")
219
+
220
+ # Get predictions
221
+ if use_ensemble:
222
+ probs = self._ensemble_prediction(image)
223
+ else:
224
+ # Single prediction with improvements
225
+ input_tensor = self.transform(image).unsqueeze(0).to(self.device)
226
+
227
+ with torch.no_grad():
228
+ logits = self.model(input_tensor)
229
+ scaled_logits = self._apply_temperature_scaling(logits)
230
+ probs = F.softmax(scaled_logits, dim=1).squeeze(0)
231
+ probs = self._apply_class_bias_correction(probs)
232
+ probs = probs.cpu().numpy()
233
+
234
+ # Get top predictions
235
+ top_indices = np.argsort(probs)[::-1][:top_k]
236
+ top_predictions = []
237
+
238
+ for idx in top_indices:
239
+ class_name = self.class_names[idx]
240
+ confidence = float(probs[idx])
241
+
242
+ top_predictions.append({
243
+ 'class': class_name,
244
+ 'confidence': confidence
245
+ })
246
+
247
+ # Determine final prediction with class-specific thresholds
248
+ predicted_class = top_predictions[0]['class']
249
+ predicted_confidence = top_predictions[0]['confidence']
250
+
251
+ # Check if prediction meets class-specific threshold
252
+ threshold = self.class_thresholds.get(predicted_class, 0.5)
253
+
254
+ if predicted_confidence < threshold:
255
+ # If below threshold, mark as uncertain
256
+ predicted_class = "Uncertain"
257
+ predicted_confidence = predicted_confidence
258
+
259
+ return {
260
+ 'success': True,
261
+ 'predicted_class': predicted_class,
262
+ 'confidence': predicted_confidence,
263
+ 'top_predictions': top_predictions,
264
+ 'ensemble_used': use_ensemble,
265
+ 'temperature': self.temperature
266
+ }
267
+
268
+ except Exception as e:
269
+ return {
270
+ 'success': False,
271
+ 'error': str(e)
272
+ }
273
+
274
+ def get_disposal_instructions(self, class_name):
275
+ """Get disposal instructions for a waste class"""
276
+ instructions = {
277
+ 'Cardboard': 'Flatten and place in recycling bin. Remove any tape or staples.',
278
+ 'Food Organics': 'Place in compost bin or organic waste collection.',
279
+ 'Glass': 'Rinse and place in glass recycling bin. Remove caps and lids.',
280
+ 'Metal': 'Rinse cans and place in metal recycling bin.',
281
+ 'Miscellaneous Trash': 'Place in general waste bin.',
282
+ 'Paper': 'Place in paper recycling bin. Remove any plastic components.',
283
+ 'Plastic': 'Check recycling number and place in appropriate plastic recycling bin.',
284
+ 'Textile Trash': 'Donate if in good condition, otherwise place in textile recycling.',
285
+ 'Vegetation': 'Compost or place in yard waste collection.',
286
+ 'Uncertain': 'Please take another photo from a different angle or with better lighting.'
287
+ }
288
+
289
+ return instructions.get(class_name, 'Please consult local waste management guidelines.')
290
+
291
+ def get_model_info(self):
292
+ """Get model information"""
293
+ return {
294
+ 'model_name': 'Improved ViT-Base MAE',
295
+ 'architecture': 'Vision Transformer (ViT-Base)',
296
+ 'pretrained': 'MAE (Masked Autoencoder)',
297
+ 'num_classes': len(self.class_names),
298
+ 'device': str(self.device),
299
+ 'temperature': self.temperature,
300
+ 'cardboard_penalty': self.cardboard_penalty,
301
+ 'improvements': [
302
+ 'Temperature scaling for confidence calibration',
303
+ 'Class-specific bias correction',
304
+ 'Ensemble predictions for stability',
305
+ 'Class-specific confidence thresholds'
306
+ ]
307
+ }
308
+
309
+ def test_improved_classifier():
310
+ """Test the improved classifier"""
311
+ print("🧪 Testing Improved MAE Waste Classifier...")
312
+
313
+ # Load improved classifier
314
+ classifier = ImprovedMAEWasteClassifier(hf_model_id="ysfad/mae-waste-classifier")
315
+
316
+ # Test with a sample image
317
+ test_image = "fail_images/image.webp"
318
+ if os.path.exists(test_image):
319
+ print(f"\n🔍 Testing with {test_image}")
320
+
321
+ # Test both single and ensemble prediction
322
+ print("\n1. Single prediction:")
323
+ result1 = classifier.classify_image(test_image, use_ensemble=False)
324
+ if result1['success']:
325
+ print(f"🎯 Predicted: {result1['predicted_class']} ({result1['confidence']:.3f})")
326
+
327
+ print("\n2. Ensemble prediction:")
328
+ result2 = classifier.classify_image(test_image, use_ensemble=True)
329
+ if result2['success']:
330
+ print(f"🎯 Predicted: {result2['predicted_class']} ({result2['confidence']:.3f})")
331
+ print("📊 Top predictions:")
332
+ for i, pred in enumerate(result2['top_predictions'], 1):
333
+ print(f" {i}. {pred['class']}: {pred['confidence']:.3f}")
334
+
335
+ print("\n🤖 Model Info:")
336
+ info = classifier.get_model_info()
337
+ for key, value in info.items():
338
+ if isinstance(value, list):
339
+ print(f" {key}:")
340
+ for item in value:
341
+ print(f" - {item}")
342
+ else:
343
+ print(f" {key}: {value}")
344
+
345
+ if __name__ == "__main__":
346
+ test_improved_classifier()