ysfad commited on
Commit
0007f63
Β·
verified Β·
1 Parent(s): e15cf70

Upload mae_waste_classifier.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mae_waste_classifier.py +209 -0
mae_waste_classifier.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """MAE ViT-Base waste classifier for inference."""
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torchvision import transforms
7
+ from PIL import Image
8
+ import timm
9
+ import os
10
+ import json
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ class MAEWasteClassifier:
14
+ """Waste classifier using finetuned MAE ViT-Base model."""
15
+
16
+ def __init__(self, model_path=None, hf_model_id="ysfad/mae-waste-classifier", device=None):
17
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
18
+ self.hf_model_id = hf_model_id
19
+
20
+ # Try to load model from different sources
21
+ if model_path and os.path.exists(model_path):
22
+ self.model_path = model_path
23
+ print(f"πŸ“ Using local model: {model_path}")
24
+ else:
25
+ # Try to download from HF Hub
26
+ try:
27
+ print(f"🌐 Downloading model from HF Hub: {hf_model_id}")
28
+ self.model_path = hf_hub_download(
29
+ repo_id=hf_model_id,
30
+ filename="best_model.pth",
31
+ cache_dir="./hf_cache"
32
+ )
33
+ print(f"βœ… Downloaded model to: {self.model_path}")
34
+ except Exception as e:
35
+ print(f"⚠️ Could not download from HF Hub: {e}")
36
+ # Fallback to local path
37
+ self.model_path = "output_simple_mae/best_model.pth"
38
+ if not os.path.exists(self.model_path):
39
+ raise FileNotFoundError(f"Model not found locally at {self.model_path} and could not download from HF Hub")
40
+
41
+ # Class names from training
42
+ self.class_names = [
43
+ 'Cardboard', 'Food Organics', 'Glass', 'Metal',
44
+ 'Miscellaneous Trash', 'Paper', 'Plastic',
45
+ 'Textile Trash', 'Vegetation'
46
+ ]
47
+
48
+ # Load disposal instructions
49
+ self.disposal_instructions = {
50
+ "Cardboard": "Flatten and place in recycling bin. Remove any tape or staples.",
51
+ "Food Organics": "Compost in organic waste bin or home composter.",
52
+ "Glass": "Rinse and place in glass recycling. Remove lids and caps.",
53
+ "Metal": "Rinse aluminum/steel cans and place in recycling bin.",
54
+ "Miscellaneous Trash": "Dispose in general waste bin. Cannot be recycled.",
55
+ "Paper": "Place clean paper in recycling. Remove plastic windows from envelopes.",
56
+ "Plastic": "Check recycling number. Rinse containers before recycling.",
57
+ "Textile Trash": "Donate if reusable, otherwise dispose in textile recycling.",
58
+ "Vegetation": "Compost in organic waste or use for mulch in garden."
59
+ }
60
+
61
+ # Load model
62
+ self.model = self._load_model()
63
+
64
+ # Image preprocessing
65
+ self.transform = transforms.Compose([
66
+ transforms.Resize((224, 224)),
67
+ transforms.ToTensor(),
68
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
69
+ ])
70
+
71
+ print(f"βœ… MAE Waste Classifier loaded on {self.device}")
72
+ print(f"πŸ“Š Model: ViT-Base MAE, Classes: {len(self.class_names)}")
73
+
74
+ def _load_model(self):
75
+ """Load the finetuned MAE model."""
76
+ try:
77
+ # Create ViT model using timm
78
+ model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=len(self.class_names))
79
+
80
+ # Load checkpoint
81
+ checkpoint = torch.load(self.model_path, map_location=self.device)
82
+
83
+ # Load state dict
84
+ if 'model_state_dict' in checkpoint:
85
+ model.load_state_dict(checkpoint['model_state_dict'])
86
+ else:
87
+ model.load_state_dict(checkpoint)
88
+
89
+ model.to(self.device)
90
+ model.eval()
91
+
92
+ print(f"βœ… Loaded finetuned MAE model from {self.model_path}")
93
+ return model
94
+
95
+ except Exception as e:
96
+ print(f"❌ Error loading model: {e}")
97
+ raise
98
+
99
+ def classify_image(self, image, top_k=5):
100
+ """
101
+ Classify a waste image.
102
+
103
+ Args:
104
+ image: PIL Image or path to image
105
+ top_k: Number of top predictions to return
106
+
107
+ Returns:
108
+ dict: Classification results
109
+ """
110
+ try:
111
+ # Load and preprocess image
112
+ if isinstance(image, str):
113
+ image = Image.open(image).convert('RGB')
114
+ elif not isinstance(image, Image.Image):
115
+ raise ValueError("Image must be PIL Image or path string")
116
+
117
+ # Preprocess
118
+ input_tensor = self.transform(image).unsqueeze(0).to(self.device)
119
+
120
+ # Inference
121
+ with torch.no_grad():
122
+ outputs = self.model(input_tensor)
123
+ probabilities = F.softmax(outputs, dim=1)
124
+
125
+ # Get top predictions
126
+ top_probs, top_indices = torch.topk(probabilities, k=min(top_k, len(self.class_names)))
127
+
128
+ top_predictions = []
129
+ for prob, idx in zip(top_probs[0], top_indices[0]):
130
+ top_predictions.append({
131
+ 'class': self.class_names[idx.item()],
132
+ 'confidence': prob.item()
133
+ })
134
+
135
+ # Best prediction
136
+ best_pred = top_predictions[0]
137
+
138
+ return {
139
+ 'success': True,
140
+ 'predicted_class': best_pred['class'],
141
+ 'confidence': best_pred['confidence'],
142
+ 'top_predictions': top_predictions
143
+ }
144
+
145
+ except Exception as e:
146
+ return {
147
+ 'success': False,
148
+ 'error': str(e)
149
+ }
150
+
151
+ def get_disposal_instructions(self, class_name):
152
+ """Get disposal instructions for a waste class."""
153
+ return self.disposal_instructions.get(class_name, "No specific instructions available.")
154
+
155
+ def get_model_info(self):
156
+ """Get information about the loaded model."""
157
+ return {
158
+ 'model_name': 'ViT-Base MAE',
159
+ 'architecture': 'Vision Transformer (ViT-Base)',
160
+ 'pretrained': 'MAE (Masked Autoencoder)',
161
+ 'num_classes': len(self.class_names),
162
+ 'device': self.device,
163
+ 'model_path': self.model_path
164
+ }
165
+
166
+ # Test the classifier
167
+ if __name__ == "__main__":
168
+ print("πŸ§ͺ Testing MAE Waste Classifier...")
169
+
170
+ try:
171
+ # Initialize classifier
172
+ classifier = MAEWasteClassifier()
173
+
174
+ # Test with a sample image if available
175
+ test_images = [
176
+ "fail_images/image.webp",
177
+ "fail_images/IMG_9501.webp"
178
+ ]
179
+
180
+ for img_path in test_images:
181
+ if os.path.exists(img_path):
182
+ print(f"\nπŸ” Testing with {img_path}")
183
+ result = classifier.classify_image(img_path)
184
+
185
+ if result['success']:
186
+ print(f"βœ… Predicted: {result['predicted_class']} ({result['confidence']:.3f})")
187
+ print(f"πŸ“‹ Instructions: {classifier.get_disposal_instructions(result['predicted_class'])}")
188
+
189
+ print("\nπŸ“Š Top predictions:")
190
+ for i, pred in enumerate(result['top_predictions'][:3], 1):
191
+ print(f" {i}. {pred['class']}: {pred['confidence']:.3f}")
192
+ else:
193
+ print(f"❌ Error: {result['error']}")
194
+ break
195
+ else:
196
+ print("ℹ️ No test images found, but classifier loaded successfully!")
197
+
198
+ # Print model info
199
+ info = classifier.get_model_info()
200
+ print(f"\nπŸ€– Model Info:")
201
+ for key, value in info.items():
202
+ print(f" {key}: {value}")
203
+
204
+ print("\nSuccess!")
205
+
206
+ except Exception as e:
207
+ print(f"❌ Error: {e}")
208
+ import traceback
209
+ traceback.print_exc()