SkullFaceFire commited on
Commit
65602f1
·
verified ·
1 Parent(s): 33403a1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -0
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models,transforms
4
+ from PIL import Image
5
+ import gradio as gr
6
+ from torchvision.transforms import transforms
7
+
8
+
9
+ # model=models.resnet18(pretrained=True)
10
+ # model.fc=nn.Linear(model.fc.in_features,10)
11
+ t=transforms.Compose([ transforms.ToTensor(),
12
+ transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
13
+ transforms.RandomHorizontalFlip(0.5),
14
+ transforms.RandomRotation(10),
15
+ ])
16
+ class_name=[f"c{i}" for i in range(1,10)]
17
+
18
+ model=torch.load("model.pth")
19
+ print(model)
20
+ def predict(image):
21
+ image=t(image).unsqueeze(0)
22
+ with torch.no_grad():
23
+ output=model(image)
24
+ _,predicted=torch.max(output,1)
25
+ predicted_class=predicted.item()
26
+ return predicted_class
27
+
28
+ interface=gr.Interface(
29
+ fn=predict,
30
+ inputs=gr.Image(type="pil"),
31
+ outputs="text",
32
+ title="cifar dataset prediction",
33
+ description="upload an image to get its class"
34
+ )
35
+
36
+ interface.launch(share=True)