File size: 2,437 Bytes
53bd501
099e047
 
 
 
 
 
53bd501
099e047
960ec28
099e047
 
 
 
 
 
 
53bd501
099e047
960ec28
 
099e047
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
960ec28
 
 
 
 
099e047
 
 
 
 
 
 
 
a4c51c5
099e047
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
import torch
import pytorch_lightning as pl
from torchvision import transforms
from PIL import Image
from torchvision import models
import torch.nn as nn

# Define the LightningModule class (should match the training code)
class ResNet50Image2k(pl.LightningModule):
    def __init__(self, num_classes=1000):
        super().__init__()
        self.model = models.resnet50(pretrained=False)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
    
    def forward(self, x):
        return self.model(x)

# Load the model from PyTorch Lightning checkpoint
checkpoint_path = "./resnet50_exp.ckpt"  # Replace with your checkpoint file path
model = ResNet50Image2k.load_from_checkpoint(checkpoint_path)
model.eval()  # Set the model to evaluation mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Load ImageNet class labels
with open("classes.txt") as f:
    class_labels = [line.strip() for line in f.readlines()]

# Define the preprocessing pipeline
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Define the prediction function
def predict_top5(image):
    # Preprocess the image
    image = preprocess(image).unsqueeze(0).to(device)

    # Get predictions
    with torch.no_grad():
        outputs = model(image)
        probabilities = torch.nn.functional.softmax(outputs[0], dim=0)

    # Get top-5 predictions
    top5_prob, top5_catid = torch.topk(probabilities, 5)
    top5_results = {class_labels[catid]: f"{prob.item():.4f}" for prob, catid in zip(top5_prob, top5_catid)}

    return top5_results

examples = [
    ["Images/Bird.JPEG"],  # Example 1
    ["Images/Chamelion.JPEG"],  # Example 2
    ["Images/Lizard.JPEG"],  # Example 3
    ["Images/Shark.JPEG"],  # Example 4
    ["Images/Turtle.JPEG"],  # Example 5
]

# Create the Gradio interface
interface = gr.Interface(
    fn=predict_top5,
    inputs=gr.Image(type="pil"),  # Updated syntax for image input
    outputs=gr.Label(num_top_classes=5),  # Updated syntax for label output
    title="ResNet50 Image Classification",
    description="Upload an image for top-5 class predictions from the ResNet50 ImageNet 1k Model.",
    examples=examples
)

# Launch the app
if __name__ == "__main__":
    interface.launch()