Rhodham96 commited on
Commit
ab58b14
·
1 Parent(s): 316853f

First commit

Browse files
Files changed (4) hide show
  1. .gitignore +90 -0
  2. Dockerfile +41 -0
  3. SatelliteClassification.py +292 -0
  4. requirements.txt +8 -0
.gitignore ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ env/
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+
27
+ # Installer logs
28
+ pip-log.txt
29
+ pip-delete-this-directory.txt
30
+
31
+ # Unit test / coverage reports
32
+ htmlcov/
33
+ .tox/
34
+ .nox/
35
+ .coverage
36
+ .coverage.*
37
+ .cache
38
+ nosetests.xml
39
+ coverage.xml
40
+ *.cover
41
+ .hypothesis/
42
+ .pytest_cache/
43
+
44
+ # Jupyter Notebook
45
+ .ipynb_checkpoints
46
+
47
+ # Pyre type checker
48
+ .pyre/
49
+
50
+ # VS Code
51
+ .vscode/
52
+
53
+ # System files
54
+ .DS_Store
55
+ Thumbs.db
56
+
57
+ # Hugging Face cache
58
+ hf_cache/
59
+
60
+ # Docker
61
+ *.log
62
+
63
+ # Gradio temp
64
+ *.gradio
65
+
66
+ # Model checkpoints
67
+ *.pth
68
+ *.pt
69
+
70
+ # Environment files
71
+ .env
72
+ .env.*
73
+
74
+ # Ignore data
75
+ *.arrow
76
+ *.lock
77
+
78
+ # Ignore images and plots
79
+ *.png
80
+ *.jpg
81
+ *.jpeg
82
+ *.bmp
83
+ *.gif
84
+
85
+ # Ignore other temp files
86
+ *.tmp
87
+ *.temp
88
+
89
+ # Ignore README artifacts
90
+ *.md~
Dockerfile ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use the official Python image with a version compatible with torch and gradio
2
+ FROM python:3.11-slim
3
+
4
+ # Set environment variables
5
+ ENV PYTHONDONTWRITEBYTECODE=1 \
6
+ PYTHONUNBUFFERED=1 \
7
+ HF_HOME=/app/hf_cache
8
+
9
+ # Set work directory
10
+ WORKDIR /app
11
+
12
+ # Install system dependencies
13
+ RUN apt-get update && \
14
+ apt-get install -y --no-install-recommends \
15
+ build-essential \
16
+ git \
17
+ libglib2.0-0 \
18
+ libsm6 \
19
+ libxext6 \
20
+ libxrender-dev \
21
+ ffmpeg \
22
+ && rm -rf /var/lib/apt/lists/*
23
+
24
+ # Copy requirements
25
+ COPY requirements.txt ./
26
+
27
+ # Install Python dependencies
28
+ RUN pip install --upgrade pip && \
29
+ pip install --no-cache-dir -r requirements.txt
30
+
31
+ # Copy the rest of the code
32
+ COPY . .
33
+
34
+ # Expose port for Gradio
35
+ EXPOSE 7860
36
+
37
+ # Set Gradio to listen on all interfaces (required for Spaces)
38
+ ENV GRADIO_SERVER_NAME=0.0.0.0
39
+
40
+ # Run the app
41
+ CMD ["python", "SatelliteClassification.py"]
SatelliteClassification.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision.models import resnet18
5
+ from datasets import load_dataset
6
+ from huggingface_hub import hf_hub_download
7
+ import numpy as np
8
+ import random
9
+ from PIL import Image
10
+ import matplotlib.pyplot as plt
11
+ import io
12
+ from torch.utils.data import DataLoader
13
+ import base64
14
+
15
+ # Model architecture definition
16
+ class ResNet18_Dropout(nn.Module):
17
+ def __init__(self, in_channels, num_classes, dropout_rate=0.3):
18
+ super().__init__()
19
+ self.model = resnet18(weights=None)
20
+ self.model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
21
+ in_features = self.model.fc.in_features
22
+ self.model.fc = nn.Sequential(
23
+ nn.Dropout(dropout_rate),
24
+ nn.Linear(in_features, num_classes)
25
+ )
26
+
27
+ def forward(self, x):
28
+ return self.model(x)
29
+
30
+ def transform_multispectral_map(example):
31
+ image = np.array(example["image"], dtype=np.float32)
32
+
33
+ if image.ndim != 3 or image.shape[2] != 13:
34
+ raise ValueError(f"Expected shape (H, W, 13), got {image.shape}")
35
+
36
+ # Normalize
37
+ image = image / 2750.0
38
+ image = np.clip(image, 0, 1)
39
+
40
+ # === DATA AUGMENTATION ===
41
+ # Horizontal flip
42
+ if random.random() < 0.5:
43
+ image = np.flip(image, axis=1).copy()
44
+
45
+ # Vertical flip
46
+ if random.random() < 0.5:
47
+ image = np.flip(image, axis=0).copy()
48
+
49
+ # Rotation (by 90, 180, 270)
50
+ if random.random() < 0.5:
51
+ k = random.choice([1, 2, 3])
52
+ image = np.rot90(image, k=k, axes=(0, 1)).copy()
53
+
54
+ # === SHAPE FORMAT ===
55
+ image = image.transpose(2, 0, 1) # (C=13, H, W)
56
+
57
+ return {
58
+ "image": torch.tensor(image, dtype=torch.float32),
59
+ "label": torch.tensor(example["label"], dtype=torch.long)
60
+ }
61
+
62
+ # RGB conversion functions
63
+ def load_rgb_from_multispectral_sample(numpy_array):
64
+ """
65
+ Takes a NumPy array with 13 multispectral bands and returns a scaled RGB NumPy array.
66
+ Equivalent to loading bands 4-3-2 and scaling as GDAL would.
67
+ """
68
+ # GDAL-style scaling: scale 0–2750 -> 1–255
69
+ def scale_band(band):
70
+ band = np.clip((band / 2750) * 255, 0, 255)
71
+ return band.astype(np.uint8)
72
+
73
+ # Bands 4 (red), 3 (green), 2 (blue) => index 3, 2, 1 in 0-based
74
+ bands = [3, 2, 1]
75
+
76
+ # Ensure the input is a NumPy array
77
+ if not isinstance(numpy_array, np.ndarray):
78
+ raise TypeError("Input must be a NumPy array")
79
+
80
+ # Check if the array has the expected number of channels (13)
81
+ if numpy_array.shape[-1] != 13:
82
+ raise ValueError(f"Input array must have 13 channels, but got {numpy_array.shape[-1]}")
83
+
84
+ # Extract and scale the RGB bands from the NumPy array
85
+ rgb = np.stack([scale_band(numpy_array[:, :, b]) for b in bands], axis=-1)
86
+ return rgb
87
+
88
+ def load_rgb_from_transformed_tensor(tensor_image):
89
+ """
90
+ Takes a torch.Tensor with 13 multispectral bands (C, H, W) and returns a scaled RGB NumPy array.
91
+ """
92
+ if not isinstance(tensor_image, torch.Tensor):
93
+ raise TypeError("Input must be a torch.Tensor")
94
+ if tensor_image.shape[0] != 13:
95
+ raise ValueError(f"Expected 13 channels, got {tensor_image.shape[0]}")
96
+
97
+ # Convert to NumPy (C, H, W) → (H, W, C)
98
+ np_image = tensor_image.numpy()
99
+ np_image = np.transpose(np_image, (1, 2, 0)) # (H, W, 13)
100
+
101
+ # Bands 4-3-2 → index 3, 2, 1
102
+ bands = [3, 2, 1]
103
+
104
+ def scale_band(band):
105
+ band = np.clip((band * 255), 0, 255)
106
+ return band.astype(np.uint8)
107
+
108
+ rgb = np.stack([scale_band(np_image[:, :, b]) for b in bands], axis=-1) # (H, W, 3)
109
+ return rgb
110
+
111
+ # Global variables for model and dataset
112
+ model = None
113
+ dataset = None
114
+ label_names = None
115
+ label2id = None
116
+ id2label = None
117
+
118
+ def load_model_and_data():
119
+ """Load the model and dataset"""
120
+ global model, dataset, label_names, label2id, id2label
121
+
122
+ try:
123
+ # Load dataset
124
+ print("Loading dataset...")
125
+ dataset = load_dataset("blanchon/EuroSAT_MSI", cache_dir="./hf_cache", streaming=False)
126
+ dataset["test"] = dataset["test"].map(transform_multispectral_map)
127
+ dataset["test"].set_format(type="torch", columns=["image", "label"])
128
+
129
+ # Setup labels
130
+ label_names = dataset["train"].features['label'].names
131
+ label2id = {name: i for i, name in enumerate(label_names)}
132
+ id2label = {v: k for k, v in label2id.items()}
133
+ num_classes = len(label_names)
134
+
135
+ # Load model
136
+ print("Loading model...")
137
+ model_path = hf_hub_download(repo_id="Rhodham96/Resnet18DropoutSentinel", filename="pytorch_model.bin")
138
+ model = ResNet18_Dropout(in_channels=13, num_classes=num_classes)
139
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
140
+ model.eval()
141
+
142
+ print(f"Model and dataset loaded successfully!")
143
+ print(f"Classes: {label_names}")
144
+ return True
145
+
146
+ except Exception as e:
147
+ print(f"Error loading model or dataset: {str(e)}")
148
+ return False
149
+
150
+ def predict_images():
151
+ """Process 16 random images and return results"""
152
+ global model, dataset, id2label
153
+
154
+ if model is None or dataset is None:
155
+ return "Model or dataset not loaded. Please wait for initialization."
156
+
157
+ test_dataloader = DataLoader(dataset["test"], batch_size=32, shuffle=True)
158
+
159
+ try:
160
+ # Get 16 random samples from validation set
161
+
162
+ num_batches = 5
163
+ collected_images = []
164
+ collected_labels = []
165
+ collected_preds = []
166
+ #criterion = nn.CrossEntropyLoss()
167
+ model.eval()
168
+ with torch.no_grad():
169
+ for i, batch in enumerate(test_dataloader):
170
+ if i >= num_batches:
171
+ break
172
+ images = batch['image']
173
+ labels = batch['label']
174
+
175
+ outputs = model(images)
176
+ _, preds = outputs.max(1)
177
+
178
+ collected_images.append(images.cpu())
179
+ collected_labels.append(labels.cpu())
180
+ collected_preds.append(preds.cpu())
181
+
182
+ # Concatenate all samples
183
+ images = torch.cat(collected_images)
184
+ labels = torch.cat(collected_labels)
185
+ preds = torch.cat(collected_preds)
186
+
187
+ # Randomly select 10 indices
188
+ indices = random.sample(range(len(images)), 10)
189
+
190
+ # Prepare for plotting
191
+ selected_images = images[indices]
192
+ selected_labels = labels[indices]
193
+ selected_preds = preds[indices]
194
+ image_to_see_layers = selected_images[0]
195
+ label_to_see_layers = selected_labels[0]
196
+ # Plot
197
+ fig, axes = plt.subplots(2, 5, figsize=(15, 6))
198
+ axes = axes.flatten()
199
+
200
+ for i in range(10):
201
+ img = load_rgb_from_transformed_tensor(selected_images[i])
202
+
203
+ axes[i].imshow(img)
204
+ axes[i].axis("off")
205
+ true_label = id2label[selected_labels[i].item()]
206
+ pred_label = id2label[selected_preds[i].item()]
207
+ color = "green" if pred_label == true_label else "red"
208
+ axes[i].set_title(f"T: {true_label}\nP: {pred_label}", color=color)
209
+
210
+ plt.tight_layout()
211
+
212
+ # Convert plot to image
213
+ buf = io.BytesIO()
214
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
215
+ buf.seek(0)
216
+ plt.close()
217
+
218
+ # Convert to PIL Image
219
+ result_image = Image.open(buf)
220
+
221
+ # Calculate accuracy
222
+ correct_predictions = (selected_preds == selected_labels).sum().item()
223
+ accuracy = correct_predictions / len(selected_labels) * 100
224
+ summary = f"Accuracy: {correct_predictions}/{len(selected_labels)} ({accuracy:.1f}%)\n"
225
+ summary += f"Classes: {', '.join(label_names)}"
226
+
227
+ return result_image, summary
228
+
229
+ except Exception as e:
230
+ error_msg = f"Error during prediction: {str(e)}"
231
+ print(error_msg)
232
+ # Return a placeholder image and error message
233
+ placeholder = Image.new('RGB', (800, 600), color='lightgray')
234
+ return placeholder, error_msg
235
+
236
+ def create_interface():
237
+ """Create the Gradio interface"""
238
+
239
+ # Initialize model and data
240
+ init_success = load_model_and_data()
241
+
242
+ if not init_success:
243
+ def error_function():
244
+ placeholder = Image.new('RGB', (800, 600), color='red')
245
+ return placeholder, "Failed to load model or dataset. Please check the logs."
246
+
247
+ interface = gr.Interface(
248
+ fn=error_function,
249
+ inputs=[],
250
+ outputs=[
251
+ gr.Image(type="pil", label="Results"),
252
+ gr.Textbox(label="Summary")
253
+ ],
254
+ title="🛰️ Satellite Image Classification - ERROR",
255
+ description="Failed to initialize the application."
256
+ )
257
+ return interface
258
+
259
+ # Create the main interface
260
+ interface = gr.Interface(
261
+ fn=predict_images,
262
+ inputs=[],
263
+ outputs=[
264
+ gr.Image(type="pil", label="Classification Results (16 Random Images)"),
265
+ gr.Textbox(label="Summary", lines=3)
266
+ ],
267
+ title="🛰️ Satellite Image Classification with ResNet18",
268
+ description="""
269
+ This app classifies satellite images from the EuroSAT dataset using a trained ResNet18 model.
270
+
271
+ **How it works:**
272
+ - Loads 16 random satellite images from the validation set
273
+ - Each image has 13 spectral bands, converted to RGB for display
274
+ - Shows true labels vs predicted labels
275
+ - Green titles = correct predictions, Red titles = incorrect predictions
276
+
277
+ **Dataset:** EuroSAT with 13 multispectral bands
278
+ **Model:** ResNet18 with dropout, trained on 13-channel input
279
+
280
+ Click "Submit" to process 16 new random images!
281
+ """,
282
+ examples=[],
283
+ cache_examples=False,
284
+ allow_flagging="never"
285
+ )
286
+
287
+ return interface
288
+
289
+ # Launch the app
290
+ if __name__ == "__main__":
291
+ demo = create_interface()
292
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ datasets>=2.0.0
5
+ huggingface_hub>=0.14.0
6
+ numpy
7
+ pillow
8
+ matplotlib