import torch from torch.nn import functional as F import torchvision from torchvision import transforms, models import pytorch_lightning as pl from pytorch_lightning import LightningModule, Trainer from PIL import Image import gradio as gr classes = ['Speed limit (20km/h)', 'Speed limit (30km/h)', 'Speed limit (50km/h)', 'Speed limit (60km/h)', 'Speed limit (70km/h)', 'Speed limit (80km/h)', 'End of speed limit (80km/h)', 'Speed limit (100km/h)', 'Speed limit (120km/h)', 'No passing', 'No passing veh over 3.5 tons', 'Right-of-way at intersection', 'Priority road', 'Yield', 'Stop', 'No vehicles', 'Veh > 3.5 tons prohibited', 'No entry', 'General caution', 'Dangerous curve left', 'Dangerous curve right', 'Double curve', 'Bumpy road', 'Slippery road', 'Road narrows on the right', 'Road work', 'Traffic signals', 'Pedestrians', 'Children crossing', 'Bicycles crossing', 'Beware of ice/snow', 'Wild animals crossing', 'End speed + passing limits', 'Turn right ahead', 'Turn left ahead', 'Ahead only', 'Go straight or right', 'Go straight or left', 'Keep right', 'Keep left', 'Roundabout mandatory', 'End of no passing', 'End no passing veh > 3.5 tons'] class LitGTSRB(pl.LightningModule): def __init__(self): super().__init__() self.model = models.resnet18(pretrained=False, num_classes=43) def forward(self, x): out = self.model(x) return F.log_softmax(out, dim=1) def predict_image(image): model = LitGTSRB().load_from_checkpoint('resnet18.ckpt') model.eval() image = image.convert('RGB') test_transforms = transforms.Compose([ transforms.Resize([224, 224]), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image_tensor = test_transforms(image).float() image_tensor = image_tensor.unsqueeze_(0) with torch.no_grad(): output = model(image_tensor) probs = torch.exp(output.data.cpu().squeeze()) prediction_score , pred_label_idx = torch.topk(probs,5) class_top5 = [classes[idx] for idx in pred_label_idx.numpy()] return dict(zip(class_top5, map(float, prediction_score.numpy()))) image = gr.Image(type='pil') label = gr.Label() examples = ['1.png', '2.png', '3.png', '4.png', '5.png', '6.png', '7.png', '8.png'] intf = gr.Interface(fn=predict_image, inputs=image, outputs=label, examples=examples) intf.launch(inline=True)