import gradio as gr import torch import torch.nn as nn from torchvision.models import resnet18 from datasets import load_dataset from huggingface_hub import hf_hub_download import numpy as np import random from PIL import Image import matplotlib.pyplot as plt import io from torch.utils.data import DataLoader import base64 # Model architecture definition class ResNet18_Dropout(nn.Module): def __init__(self, in_channels, num_classes, dropout_rate=0.3): super().__init__() self.model = resnet18(weights=None) self.model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) in_features = self.model.fc.in_features self.model.fc = nn.Sequential( nn.Dropout(dropout_rate), nn.Linear(in_features, num_classes) ) def forward(self, x): return self.model(x) def transform_multispectral_map(example): image = np.array(example["image"], dtype=np.float32) if image.ndim != 3 or image.shape[2] != 13: raise ValueError(f"Expected shape (H, W, 13), got {image.shape}") # Normalize image = image / 2750.0 image = np.clip(image, 0, 1) # === DATA AUGMENTATION === # Horizontal flip if random.random() < 0.5: image = np.flip(image, axis=1).copy() # Vertical flip if random.random() < 0.5: image = np.flip(image, axis=0).copy() # Rotation (by 90, 180, 270) if random.random() < 0.5: k = random.choice([1, 2, 3]) image = np.rot90(image, k=k, axes=(0, 1)).copy() # === SHAPE FORMAT === image = image.transpose(2, 0, 1) # (C=13, H, W) return { "image": torch.tensor(image, dtype=torch.float32), "label": torch.tensor(example["label"], dtype=torch.long) } # RGB conversion functions def load_rgb_from_multispectral_sample(numpy_array): """ Takes a NumPy array with 13 multispectral bands and returns a scaled RGB NumPy array. Equivalent to loading bands 4-3-2 and scaling as GDAL would. """ # GDAL-style scaling: scale 0–2750 -> 1–255 def scale_band(band): band = np.clip((band / 2750) * 255, 0, 255) return band.astype(np.uint8) # Bands 4 (red), 3 (green), 2 (blue) => index 3, 2, 1 in 0-based bands = [3, 2, 1] # Ensure the input is a NumPy array if not isinstance(numpy_array, np.ndarray): raise TypeError("Input must be a NumPy array") # Check if the array has the expected number of channels (13) if numpy_array.shape[-1] != 13: raise ValueError(f"Input array must have 13 channels, but got {numpy_array.shape[-1]}") # Extract and scale the RGB bands from the NumPy array rgb = np.stack([scale_band(numpy_array[:, :, b]) for b in bands], axis=-1) return rgb def load_rgb_from_transformed_tensor(tensor_image): """ Takes a torch.Tensor with 13 multispectral bands (C, H, W) and returns a scaled RGB NumPy array. """ if not isinstance(tensor_image, torch.Tensor): raise TypeError("Input must be a torch.Tensor") if tensor_image.shape[0] != 13: raise ValueError(f"Expected 13 channels, got {tensor_image.shape[0]}") # Convert to NumPy (C, H, W) → (H, W, C) np_image = tensor_image.numpy() np_image = np.transpose(np_image, (1, 2, 0)) # (H, W, 13) # Bands 4-3-2 → index 3, 2, 1 bands = [3, 2, 1] def scale_band(band): band = np.clip((band * 255), 0, 255) return band.astype(np.uint8) rgb = np.stack([scale_band(np_image[:, :, b]) for b in bands], axis=-1) # (H, W, 3) return rgb # Global variables for model and dataset model = None dataset = None label_names = None label2id = None id2label = None def load_model_and_data(): """Load the model and dataset""" global model, dataset, label_names, label2id, id2label try: # Load dataset print("Loading dataset...") dataset = load_dataset("blanchon/EuroSAT_MSI", cache_dir="./hf_cache", streaming=False) dataset["test"] = dataset["test"].map(transform_multispectral_map) dataset["test"].set_format(type="torch", columns=["image", "label"]) # Setup labels label_names = dataset["train"].features['label'].names label2id = {name: i for i, name in enumerate(label_names)} id2label = {v: k for k, v in label2id.items()} num_classes = len(label_names) # Load model print("Loading model...") model_path = hf_hub_download(repo_id="Rhodham96/Resnet18DropoutSentinel", filename="pytorch_model.bin") model = ResNet18_Dropout(in_channels=13, num_classes=num_classes) model.load_state_dict(torch.load(model_path, map_location='cpu')) model.eval() print(f"Model and dataset loaded successfully!") print(f"Classes: {label_names}") return True except Exception as e: print(f"Error loading model or dataset: {str(e)}") return False def predict_images(): """Process 16 random images and return results""" global model, dataset, id2label if model is None or dataset is None: return "Model or dataset not loaded. Please wait for initialization." test_dataloader = DataLoader(dataset["test"], batch_size=32, shuffle=True) try: # Get 16 random samples from validation set num_batches = 5 collected_images = [] collected_labels = [] collected_preds = [] #criterion = nn.CrossEntropyLoss() model.eval() with torch.no_grad(): for i, batch in enumerate(test_dataloader): if i >= num_batches: break images = batch['image'] labels = batch['label'] outputs = model(images) _, preds = outputs.max(1) collected_images.append(images.cpu()) collected_labels.append(labels.cpu()) collected_preds.append(preds.cpu()) # Concatenate all samples images = torch.cat(collected_images) labels = torch.cat(collected_labels) preds = torch.cat(collected_preds) # Randomly select 10 indices indices = random.sample(range(len(images)), 10) # Prepare for plotting selected_images = images[indices] selected_labels = labels[indices] selected_preds = preds[indices] image_to_see_layers = selected_images[0] label_to_see_layers = selected_labels[0] # Plot fig, axes = plt.subplots(2, 5, figsize=(15, 6)) axes = axes.flatten() for i in range(10): img = load_rgb_from_transformed_tensor(selected_images[i]) axes[i].imshow(img) axes[i].axis("off") true_label = id2label[selected_labels[i].item()] pred_label = id2label[selected_preds[i].item()] color = "green" if pred_label == true_label else "red" axes[i].set_title(f"T: {true_label}\nP: {pred_label}", color=color) plt.tight_layout() # Convert plot to image buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') buf.seek(0) plt.close() # Convert to PIL Image result_image = Image.open(buf) # Calculate accuracy correct_predictions = (selected_preds == selected_labels).sum().item() accuracy = correct_predictions / len(selected_labels) * 100 summary = f"Accuracy: {correct_predictions}/{len(selected_labels)} ({accuracy:.1f}%)\n" summary += f"Classes: {', '.join(label_names)}" return result_image, summary except Exception as e: error_msg = f"Error during prediction: {str(e)}" print(error_msg) # Return a placeholder image and error message placeholder = Image.new('RGB', (800, 600), color='lightgray') return placeholder, error_msg def create_interface(): """Create the Gradio interface""" # Initialize model and data init_success = load_model_and_data() if not init_success: def error_function(): placeholder = Image.new('RGB', (800, 600), color='red') return placeholder, "Failed to load model or dataset. Please check the logs." interface = gr.Interface( fn=error_function, inputs=[], outputs=[ gr.Image(type="pil", label="Results"), gr.Textbox(label="Summary") ], title="🛰️ Satellite Image Classification - ERROR", description="Failed to initialize the application." ) return interface # Create the main interface interface = gr.Interface( fn=predict_images, inputs=[], outputs=[ gr.Image(type="pil", label="Classification Results (10 Random Images)"), gr.Textbox(label="Summary", lines=3) ], title="🛰️ Satellite Image Classification with ResNet18", description=""" This app classifies satellite images from the EuroSAT dataset using a trained ResNet18 model. **How it works:** - Loads 10 random satellite images from the test set - Each image has 13 spectral bands, converted to RGB for display - Shows true labels vs predicted labels - Green titles = correct predictions, Red titles = incorrect predictions **Dataset:** EuroSAT with 13 multispectral bands **Model:** ResNet18 with dropout, trained on 13-channel input Click "Generate" to process 10 new random images! """, examples=[], cache_examples=False, allow_flagging="never" ) return interface # Launch the app if __name__ == "__main__": demo = create_interface() demo.launch(share=True)