Spaces:
Sleeping
Sleeping
File size: 2,437 Bytes
53bd501 099e047 53bd501 099e047 960ec28 099e047 53bd501 099e047 960ec28 099e047 960ec28 099e047 a4c51c5 099e047 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import gradio as gr
import torch
import pytorch_lightning as pl
from torchvision import transforms
from PIL import Image
from torchvision import models
import torch.nn as nn
# Define the LightningModule class (should match the training code)
class ResNet50Image2k(pl.LightningModule):
def __init__(self, num_classes=1000):
super().__init__()
self.model = models.resnet50(pretrained=False)
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
def forward(self, x):
return self.model(x)
# Load the model from PyTorch Lightning checkpoint
checkpoint_path = "./resnet50_exp.ckpt" # Replace with your checkpoint file path
model = ResNet50Image2k.load_from_checkpoint(checkpoint_path)
model.eval() # Set the model to evaluation mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Load ImageNet class labels
with open("classes.txt") as f:
class_labels = [line.strip() for line in f.readlines()]
# Define the preprocessing pipeline
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define the prediction function
def predict_top5(image):
# Preprocess the image
image = preprocess(image).unsqueeze(0).to(device)
# Get predictions
with torch.no_grad():
outputs = model(image)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
# Get top-5 predictions
top5_prob, top5_catid = torch.topk(probabilities, 5)
top5_results = {class_labels[catid]: f"{prob.item():.4f}" for prob, catid in zip(top5_prob, top5_catid)}
return top5_results
examples = [
["Images/Bird.JPEG"], # Example 1
["Images/Chamelion.JPEG"], # Example 2
["Images/Lizard.JPEG"], # Example 3
["Images/Shark.JPEG"], # Example 4
["Images/Turtle.JPEG"], # Example 5
]
# Create the Gradio interface
interface = gr.Interface(
fn=predict_top5,
inputs=gr.Image(type="pil"), # Updated syntax for image input
outputs=gr.Label(num_top_classes=5), # Updated syntax for label output
title="ResNet50 Image Classification",
description="Upload an image for top-5 class predictions from the ResNet50 ImageNet 1k Model.",
examples=examples
)
# Launch the app
if __name__ == "__main__":
interface.launch()
|