Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -20,7 +20,7 @@ class_names = [line.strip() for line in open("classes.txt")]
|
|
20 |
# Load the model
|
21 |
model = torchvision.models.vit_b_16(weights=None) # Initialize the model architecture
|
22 |
model.heads = nn.Linear(in_features=768, out_features=len(class_names)) # Adjust the classifier head
|
23 |
-
checkpoint = torch.load('08_pretrained_vit_feature_extractor_pizza_steak_sushi.pth')
|
24 |
model.load_state_dict(checkpoint, strict=False)
|
25 |
model = model.to(device)
|
26 |
model.eval()
|
@@ -69,7 +69,7 @@ iface = gr.Interface(
|
|
69 |
fn=predict,
|
70 |
inputs=gr.Image(type="numpy"),
|
71 |
outputs=[gr.Textbox(label="Top Prediction"), gr.Plot()], # Textbox for top prediction and Plot for the bar chart
|
72 |
-
examples=[
|
73 |
)
|
74 |
|
75 |
# Launch the Gradio app
|
|
|
20 |
# Load the model
|
21 |
model = torchvision.models.vit_b_16(weights=None) # Initialize the model architecture
|
22 |
model.heads = nn.Linear(in_features=768, out_features=len(class_names)) # Adjust the classifier head
|
23 |
+
checkpoint = torch.load('08_pretrained_vit_feature_extractor_pizza_steak_sushi.pth', map_location=torch.device('cpu'))
|
24 |
model.load_state_dict(checkpoint, strict=False)
|
25 |
model = model.to(device)
|
26 |
model.eval()
|
|
|
69 |
fn=predict,
|
70 |
inputs=gr.Image(type="numpy"),
|
71 |
outputs=[gr.Textbox(label="Top Prediction"), gr.Plot()], # Textbox for top prediction and Plot for the bar chart
|
72 |
+
examples=["Image 1.jpg", "Image 2"] # Optional: Add paths to example images
|
73 |
)
|
74 |
|
75 |
# Launch the Gradio app
|