File size: 8,403 Bytes
0007f63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
#!/usr/bin/env python3
"""MAE ViT-Base waste classifier for inference."""

import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import timm
import os
import json
from huggingface_hub import hf_hub_download

class MAEWasteClassifier:
    """Waste classifier using finetuned MAE ViT-Base model."""
    
    def __init__(self, model_path=None, hf_model_id="ysfad/mae-waste-classifier", device=None):
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.hf_model_id = hf_model_id
        
        # Try to load model from different sources
        if model_path and os.path.exists(model_path):
            self.model_path = model_path
            print(f"πŸ“ Using local model: {model_path}")
        else:
            # Try to download from HF Hub
            try:
                print(f"🌐 Downloading model from HF Hub: {hf_model_id}")
                self.model_path = hf_hub_download(
                    repo_id=hf_model_id,
                    filename="best_model.pth",
                    cache_dir="./hf_cache"
                )
                print(f"βœ… Downloaded model to: {self.model_path}")
            except Exception as e:
                print(f"⚠️ Could not download from HF Hub: {e}")
                # Fallback to local path
                self.model_path = "output_simple_mae/best_model.pth"
                if not os.path.exists(self.model_path):
                    raise FileNotFoundError(f"Model not found locally at {self.model_path} and could not download from HF Hub")
        
        # Class names from training
        self.class_names = [
            'Cardboard', 'Food Organics', 'Glass', 'Metal', 
            'Miscellaneous Trash', 'Paper', 'Plastic', 
            'Textile Trash', 'Vegetation'
        ]
        
        # Load disposal instructions
        self.disposal_instructions = {
            "Cardboard": "Flatten and place in recycling bin. Remove any tape or staples.",
            "Food Organics": "Compost in organic waste bin or home composter.",
            "Glass": "Rinse and place in glass recycling. Remove lids and caps.",
            "Metal": "Rinse aluminum/steel cans and place in recycling bin.",
            "Miscellaneous Trash": "Dispose in general waste bin. Cannot be recycled.",
            "Paper": "Place clean paper in recycling. Remove plastic windows from envelopes.",
            "Plastic": "Check recycling number. Rinse containers before recycling.",
            "Textile Trash": "Donate if reusable, otherwise dispose in textile recycling.",
            "Vegetation": "Compost in organic waste or use for mulch in garden."
        }
        
        # Load model
        self.model = self._load_model()
        
        # Image preprocessing
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        print(f"βœ… MAE Waste Classifier loaded on {self.device}")
        print(f"πŸ“Š Model: ViT-Base MAE, Classes: {len(self.class_names)}")
    
    def _load_model(self):
        """Load the finetuned MAE model."""
        try:
            # Create ViT model using timm
            model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=len(self.class_names))
            
            # Load checkpoint
            checkpoint = torch.load(self.model_path, map_location=self.device)
            
            # Load state dict
            if 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
            else:
                model.load_state_dict(checkpoint)
            
            model.to(self.device)
            model.eval()
            
            print(f"βœ… Loaded finetuned MAE model from {self.model_path}")
            return model
            
        except Exception as e:
            print(f"❌ Error loading model: {e}")
            raise
    
    def classify_image(self, image, top_k=5):
        """

        Classify a waste image.

        

        Args:

            image: PIL Image or path to image

            top_k: Number of top predictions to return

            

        Returns:

            dict: Classification results

        """
        try:
            # Load and preprocess image
            if isinstance(image, str):
                image = Image.open(image).convert('RGB')
            elif not isinstance(image, Image.Image):
                raise ValueError("Image must be PIL Image or path string")
            
            # Preprocess
            input_tensor = self.transform(image).unsqueeze(0).to(self.device)
            
            # Inference
            with torch.no_grad():
                outputs = self.model(input_tensor)
                probabilities = F.softmax(outputs, dim=1)
                
                # Get top predictions
                top_probs, top_indices = torch.topk(probabilities, k=min(top_k, len(self.class_names)))
                
                top_predictions = []
                for prob, idx in zip(top_probs[0], top_indices[0]):
                    top_predictions.append({
                        'class': self.class_names[idx.item()],
                        'confidence': prob.item()
                    })
                
                # Best prediction
                best_pred = top_predictions[0]
                
                return {
                    'success': True,
                    'predicted_class': best_pred['class'],
                    'confidence': best_pred['confidence'],
                    'top_predictions': top_predictions
                }
                
        except Exception as e:
            return {
                'success': False,
                'error': str(e)
            }
    
    def get_disposal_instructions(self, class_name):
        """Get disposal instructions for a waste class."""
        return self.disposal_instructions.get(class_name, "No specific instructions available.")
    
    def get_model_info(self):
        """Get information about the loaded model."""
        return {
            'model_name': 'ViT-Base MAE',
            'architecture': 'Vision Transformer (ViT-Base)',
            'pretrained': 'MAE (Masked Autoencoder)',
            'num_classes': len(self.class_names),
            'device': self.device,
            'model_path': self.model_path
        }

# Test the classifier
if __name__ == "__main__":
    print("πŸ§ͺ Testing MAE Waste Classifier...")
    
    try:
        # Initialize classifier
        classifier = MAEWasteClassifier()
        
        # Test with a sample image if available
        test_images = [
            "fail_images/image.webp",
            "fail_images/IMG_9501.webp"
        ]
        
        for img_path in test_images:
            if os.path.exists(img_path):
                print(f"\nπŸ” Testing with {img_path}")
                result = classifier.classify_image(img_path)
                
                if result['success']:
                    print(f"βœ… Predicted: {result['predicted_class']} ({result['confidence']:.3f})")
                    print(f"πŸ“‹ Instructions: {classifier.get_disposal_instructions(result['predicted_class'])}")
                    
                    print("\nπŸ“Š Top predictions:")
                    for i, pred in enumerate(result['top_predictions'][:3], 1):
                        print(f"  {i}. {pred['class']}: {pred['confidence']:.3f}")
                else:
                    print(f"❌ Error: {result['error']}")
                break
        else:
            print("ℹ️ No test images found, but classifier loaded successfully!")
            
        # Print model info
        info = classifier.get_model_info()
        print(f"\nπŸ€– Model Info:")
        for key, value in info.items():
            print(f"  {key}: {value}")
            
        print("\nSuccess!")
        
    except Exception as e:
        print(f"❌ Error: {e}")
        import traceback
        traceback.print_exc()