Spaces:
Sleeping
Sleeping
File size: 10,035 Bytes
ab58b14 1d0d8c3 ab58b14 1d0d8c3 ab58b14 1d0d8c3 ab58b14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 |
import gradio as gr
import torch
import torch.nn as nn
from torchvision.models import resnet18
from datasets import load_dataset
from huggingface_hub import hf_hub_download
import numpy as np
import random
from PIL import Image
import matplotlib.pyplot as plt
import io
from torch.utils.data import DataLoader
import base64
# Model architecture definition
class ResNet18_Dropout(nn.Module):
def __init__(self, in_channels, num_classes, dropout_rate=0.3):
super().__init__()
self.model = resnet18(weights=None)
self.model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
in_features = self.model.fc.in_features
self.model.fc = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(in_features, num_classes)
)
def forward(self, x):
return self.model(x)
def transform_multispectral_map(example):
image = np.array(example["image"], dtype=np.float32)
if image.ndim != 3 or image.shape[2] != 13:
raise ValueError(f"Expected shape (H, W, 13), got {image.shape}")
# Normalize
image = image / 2750.0
image = np.clip(image, 0, 1)
# === DATA AUGMENTATION ===
# Horizontal flip
if random.random() < 0.5:
image = np.flip(image, axis=1).copy()
# Vertical flip
if random.random() < 0.5:
image = np.flip(image, axis=0).copy()
# Rotation (by 90, 180, 270)
if random.random() < 0.5:
k = random.choice([1, 2, 3])
image = np.rot90(image, k=k, axes=(0, 1)).copy()
# === SHAPE FORMAT ===
image = image.transpose(2, 0, 1) # (C=13, H, W)
return {
"image": torch.tensor(image, dtype=torch.float32),
"label": torch.tensor(example["label"], dtype=torch.long)
}
# RGB conversion functions
def load_rgb_from_multispectral_sample(numpy_array):
"""
Takes a NumPy array with 13 multispectral bands and returns a scaled RGB NumPy array.
Equivalent to loading bands 4-3-2 and scaling as GDAL would.
"""
# GDAL-style scaling: scale 0β2750 -> 1β255
def scale_band(band):
band = np.clip((band / 2750) * 255, 0, 255)
return band.astype(np.uint8)
# Bands 4 (red), 3 (green), 2 (blue) => index 3, 2, 1 in 0-based
bands = [3, 2, 1]
# Ensure the input is a NumPy array
if not isinstance(numpy_array, np.ndarray):
raise TypeError("Input must be a NumPy array")
# Check if the array has the expected number of channels (13)
if numpy_array.shape[-1] != 13:
raise ValueError(f"Input array must have 13 channels, but got {numpy_array.shape[-1]}")
# Extract and scale the RGB bands from the NumPy array
rgb = np.stack([scale_band(numpy_array[:, :, b]) for b in bands], axis=-1)
return rgb
def load_rgb_from_transformed_tensor(tensor_image):
"""
Takes a torch.Tensor with 13 multispectral bands (C, H, W) and returns a scaled RGB NumPy array.
"""
if not isinstance(tensor_image, torch.Tensor):
raise TypeError("Input must be a torch.Tensor")
if tensor_image.shape[0] != 13:
raise ValueError(f"Expected 13 channels, got {tensor_image.shape[0]}")
# Convert to NumPy (C, H, W) β (H, W, C)
np_image = tensor_image.numpy()
np_image = np.transpose(np_image, (1, 2, 0)) # (H, W, 13)
# Bands 4-3-2 β index 3, 2, 1
bands = [3, 2, 1]
def scale_band(band):
band = np.clip((band * 255), 0, 255)
return band.astype(np.uint8)
rgb = np.stack([scale_band(np_image[:, :, b]) for b in bands], axis=-1) # (H, W, 3)
return rgb
# Global variables for model and dataset
model = None
dataset = None
label_names = None
label2id = None
id2label = None
def load_model_and_data():
"""Load the model and dataset"""
global model, dataset, label_names, label2id, id2label
try:
# Load dataset
print("Loading dataset...")
dataset = load_dataset("blanchon/EuroSAT_MSI", cache_dir="./hf_cache", streaming=False)
dataset["test"] = dataset["test"].map(transform_multispectral_map)
dataset["test"].set_format(type="torch", columns=["image", "label"])
# Setup labels
label_names = dataset["train"].features['label'].names
label2id = {name: i for i, name in enumerate(label_names)}
id2label = {v: k for k, v in label2id.items()}
num_classes = len(label_names)
# Load model
print("Loading model...")
model_path = hf_hub_download(repo_id="Rhodham96/Resnet18DropoutSentinel", filename="pytorch_model.bin")
model = ResNet18_Dropout(in_channels=13, num_classes=num_classes)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()
print(f"Model and dataset loaded successfully!")
print(f"Classes: {label_names}")
return True
except Exception as e:
print(f"Error loading model or dataset: {str(e)}")
return False
def predict_images():
"""Process 16 random images and return results"""
global model, dataset, id2label
if model is None or dataset is None:
return "Model or dataset not loaded. Please wait for initialization."
test_dataloader = DataLoader(dataset["test"], batch_size=32, shuffle=True)
try:
# Get 16 random samples from validation set
num_batches = 5
collected_images = []
collected_labels = []
collected_preds = []
#criterion = nn.CrossEntropyLoss()
model.eval()
with torch.no_grad():
for i, batch in enumerate(test_dataloader):
if i >= num_batches:
break
images = batch['image']
labels = batch['label']
outputs = model(images)
_, preds = outputs.max(1)
collected_images.append(images.cpu())
collected_labels.append(labels.cpu())
collected_preds.append(preds.cpu())
# Concatenate all samples
images = torch.cat(collected_images)
labels = torch.cat(collected_labels)
preds = torch.cat(collected_preds)
# Randomly select 10 indices
indices = random.sample(range(len(images)), 10)
# Prepare for plotting
selected_images = images[indices]
selected_labels = labels[indices]
selected_preds = preds[indices]
image_to_see_layers = selected_images[0]
label_to_see_layers = selected_labels[0]
# Plot
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.flatten()
for i in range(10):
img = load_rgb_from_transformed_tensor(selected_images[i])
axes[i].imshow(img)
axes[i].axis("off")
true_label = id2label[selected_labels[i].item()]
pred_label = id2label[selected_preds[i].item()]
color = "green" if pred_label == true_label else "red"
axes[i].set_title(f"T: {true_label}\nP: {pred_label}", color=color)
plt.tight_layout()
# Convert plot to image
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
plt.close()
# Convert to PIL Image
result_image = Image.open(buf)
# Calculate accuracy
correct_predictions = (selected_preds == selected_labels).sum().item()
accuracy = correct_predictions / len(selected_labels) * 100
summary = f"Accuracy: {correct_predictions}/{len(selected_labels)} ({accuracy:.1f}%)\n"
summary += f"Classes: {', '.join(label_names)}"
return result_image, summary
except Exception as e:
error_msg = f"Error during prediction: {str(e)}"
print(error_msg)
# Return a placeholder image and error message
placeholder = Image.new('RGB', (800, 600), color='lightgray')
return placeholder, error_msg
def create_interface():
"""Create the Gradio interface"""
# Initialize model and data
init_success = load_model_and_data()
if not init_success:
def error_function():
placeholder = Image.new('RGB', (800, 600), color='red')
return placeholder, "Failed to load model or dataset. Please check the logs."
interface = gr.Interface(
fn=error_function,
inputs=[],
outputs=[
gr.Image(type="pil", label="Results"),
gr.Textbox(label="Summary")
],
title="π°οΈ Satellite Image Classification - ERROR",
description="Failed to initialize the application."
)
return interface
# Create the main interface
interface = gr.Interface(
fn=predict_images,
inputs=[],
outputs=[
gr.Image(type="pil", label="Classification Results (10 Random Images)"),
gr.Textbox(label="Summary", lines=3)
],
title="π°οΈ Satellite Image Classification with ResNet18",
description="""
This app classifies satellite images from the EuroSAT dataset using a trained ResNet18 model.
**How it works:**
- Loads 10 random satellite images from the test set
- Each image has 13 spectral bands, converted to RGB for display
- Shows true labels vs predicted labels
- Green titles = correct predictions, Red titles = incorrect predictions
**Dataset:** EuroSAT with 13 multispectral bands
**Model:** ResNet18 with dropout, trained on 13-channel input
Click "Generate" to process 10 new random images!
""",
examples=[],
cache_examples=False,
allow_flagging="never"
)
return interface
# Launch the app
if __name__ == "__main__":
demo = create_interface()
demo.launch(share=True) |