Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import pytorch_lightning as pl | |
from torchvision import transforms | |
from PIL import Image | |
from torchvision import models | |
import torch.nn as nn | |
# Define the LightningModule class (should match the training code) | |
class ResNet50Image2k(pl.LightningModule): | |
def __init__(self, num_classes=1000): | |
super().__init__() | |
self.model = models.resnet50(pretrained=False) | |
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) | |
def forward(self, x): | |
return self.model(x) | |
# Load the model from PyTorch Lightning checkpoint | |
checkpoint_path = "./resnet50_exp.ckpt" # Replace with your checkpoint file path | |
model = ResNet50Image2k.load_from_checkpoint(checkpoint_path) | |
model.eval() # Set the model to evaluation mode | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
# Load ImageNet class labels | |
with open("classes.txt") as f: | |
class_labels = [line.strip() for line in f.readlines()] | |
# Define the preprocessing pipeline | |
preprocess = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Define the prediction function | |
def predict_top5(image): | |
# Preprocess the image | |
image = preprocess(image).unsqueeze(0).to(device) | |
# Get predictions | |
with torch.no_grad(): | |
outputs = model(image) | |
probabilities = torch.nn.functional.softmax(outputs[0], dim=0) | |
# Get top-5 predictions | |
top5_prob, top5_catid = torch.topk(probabilities, 5) | |
top5_results = {class_labels[catid]: f"{prob.item():.4f}" for prob, catid in zip(top5_prob, top5_catid)} | |
return top5_results | |
examples = [ | |
["Images/Bird.JPEG"], # Example 1 | |
["Images/Chamelion.JPEG"], # Example 2 | |
["Images/Lizard.JPEG"], # Example 3 | |
["Images/Shark.JPEG"], # Example 4 | |
["Images/Turtle.JPEG"], # Example 5 | |
] | |
# Create the Gradio interface | |
interface = gr.Interface( | |
fn=predict_top5, | |
inputs=gr.Image(type="pil"), # Updated syntax for image input | |
outputs=gr.Label(num_top_classes=5), # Updated syntax for label output | |
title="ResNet50 Image Classification", | |
description="Upload an image for top-5 class predictions from the ResNet50 ImageNet 1k Model.", | |
examples=examples | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
interface.launch() | |