Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from torchvision import transforms, models | |
from PIL import Image | |
# Load your trained ResNet50 model checkpoint | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = models.resnet50(num_classes=1000) # Ensure the model matches your architecture | |
checkpoint = torch.load("resnet50_40epoch_imagenet1k.ckpt", map_location=device) # Replace with your checkpoint path | |
model.load_state_dict(checkpoint['model_state_dict']) # Load state_dict from your checkpoint | |
model = model.to(device) | |
model.eval() | |
# Load ImageNet class labels | |
with open("classes.txt") as f: | |
class_labels = [line.strip() for line in f.readlines()] | |
# Define the preprocessing pipeline | |
preprocess = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Function to predict top-5 classes | |
def predict_top5(image): | |
# Preprocess the image | |
image = preprocess(image).unsqueeze(0).to(device) | |
# Get predictions | |
with torch.no_grad(): | |
outputs = model(image) | |
probabilities = torch.nn.functional.softmax(outputs[0], dim=0) | |
# Get top-5 predictions | |
top5_prob, top5_catid = torch.topk(probabilities, 5) | |
top5_results = {class_labels[catid]: prob.item() for prob, catid in zip(top5_prob, top5_catid)} | |
return top5_results | |
# Create the Gradio interface | |
interface = gr.Interface( | |
fn=predict_top5, | |
inputs=gr.inputs.Image(type="pil"), | |
outputs=gr.outputs.Label(num_top_classes=5), | |
title="ResNet50 Image Classification", | |
description="Upload an image to get the top-5 class predictions from the ResNet50 model trained on ImageNet 1k.", | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
interface.launch() | |