SkullFaceFire commited on
Commit
4027e4e
·
verified ·
1 Parent(s): 6c1a206

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -3
app.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  import torch.nn as nn
3
  from torchvision import models,transforms
4
  from PIL import Image
 
5
  import gradio as gr
6
  from torchvision.transforms import transforms
7
 
@@ -13,16 +14,52 @@ t=transforms.Compose([ transforms.ToTensor(),
13
  transforms.RandomHorizontalFlip(0.5),
14
  transforms.RandomRotation(10),
15
  ])
16
- class_name=[f"c{i}" for i in range(1,10)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- model=torch.load("model.pth")
19
  print(model)
 
20
  def predict(image):
 
21
  image=t(image).unsqueeze(0)
22
  with torch.no_grad():
23
  output=model(image)
24
  _,predicted=torch.max(output,1)
25
- predicted_class=predicted.item()
 
26
  return predicted_class
27
 
28
  interface=gr.Interface(
 
2
  import torch.nn as nn
3
  from torchvision import models,transforms
4
  from PIL import Image
5
+ import torch.nn.functional as f
6
  import gradio as gr
7
  from torchvision.transforms import transforms
8
 
 
14
  transforms.RandomHorizontalFlip(0.5),
15
  transforms.RandomRotation(10),
16
  ])
17
+ class_name=["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship","truck"]
18
+
19
+ class CIFAR_Module(nn.Module):
20
+ def __init__(self,in_channel):
21
+ self.in_channel=in_channel
22
+ super(CIFAR_Module,self).__init__()
23
+ self.con1=nn.Conv2d(in_channel,6*in_channel,5)
24
+ self.pool1=nn.MaxPool2d(5,stride=2)
25
+ self.con2=nn.Conv2d(6*in_channel,16*in_channel,5)
26
+ self.pool2=nn.MaxPool2d(5,stride=2)
27
+ self.flat=nn.Flatten()
28
+ self.fc1=nn.Linear(192,100*in_channel)
29
+ self.fc2=nn.Linear(100*in_channel,40*in_channel)
30
+ self.fc3=nn.Linear(40*in_channel,10)
31
+ def forward(self,x):
32
+ x=self.con1(x)
33
+ x=f.relu(x)
34
+ x=self.pool1(x)
35
+ x=f.relu(x)
36
+ x=self.con2(x)
37
+ x=f.relu(x)
38
+ x=self.pool2(x)
39
+ x=self.flat(x)
40
+ x=self.fc1(x)
41
+ x=f.relu(x)
42
+ x=self.fc2(x)
43
+ x=f.relu(x)
44
+ x=self.fc3(x)
45
+ return x
46
+
47
+
48
+ model=CIFAR_Module(3)
49
+ model.load_state_dict(torch.load("model.pth",weights_only=True))
50
+ model.eval()
51
+
52
 
 
53
  print(model)
54
+
55
  def predict(image):
56
+ image=image.resize((32,32))
57
  image=t(image).unsqueeze(0)
58
  with torch.no_grad():
59
  output=model(image)
60
  _,predicted=torch.max(output,1)
61
+ print(output)
62
+ predicted_class=class_name[predicted.item()-1]
63
  return predicted_class
64
 
65
  interface=gr.Interface(