Harshithtd commited on
Commit
d1844bf
·
verified ·
1 Parent(s): db4e299

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -20,7 +20,7 @@ class_names = [line.strip() for line in open("classes.txt")]
20
  # Load the model
21
  model = torchvision.models.vit_b_16(weights=None) # Initialize the model architecture
22
  model.heads = nn.Linear(in_features=768, out_features=len(class_names)) # Adjust the classifier head
23
- checkpoint = torch.load('08_pretrained_vit_feature_extractor_pizza_steak_sushi.pth')
24
  model.load_state_dict(checkpoint, strict=False)
25
  model = model.to(device)
26
  model.eval()
@@ -69,7 +69,7 @@ iface = gr.Interface(
69
  fn=predict,
70
  inputs=gr.Image(type="numpy"),
71
  outputs=[gr.Textbox(label="Top Prediction"), gr.Plot()], # Textbox for top prediction and Plot for the bar chart
72
- examples=[r"C:\Users\Asus\Desktop\download (1).jpg", r"C:\Users\Asus\Desktop\download (3).jpg"] # Optional: Add paths to example images
73
  )
74
 
75
  # Launch the Gradio app
 
20
  # Load the model
21
  model = torchvision.models.vit_b_16(weights=None) # Initialize the model architecture
22
  model.heads = nn.Linear(in_features=768, out_features=len(class_names)) # Adjust the classifier head
23
+ checkpoint = torch.load('08_pretrained_vit_feature_extractor_pizza_steak_sushi.pth', map_location=torch.device('cpu'))
24
  model.load_state_dict(checkpoint, strict=False)
25
  model = model.to(device)
26
  model.eval()
 
69
  fn=predict,
70
  inputs=gr.Image(type="numpy"),
71
  outputs=[gr.Textbox(label="Top Prediction"), gr.Plot()], # Textbox for top prediction and Plot for the bar chart
72
+ examples=["Image 1.jpg", "Image 2"] # Optional: Add paths to example images
73
  )
74
 
75
  # Launch the Gradio app