Weicheng HE commited on
Commit
0349d8b
·
1 Parent(s): 4af5909

add torch-lightning

Browse files
Files changed (1) hide show
  1. app.py +69 -21
app.py CHANGED
@@ -1,34 +1,82 @@
1
- import gradio as gr
2
- from PIL import Image
3
- import numpy as np
4
  import torch
 
5
  import torchvision
6
- from torchvision import transforms
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  def predict_image(image):
 
 
9
  image = image.convert('RGB')
10
  test_transforms = transforms.Compose([
11
  transforms.Resize([224, 224]),
12
- transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
13
  ])
14
- classes = ('Speed limit (20km/h)',
15
- 'Speed limit (30km/h)',
16
- 'Speed limit (50km/h)',
17
- 'Speed limit (60km/h)',
18
- 'Speed limit (70km/h)',
19
- 'Speed limit (80km/h)',
20
- 'Speed limit (100km/h)',
21
- 'Speed limit (120km/h)')
22
-
23
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
- model = torch.load('vgg11.pt')
25
- model.eval()
26
  image_tensor = test_transforms(image).float()
27
  image_tensor = image_tensor.unsqueeze_(0)
28
- input = image_tensor.to(device)
29
- output = model(input)
30
- probs = torch.exp(output.data.cpu().squeeze()).numpy()
31
- return dict(zip(classes, map(float, probs)))
 
 
32
  image = gr.Image(type='pil')
33
  label = gr.Label()
34
  examples = ['1.png', '2.png', '3.png']
 
 
 
 
1
  import torch
2
+ from torch.nn import functional as F
3
  import torchvision
4
+ from torchvision import transforms, models
5
+ import pytorch_lightning as pl
6
+ from pytorch_lightning import LightningModule, Trainer
7
+ from PIL import Image
8
+ import gradio as gr
9
+
10
+ classes = ['Speed limit (20km/h)',
11
+ 'Speed limit (30km/h)',
12
+ 'Speed limit (50km/h)',
13
+ 'Speed limit (60km/h)',
14
+ 'Speed limit (70km/h)',
15
+ 'Speed limit (80km/h)',
16
+ 'End of speed limit (80km/h)',
17
+ 'Speed limit (100km/h)',
18
+ 'Speed limit (120km/h)',
19
+ 'No passing',
20
+ 'No passing veh over 3.5 tons',
21
+ 'Right-of-way at intersection',
22
+ 'Priority road',
23
+ 'Yield',
24
+ 'Stop',
25
+ 'No vehicles',
26
+ 'Veh > 3.5 tons prohibited',
27
+ 'No entry',
28
+ 'General caution',
29
+ 'Dangerous curve left',
30
+ 'Dangerous curve right',
31
+ 'Double curve',
32
+ 'Bumpy road',
33
+ 'Slippery road',
34
+ 'Road narrows on the right',
35
+ 'Road work',
36
+ 'Traffic signals',
37
+ 'Pedestrians',
38
+ 'Children crossing',
39
+ 'Bicycles crossing',
40
+ 'Beware of ice/snow',
41
+ 'Wild animals crossing',
42
+ 'End speed + passing limits',
43
+ 'Turn right ahead',
44
+ 'Turn left ahead',
45
+ 'Ahead only',
46
+ 'Go straight or right',
47
+ 'Go straight or left',
48
+ 'Keep right',
49
+ 'Keep left',
50
+ 'Roundabout mandatory',
51
+ 'End of no passing',
52
+ 'End no passing veh > 3.5 tons']
53
+
54
+ class LitGTSRB(pl.LightningModule):
55
+ def __init__(self):
56
+ super().__init__()
57
+ self.model = models.resnet18(pretrained=False, num_classes=43)
58
+
59
+ def forward(self, x):
60
+ out = self.model(x)
61
+ return F.log_softmax(out, dim=1)
62
 
63
  def predict_image(image):
64
+ model = LitGTSRB().load_from_checkpoint('resnet18.ckpt')
65
+ model.eval()
66
  image = image.convert('RGB')
67
  test_transforms = transforms.Compose([
68
  transforms.Resize([224, 224]),
69
+ transforms.ToTensor(),
70
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
71
  ])
 
 
 
 
 
 
 
 
 
 
 
 
72
  image_tensor = test_transforms(image).float()
73
  image_tensor = image_tensor.unsqueeze_(0)
74
+ with torch.no_grad():
75
+ output = model(image_tensor)
76
+ probs = torch.exp(output.data.cpu().squeeze())
77
+ prediction_score , pred_label_idx = torch.topk(probs,5)
78
+ class_top5 = [classes[idx] for idx in pred_label_idx.numpy()]
79
+ return dict(zip(class_top5, map(float, prediction_score.numpy())))
80
  image = gr.Image(type='pil')
81
  label = gr.Label()
82
  examples = ['1.png', '2.png', '3.png']