Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from torchvision.models import resnet18 | |
from datasets import load_dataset | |
from huggingface_hub import hf_hub_download | |
import numpy as np | |
import random | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import io | |
from torch.utils.data import DataLoader | |
import base64 | |
# Model architecture definition | |
class ResNet18_Dropout(nn.Module): | |
def __init__(self, in_channels, num_classes, dropout_rate=0.3): | |
super().__init__() | |
self.model = resnet18(weights=None) | |
self.model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) | |
in_features = self.model.fc.in_features | |
self.model.fc = nn.Sequential( | |
nn.Dropout(dropout_rate), | |
nn.Linear(in_features, num_classes) | |
) | |
def forward(self, x): | |
return self.model(x) | |
def transform_multispectral_map(example): | |
image = np.array(example["image"], dtype=np.float32) | |
if image.ndim != 3 or image.shape[2] != 13: | |
raise ValueError(f"Expected shape (H, W, 13), got {image.shape}") | |
# Normalize | |
image = image / 2750.0 | |
image = np.clip(image, 0, 1) | |
# === DATA AUGMENTATION === | |
# Horizontal flip | |
if random.random() < 0.5: | |
image = np.flip(image, axis=1).copy() | |
# Vertical flip | |
if random.random() < 0.5: | |
image = np.flip(image, axis=0).copy() | |
# Rotation (by 90, 180, 270) | |
if random.random() < 0.5: | |
k = random.choice([1, 2, 3]) | |
image = np.rot90(image, k=k, axes=(0, 1)).copy() | |
# === SHAPE FORMAT === | |
image = image.transpose(2, 0, 1) # (C=13, H, W) | |
return { | |
"image": torch.tensor(image, dtype=torch.float32), | |
"label": torch.tensor(example["label"], dtype=torch.long) | |
} | |
# RGB conversion functions | |
def load_rgb_from_multispectral_sample(numpy_array): | |
""" | |
Takes a NumPy array with 13 multispectral bands and returns a scaled RGB NumPy array. | |
Equivalent to loading bands 4-3-2 and scaling as GDAL would. | |
""" | |
# GDAL-style scaling: scale 0β2750 -> 1β255 | |
def scale_band(band): | |
band = np.clip((band / 2750) * 255, 0, 255) | |
return band.astype(np.uint8) | |
# Bands 4 (red), 3 (green), 2 (blue) => index 3, 2, 1 in 0-based | |
bands = [3, 2, 1] | |
# Ensure the input is a NumPy array | |
if not isinstance(numpy_array, np.ndarray): | |
raise TypeError("Input must be a NumPy array") | |
# Check if the array has the expected number of channels (13) | |
if numpy_array.shape[-1] != 13: | |
raise ValueError(f"Input array must have 13 channels, but got {numpy_array.shape[-1]}") | |
# Extract and scale the RGB bands from the NumPy array | |
rgb = np.stack([scale_band(numpy_array[:, :, b]) for b in bands], axis=-1) | |
return rgb | |
def load_rgb_from_transformed_tensor(tensor_image): | |
""" | |
Takes a torch.Tensor with 13 multispectral bands (C, H, W) and returns a scaled RGB NumPy array. | |
""" | |
if not isinstance(tensor_image, torch.Tensor): | |
raise TypeError("Input must be a torch.Tensor") | |
if tensor_image.shape[0] != 13: | |
raise ValueError(f"Expected 13 channels, got {tensor_image.shape[0]}") | |
# Convert to NumPy (C, H, W) β (H, W, C) | |
np_image = tensor_image.numpy() | |
np_image = np.transpose(np_image, (1, 2, 0)) # (H, W, 13) | |
# Bands 4-3-2 β index 3, 2, 1 | |
bands = [3, 2, 1] | |
def scale_band(band): | |
band = np.clip((band * 255), 0, 255) | |
return band.astype(np.uint8) | |
rgb = np.stack([scale_band(np_image[:, :, b]) for b in bands], axis=-1) # (H, W, 3) | |
return rgb | |
# Global variables for model and dataset | |
model = None | |
dataset = None | |
label_names = None | |
label2id = None | |
id2label = None | |
def load_model_and_data(): | |
"""Load the model and dataset""" | |
global model, dataset, label_names, label2id, id2label | |
try: | |
# Load dataset | |
print("Loading dataset...") | |
dataset = load_dataset("blanchon/EuroSAT_MSI", cache_dir="./hf_cache", streaming=False) | |
dataset["test"] = dataset["test"].map(transform_multispectral_map) | |
dataset["test"].set_format(type="torch", columns=["image", "label"]) | |
# Setup labels | |
label_names = dataset["train"].features['label'].names | |
label2id = {name: i for i, name in enumerate(label_names)} | |
id2label = {v: k for k, v in label2id.items()} | |
num_classes = len(label_names) | |
# Load model | |
print("Loading model...") | |
model_path = hf_hub_download(repo_id="Rhodham96/Resnet18DropoutSentinel", filename="pytorch_model.bin") | |
model = ResNet18_Dropout(in_channels=13, num_classes=num_classes) | |
model.load_state_dict(torch.load(model_path, map_location='cpu')) | |
model.eval() | |
print(f"Model and dataset loaded successfully!") | |
print(f"Classes: {label_names}") | |
return True | |
except Exception as e: | |
print(f"Error loading model or dataset: {str(e)}") | |
return False | |
def predict_images(): | |
"""Process 16 random images and return results""" | |
global model, dataset, id2label | |
if model is None or dataset is None: | |
return "Model or dataset not loaded. Please wait for initialization." | |
test_dataloader = DataLoader(dataset["test"], batch_size=32, shuffle=True) | |
try: | |
# Get 16 random samples from validation set | |
num_batches = 5 | |
collected_images = [] | |
collected_labels = [] | |
collected_preds = [] | |
#criterion = nn.CrossEntropyLoss() | |
model.eval() | |
with torch.no_grad(): | |
for i, batch in enumerate(test_dataloader): | |
if i >= num_batches: | |
break | |
images = batch['image'] | |
labels = batch['label'] | |
outputs = model(images) | |
_, preds = outputs.max(1) | |
collected_images.append(images.cpu()) | |
collected_labels.append(labels.cpu()) | |
collected_preds.append(preds.cpu()) | |
# Concatenate all samples | |
images = torch.cat(collected_images) | |
labels = torch.cat(collected_labels) | |
preds = torch.cat(collected_preds) | |
# Randomly select 10 indices | |
indices = random.sample(range(len(images)), 10) | |
# Prepare for plotting | |
selected_images = images[indices] | |
selected_labels = labels[indices] | |
selected_preds = preds[indices] | |
image_to_see_layers = selected_images[0] | |
label_to_see_layers = selected_labels[0] | |
# Plot | |
fig, axes = plt.subplots(2, 5, figsize=(15, 6)) | |
axes = axes.flatten() | |
for i in range(10): | |
img = load_rgb_from_transformed_tensor(selected_images[i]) | |
axes[i].imshow(img) | |
axes[i].axis("off") | |
true_label = id2label[selected_labels[i].item()] | |
pred_label = id2label[selected_preds[i].item()] | |
color = "green" if pred_label == true_label else "red" | |
axes[i].set_title(f"T: {true_label}\nP: {pred_label}", color=color) | |
plt.tight_layout() | |
# Convert plot to image | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') | |
buf.seek(0) | |
plt.close() | |
# Convert to PIL Image | |
result_image = Image.open(buf) | |
# Calculate accuracy | |
correct_predictions = (selected_preds == selected_labels).sum().item() | |
accuracy = correct_predictions / len(selected_labels) * 100 | |
summary = f"Accuracy: {correct_predictions}/{len(selected_labels)} ({accuracy:.1f}%)\n" | |
summary += f"Classes: {', '.join(label_names)}" | |
return result_image, summary | |
except Exception as e: | |
error_msg = f"Error during prediction: {str(e)}" | |
print(error_msg) | |
# Return a placeholder image and error message | |
placeholder = Image.new('RGB', (800, 600), color='lightgray') | |
return placeholder, error_msg | |
def create_interface(): | |
"""Create the Gradio interface""" | |
# Initialize model and data | |
init_success = load_model_and_data() | |
if not init_success: | |
def error_function(): | |
placeholder = Image.new('RGB', (800, 600), color='red') | |
return placeholder, "Failed to load model or dataset. Please check the logs." | |
interface = gr.Interface( | |
fn=error_function, | |
inputs=[], | |
outputs=[ | |
gr.Image(type="pil", label="Results"), | |
gr.Textbox(label="Summary") | |
], | |
title="π°οΈ Satellite Image Classification - ERROR", | |
description="Failed to initialize the application." | |
) | |
return interface | |
# Create the main interface | |
interface = gr.Interface( | |
fn=predict_images, | |
inputs=[], | |
outputs=[ | |
gr.Image(type="pil", label="Classification Results (10 Random Images)"), | |
gr.Textbox(label="Summary", lines=3) | |
], | |
title="π°οΈ Satellite Image Classification with ResNet18", | |
description=""" | |
This app classifies satellite images from the EuroSAT dataset using a trained ResNet18 model. | |
**How it works:** | |
- Loads 10 random satellite images from the test set | |
- Each image has 13 spectral bands, converted to RGB for display | |
- Shows true labels vs predicted labels | |
- Green titles = correct predictions, Red titles = incorrect predictions | |
**Dataset:** EuroSAT with 13 multispectral bands | |
**Model:** ResNet18 with dropout, trained on 13-channel input | |
Click "Generate" to process 10 new random images! | |
""", | |
examples=[], | |
cache_examples=False, | |
allow_flagging="never" | |
) | |
return interface | |
# Launch the app | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch(share=True) |