sudhakar272 commited on
Commit
960ec28
·
verified ·
1 Parent(s): 6e2223a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -7,7 +7,7 @@ from torchvision import models
7
  import torch.nn as nn
8
 
9
  # Define the LightningModule class (should match the training code)
10
- class ResNet50Lightning(pl.LightningModule):
11
  def __init__(self, num_classes=1000):
12
  super().__init__()
13
  self.model = models.resnet50(pretrained=False)
@@ -17,8 +17,8 @@ class ResNet50Lightning(pl.LightningModule):
17
  return self.model(x)
18
 
19
  # Load the model from PyTorch Lightning checkpoint
20
- checkpoint_path = "./resnet50_40epoch_imagenet1k.ckpt" # Replace with your checkpoint file path
21
- model = ResNet50Lightning.load_from_checkpoint(checkpoint_path)
22
  model.eval() # Set the model to evaluation mode
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  model = model.to(device)
@@ -52,9 +52,11 @@ def predict_top5(image):
52
  return top5_results
53
 
54
  examples = [
55
- ["examples/espresso.jpg.webp"], # Example 1
56
- ["examples/american_bullfrog.jpg"], # Example 2
57
- ["examples/tiger_shark.jpg"], # Example 3
 
 
58
  ]
59
 
60
  # Create the Gradio interface
 
7
  import torch.nn as nn
8
 
9
  # Define the LightningModule class (should match the training code)
10
+ class ResNet50Image2k(pl.LightningModule):
11
  def __init__(self, num_classes=1000):
12
  super().__init__()
13
  self.model = models.resnet50(pretrained=False)
 
17
  return self.model(x)
18
 
19
  # Load the model from PyTorch Lightning checkpoint
20
+ checkpoint_path = "./resnet50_exp.ckpt" # Replace with your checkpoint file path
21
+ model = ResNet50Image2k.load_from_checkpoint(checkpoint_path)
22
  model.eval() # Set the model to evaluation mode
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  model = model.to(device)
 
52
  return top5_results
53
 
54
  examples = [
55
+ ["Images/Bird.JPEG"], # Example 1
56
+ ["Images/Chamelion.JPEG"], # Example 2
57
+ ["Images/Lizard.JPEG"], # Example 3
58
+ ["Images/Shark.JPEG"], # Example 4
59
+ ["Images/Turtle.JPEG"], # Example 5
60
  ]
61
 
62
  # Create the Gradio interface