import gradio as gr import torch import pytorch_lightning as pl from torchvision import transforms from PIL import Image from torchvision import models import torch.nn as nn # Define the LightningModule class (should match the training code) class ResNet50Image2k(pl.LightningModule): def __init__(self, num_classes=1000): super().__init__() self.model = models.resnet50(pretrained=False) self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) def forward(self, x): return self.model(x) # Load the model from PyTorch Lightning checkpoint checkpoint_path = "./resnet50_exp.ckpt" # Replace with your checkpoint file path model = ResNet50Image2k.load_from_checkpoint(checkpoint_path) model.eval() # Set the model to evaluation mode device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # Load ImageNet class labels with open("classes.txt") as f: class_labels = [line.strip() for line in f.readlines()] # Define the preprocessing pipeline preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Define the prediction function def predict_top5(image): # Preprocess the image image = preprocess(image).unsqueeze(0).to(device) # Get predictions with torch.no_grad(): outputs = model(image) probabilities = torch.nn.functional.softmax(outputs[0], dim=0) # Get top-5 predictions top5_prob, top5_catid = torch.topk(probabilities, 5) top5_results = {class_labels[catid]: f"{prob.item():.4f}" for prob, catid in zip(top5_prob, top5_catid)} return top5_results examples = [ ["Images/Bird.JPEG"], # Example 1 ["Images/Chamelion.JPEG"], # Example 2 ["Images/Lizard.JPEG"], # Example 3 ["Images/Shark.JPEG"], # Example 4 ["Images/Turtle.JPEG"], # Example 5 ] # Create the Gradio interface interface = gr.Interface( fn=predict_top5, inputs=gr.Image(type="pil"), # Updated syntax for image input outputs=gr.Label(num_top_classes=5), # Updated syntax for label output title="ResNet50 Image Classification", description="Upload an image for top-5 class predictions from the ResNet50 ImageNet 1k Model.", examples=examples ) # Launch the app if __name__ == "__main__": interface.launch()