Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.nn import functional as F | |
| import torchvision | |
| from torchvision import transforms, models | |
| import pytorch_lightning as pl | |
| from pytorch_lightning import LightningModule, Trainer | |
| from PIL import Image | |
| import gradio as gr | |
| classes = ['Speed limit (20km/h)', | |
| 'Speed limit (30km/h)', | |
| 'Speed limit (50km/h)', | |
| 'Speed limit (60km/h)', | |
| 'Speed limit (70km/h)', | |
| 'Speed limit (80km/h)', | |
| 'End of speed limit (80km/h)', | |
| 'Speed limit (100km/h)', | |
| 'Speed limit (120km/h)', | |
| 'No passing', | |
| 'No passing veh over 3.5 tons', | |
| 'Right-of-way at intersection', | |
| 'Priority road', | |
| 'Yield', | |
| 'Stop', | |
| 'No vehicles', | |
| 'Veh > 3.5 tons prohibited', | |
| 'No entry', | |
| 'General caution', | |
| 'Dangerous curve left', | |
| 'Dangerous curve right', | |
| 'Double curve', | |
| 'Bumpy road', | |
| 'Slippery road', | |
| 'Road narrows on the right', | |
| 'Road work', | |
| 'Traffic signals', | |
| 'Pedestrians', | |
| 'Children crossing', | |
| 'Bicycles crossing', | |
| 'Beware of ice/snow', | |
| 'Wild animals crossing', | |
| 'End speed + passing limits', | |
| 'Turn right ahead', | |
| 'Turn left ahead', | |
| 'Ahead only', | |
| 'Go straight or right', | |
| 'Go straight or left', | |
| 'Keep right', | |
| 'Keep left', | |
| 'Roundabout mandatory', | |
| 'End of no passing', | |
| 'End no passing veh > 3.5 tons'] | |
| class LitGTSRB(pl.LightningModule): | |
| def __init__(self): | |
| super().__init__() | |
| self.model = models.resnet18(pretrained=False, num_classes=43) | |
| def forward(self, x): | |
| out = self.model(x) | |
| return F.log_softmax(out, dim=1) | |
| def predict_image(image): | |
| model = LitGTSRB().load_from_checkpoint('resnet18.ckpt') | |
| model.eval() | |
| image = image.convert('RGB') | |
| test_transforms = transforms.Compose([ | |
| transforms.Resize([224, 224]), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| image_tensor = test_transforms(image).float() | |
| image_tensor = image_tensor.unsqueeze_(0) | |
| with torch.no_grad(): | |
| output = model(image_tensor) | |
| probs = torch.exp(output.data.cpu().squeeze()) | |
| prediction_score , pred_label_idx = torch.topk(probs,5) | |
| class_top5 = [classes[idx] for idx in pred_label_idx.numpy()] | |
| return dict(zip(class_top5, map(float, prediction_score.numpy()))) | |
| image = gr.Image(type='pil') | |
| label = gr.Label() | |
| examples = ['1.png', '2.png', '3.png', '4.png', '5.png', '6.png', '7.png', '8.png'] | |
| intf = gr.Interface(fn=predict_image, inputs=image, outputs=label, examples=examples) | |
| intf.launch(inline=True) |