sudhakar272's picture
Upload app.py
a4c51c5 verified
import gradio as gr
import torch
import pytorch_lightning as pl
from torchvision import transforms
from PIL import Image
from torchvision import models
import torch.nn as nn
# Define the LightningModule class (should match the training code)
class ResNet50Image2k(pl.LightningModule):
def __init__(self, num_classes=1000):
super().__init__()
self.model = models.resnet50(pretrained=False)
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
def forward(self, x):
return self.model(x)
# Load the model from PyTorch Lightning checkpoint
checkpoint_path = "./resnet50_exp.ckpt" # Replace with your checkpoint file path
model = ResNet50Image2k.load_from_checkpoint(checkpoint_path)
model.eval() # Set the model to evaluation mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Load ImageNet class labels
with open("classes.txt") as f:
class_labels = [line.strip() for line in f.readlines()]
# Define the preprocessing pipeline
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define the prediction function
def predict_top5(image):
# Preprocess the image
image = preprocess(image).unsqueeze(0).to(device)
# Get predictions
with torch.no_grad():
outputs = model(image)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
# Get top-5 predictions
top5_prob, top5_catid = torch.topk(probabilities, 5)
top5_results = {class_labels[catid]: f"{prob.item():.4f}" for prob, catid in zip(top5_prob, top5_catid)}
return top5_results
examples = [
["Images/Bird.JPEG"], # Example 1
["Images/Chamelion.JPEG"], # Example 2
["Images/Lizard.JPEG"], # Example 3
["Images/Shark.JPEG"], # Example 4
["Images/Turtle.JPEG"], # Example 5
]
# Create the Gradio interface
interface = gr.Interface(
fn=predict_top5,
inputs=gr.Image(type="pil"), # Updated syntax for image input
outputs=gr.Label(num_top_classes=5), # Updated syntax for label output
title="ResNet50 Image Classification",
description="Upload an image for top-5 class predictions from the ResNet50 ImageNet 1k Model.",
examples=examples
)
# Launch the app
if __name__ == "__main__":
interface.launch()