SatelliteClassification / SatelliteClassification.py
Rhodham96's picture
small changes in text
1d0d8c3
import gradio as gr
import torch
import torch.nn as nn
from torchvision.models import resnet18
from datasets import load_dataset
from huggingface_hub import hf_hub_download
import numpy as np
import random
from PIL import Image
import matplotlib.pyplot as plt
import io
from torch.utils.data import DataLoader
import base64
# Model architecture definition
class ResNet18_Dropout(nn.Module):
def __init__(self, in_channels, num_classes, dropout_rate=0.3):
super().__init__()
self.model = resnet18(weights=None)
self.model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
in_features = self.model.fc.in_features
self.model.fc = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(in_features, num_classes)
)
def forward(self, x):
return self.model(x)
def transform_multispectral_map(example):
image = np.array(example["image"], dtype=np.float32)
if image.ndim != 3 or image.shape[2] != 13:
raise ValueError(f"Expected shape (H, W, 13), got {image.shape}")
# Normalize
image = image / 2750.0
image = np.clip(image, 0, 1)
# === DATA AUGMENTATION ===
# Horizontal flip
if random.random() < 0.5:
image = np.flip(image, axis=1).copy()
# Vertical flip
if random.random() < 0.5:
image = np.flip(image, axis=0).copy()
# Rotation (by 90, 180, 270)
if random.random() < 0.5:
k = random.choice([1, 2, 3])
image = np.rot90(image, k=k, axes=(0, 1)).copy()
# === SHAPE FORMAT ===
image = image.transpose(2, 0, 1) # (C=13, H, W)
return {
"image": torch.tensor(image, dtype=torch.float32),
"label": torch.tensor(example["label"], dtype=torch.long)
}
# RGB conversion functions
def load_rgb_from_multispectral_sample(numpy_array):
"""
Takes a NumPy array with 13 multispectral bands and returns a scaled RGB NumPy array.
Equivalent to loading bands 4-3-2 and scaling as GDAL would.
"""
# GDAL-style scaling: scale 0–2750 -> 1–255
def scale_band(band):
band = np.clip((band / 2750) * 255, 0, 255)
return band.astype(np.uint8)
# Bands 4 (red), 3 (green), 2 (blue) => index 3, 2, 1 in 0-based
bands = [3, 2, 1]
# Ensure the input is a NumPy array
if not isinstance(numpy_array, np.ndarray):
raise TypeError("Input must be a NumPy array")
# Check if the array has the expected number of channels (13)
if numpy_array.shape[-1] != 13:
raise ValueError(f"Input array must have 13 channels, but got {numpy_array.shape[-1]}")
# Extract and scale the RGB bands from the NumPy array
rgb = np.stack([scale_band(numpy_array[:, :, b]) for b in bands], axis=-1)
return rgb
def load_rgb_from_transformed_tensor(tensor_image):
"""
Takes a torch.Tensor with 13 multispectral bands (C, H, W) and returns a scaled RGB NumPy array.
"""
if not isinstance(tensor_image, torch.Tensor):
raise TypeError("Input must be a torch.Tensor")
if tensor_image.shape[0] != 13:
raise ValueError(f"Expected 13 channels, got {tensor_image.shape[0]}")
# Convert to NumPy (C, H, W) β†’ (H, W, C)
np_image = tensor_image.numpy()
np_image = np.transpose(np_image, (1, 2, 0)) # (H, W, 13)
# Bands 4-3-2 β†’ index 3, 2, 1
bands = [3, 2, 1]
def scale_band(band):
band = np.clip((band * 255), 0, 255)
return band.astype(np.uint8)
rgb = np.stack([scale_band(np_image[:, :, b]) for b in bands], axis=-1) # (H, W, 3)
return rgb
# Global variables for model and dataset
model = None
dataset = None
label_names = None
label2id = None
id2label = None
def load_model_and_data():
"""Load the model and dataset"""
global model, dataset, label_names, label2id, id2label
try:
# Load dataset
print("Loading dataset...")
dataset = load_dataset("blanchon/EuroSAT_MSI", cache_dir="./hf_cache", streaming=False)
dataset["test"] = dataset["test"].map(transform_multispectral_map)
dataset["test"].set_format(type="torch", columns=["image", "label"])
# Setup labels
label_names = dataset["train"].features['label'].names
label2id = {name: i for i, name in enumerate(label_names)}
id2label = {v: k for k, v in label2id.items()}
num_classes = len(label_names)
# Load model
print("Loading model...")
model_path = hf_hub_download(repo_id="Rhodham96/Resnet18DropoutSentinel", filename="pytorch_model.bin")
model = ResNet18_Dropout(in_channels=13, num_classes=num_classes)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()
print(f"Model and dataset loaded successfully!")
print(f"Classes: {label_names}")
return True
except Exception as e:
print(f"Error loading model or dataset: {str(e)}")
return False
def predict_images():
"""Process 16 random images and return results"""
global model, dataset, id2label
if model is None or dataset is None:
return "Model or dataset not loaded. Please wait for initialization."
test_dataloader = DataLoader(dataset["test"], batch_size=32, shuffle=True)
try:
# Get 16 random samples from validation set
num_batches = 5
collected_images = []
collected_labels = []
collected_preds = []
#criterion = nn.CrossEntropyLoss()
model.eval()
with torch.no_grad():
for i, batch in enumerate(test_dataloader):
if i >= num_batches:
break
images = batch['image']
labels = batch['label']
outputs = model(images)
_, preds = outputs.max(1)
collected_images.append(images.cpu())
collected_labels.append(labels.cpu())
collected_preds.append(preds.cpu())
# Concatenate all samples
images = torch.cat(collected_images)
labels = torch.cat(collected_labels)
preds = torch.cat(collected_preds)
# Randomly select 10 indices
indices = random.sample(range(len(images)), 10)
# Prepare for plotting
selected_images = images[indices]
selected_labels = labels[indices]
selected_preds = preds[indices]
image_to_see_layers = selected_images[0]
label_to_see_layers = selected_labels[0]
# Plot
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.flatten()
for i in range(10):
img = load_rgb_from_transformed_tensor(selected_images[i])
axes[i].imshow(img)
axes[i].axis("off")
true_label = id2label[selected_labels[i].item()]
pred_label = id2label[selected_preds[i].item()]
color = "green" if pred_label == true_label else "red"
axes[i].set_title(f"T: {true_label}\nP: {pred_label}", color=color)
plt.tight_layout()
# Convert plot to image
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
plt.close()
# Convert to PIL Image
result_image = Image.open(buf)
# Calculate accuracy
correct_predictions = (selected_preds == selected_labels).sum().item()
accuracy = correct_predictions / len(selected_labels) * 100
summary = f"Accuracy: {correct_predictions}/{len(selected_labels)} ({accuracy:.1f}%)\n"
summary += f"Classes: {', '.join(label_names)}"
return result_image, summary
except Exception as e:
error_msg = f"Error during prediction: {str(e)}"
print(error_msg)
# Return a placeholder image and error message
placeholder = Image.new('RGB', (800, 600), color='lightgray')
return placeholder, error_msg
def create_interface():
"""Create the Gradio interface"""
# Initialize model and data
init_success = load_model_and_data()
if not init_success:
def error_function():
placeholder = Image.new('RGB', (800, 600), color='red')
return placeholder, "Failed to load model or dataset. Please check the logs."
interface = gr.Interface(
fn=error_function,
inputs=[],
outputs=[
gr.Image(type="pil", label="Results"),
gr.Textbox(label="Summary")
],
title="πŸ›°οΈ Satellite Image Classification - ERROR",
description="Failed to initialize the application."
)
return interface
# Create the main interface
interface = gr.Interface(
fn=predict_images,
inputs=[],
outputs=[
gr.Image(type="pil", label="Classification Results (10 Random Images)"),
gr.Textbox(label="Summary", lines=3)
],
title="πŸ›°οΈ Satellite Image Classification with ResNet18",
description="""
This app classifies satellite images from the EuroSAT dataset using a trained ResNet18 model.
**How it works:**
- Loads 10 random satellite images from the test set
- Each image has 13 spectral bands, converted to RGB for display
- Shows true labels vs predicted labels
- Green titles = correct predictions, Red titles = incorrect predictions
**Dataset:** EuroSAT with 13 multispectral bands
**Model:** ResNet18 with dropout, trained on 13-channel input
Click "Generate" to process 10 new random images!
""",
examples=[],
cache_examples=False,
allow_flagging="never"
)
return interface
# Launch the app
if __name__ == "__main__":
demo = create_interface()
demo.launch(share=True)