ResNet50 / app.py
piyushgrover's picture
Upload 4 files
823b8ba verified
raw
history blame
1.82 kB
import gradio as gr
import torch
from torchvision import transforms, models
from PIL import Image
# Load your trained ResNet50 model checkpoint
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet50(num_classes=1000) # Ensure the model matches your architecture
checkpoint = torch.load("resnet50_40epoch_imagenet1k.ckpt", map_location=device) # Replace with your checkpoint path
model.load_state_dict(checkpoint['model_state_dict']) # Load state_dict from your checkpoint
model = model.to(device)
model.eval()
# Load ImageNet class labels
with open("classes.txt") as f:
class_labels = [line.strip() for line in f.readlines()]
# Define the preprocessing pipeline
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Function to predict top-5 classes
def predict_top5(image):
# Preprocess the image
image = preprocess(image).unsqueeze(0).to(device)
# Get predictions
with torch.no_grad():
outputs = model(image)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
# Get top-5 predictions
top5_prob, top5_catid = torch.topk(probabilities, 5)
top5_results = {class_labels[catid]: prob.item() for prob, catid in zip(top5_prob, top5_catid)}
return top5_results
# Create the Gradio interface
interface = gr.Interface(
fn=predict_top5,
inputs=gr.inputs.Image(type="pil"),
outputs=gr.outputs.Label(num_top_classes=5),
title="ResNet50 Image Classification",
description="Upload an image to get the top-5 class predictions from the ResNet50 model trained on ImageNet 1k.",
)
# Launch the app
if __name__ == "__main__":
interface.launch()