piyushgrover commited on
Commit
00a1c70
·
verified ·
1 Parent(s): 823b8ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -8
app.py CHANGED
@@ -1,15 +1,27 @@
1
  import gradio as gr
2
  import torch
3
- from torchvision import transforms, models
 
4
  from PIL import Image
 
 
5
 
6
- # Load your trained ResNet50 model checkpoint
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
- model = models.resnet50(num_classes=1000) # Ensure the model matches your architecture
9
- checkpoint = torch.load("resnet50_40epoch_imagenet1k.ckpt", map_location=device) # Replace with your checkpoint path
10
- model.load_state_dict(checkpoint['model_state_dict']) # Load state_dict from your checkpoint
11
  model = model.to(device)
12
- model.eval()
13
 
14
  # Load ImageNet class labels
15
  with open("classes.txt") as f:
@@ -23,7 +35,7 @@ preprocess = transforms.Compose([
23
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
24
  ])
25
 
26
- # Function to predict top-5 classes
27
  def predict_top5(image):
28
  # Preprocess the image
29
  image = preprocess(image).unsqueeze(0).to(device)
@@ -35,7 +47,7 @@ def predict_top5(image):
35
 
36
  # Get top-5 predictions
37
  top5_prob, top5_catid = torch.topk(probabilities, 5)
38
- top5_results = {class_labels[catid]: prob.item() for prob, catid in zip(top5_prob, top5_catid)}
39
 
40
  return top5_results
41
 
 
1
  import gradio as gr
2
  import torch
3
+ import pytorch_lightning as pl
4
+ from torchvision import transforms
5
  from PIL import Image
6
+ from torchvision import models
7
+ import torch.nn as nn
8
 
9
+ # Define the LightningModule class (should match the training code)
10
+ class ResNet50Lightning(pl.LightningModule):
11
+ def __init__(self, num_classes=1000):
12
+ super().__init__()
13
+ self.model = models.resnet50(pretrained=False)
14
+ self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
15
+
16
+ def forward(self, x):
17
+ return self.model(x)
18
+
19
+ # Load the model from PyTorch Lightning checkpoint
20
+ checkpoint_path = "./resnet50_40epoch_imagenet1k.ckpt" # Replace with your checkpoint file path
21
+ model = ResNet50Lightning.load_from_checkpoint(checkpoint_path)
22
+ model.eval() # Set the model to evaluation mode
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
24
  model = model.to(device)
 
25
 
26
  # Load ImageNet class labels
27
  with open("classes.txt") as f:
 
35
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
36
  ])
37
 
38
+ # Define the prediction function
39
  def predict_top5(image):
40
  # Preprocess the image
41
  image = preprocess(image).unsqueeze(0).to(device)
 
47
 
48
  # Get top-5 predictions
49
  top5_prob, top5_catid = torch.topk(probabilities, 5)
50
+ top5_results = {class_labels[catid]: f"{prob.item():.4f}" for prob, catid in zip(top5_prob, top5_catid)}
51
 
52
  return top5_results
53