Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
from torchvision import models,transforms
|
4 |
from PIL import Image
|
|
|
5 |
import gradio as gr
|
6 |
from torchvision.transforms import transforms
|
7 |
|
@@ -13,16 +14,52 @@ t=transforms.Compose([ transforms.ToTensor(),
|
|
13 |
transforms.RandomHorizontalFlip(0.5),
|
14 |
transforms.RandomRotation(10),
|
15 |
])
|
16 |
-
class_name=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
model=torch.load("model.pth")
|
19 |
print(model)
|
|
|
20 |
def predict(image):
|
|
|
21 |
image=t(image).unsqueeze(0)
|
22 |
with torch.no_grad():
|
23 |
output=model(image)
|
24 |
_,predicted=torch.max(output,1)
|
25 |
-
|
|
|
26 |
return predicted_class
|
27 |
|
28 |
interface=gr.Interface(
|
|
|
2 |
import torch.nn as nn
|
3 |
from torchvision import models,transforms
|
4 |
from PIL import Image
|
5 |
+
import torch.nn.functional as f
|
6 |
import gradio as gr
|
7 |
from torchvision.transforms import transforms
|
8 |
|
|
|
14 |
transforms.RandomHorizontalFlip(0.5),
|
15 |
transforms.RandomRotation(10),
|
16 |
])
|
17 |
+
class_name=["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship","truck"]
|
18 |
+
|
19 |
+
class CIFAR_Module(nn.Module):
|
20 |
+
def __init__(self,in_channel):
|
21 |
+
self.in_channel=in_channel
|
22 |
+
super(CIFAR_Module,self).__init__()
|
23 |
+
self.con1=nn.Conv2d(in_channel,6*in_channel,5)
|
24 |
+
self.pool1=nn.MaxPool2d(5,stride=2)
|
25 |
+
self.con2=nn.Conv2d(6*in_channel,16*in_channel,5)
|
26 |
+
self.pool2=nn.MaxPool2d(5,stride=2)
|
27 |
+
self.flat=nn.Flatten()
|
28 |
+
self.fc1=nn.Linear(192,100*in_channel)
|
29 |
+
self.fc2=nn.Linear(100*in_channel,40*in_channel)
|
30 |
+
self.fc3=nn.Linear(40*in_channel,10)
|
31 |
+
def forward(self,x):
|
32 |
+
x=self.con1(x)
|
33 |
+
x=f.relu(x)
|
34 |
+
x=self.pool1(x)
|
35 |
+
x=f.relu(x)
|
36 |
+
x=self.con2(x)
|
37 |
+
x=f.relu(x)
|
38 |
+
x=self.pool2(x)
|
39 |
+
x=self.flat(x)
|
40 |
+
x=self.fc1(x)
|
41 |
+
x=f.relu(x)
|
42 |
+
x=self.fc2(x)
|
43 |
+
x=f.relu(x)
|
44 |
+
x=self.fc3(x)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
model=CIFAR_Module(3)
|
49 |
+
model.load_state_dict(torch.load("model.pth",weights_only=True))
|
50 |
+
model.eval()
|
51 |
+
|
52 |
|
|
|
53 |
print(model)
|
54 |
+
|
55 |
def predict(image):
|
56 |
+
image=image.resize((32,32))
|
57 |
image=t(image).unsqueeze(0)
|
58 |
with torch.no_grad():
|
59 |
output=model(image)
|
60 |
_,predicted=torch.max(output,1)
|
61 |
+
print(output)
|
62 |
+
predicted_class=class_name[predicted.item()-1]
|
63 |
return predicted_class
|
64 |
|
65 |
interface=gr.Interface(
|