Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from typing import List | |
| import cv2 | |
| import torch | |
| from torchvision import transforms | |
| import numpy as np | |
| from PIL import Image | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| import io | |
| from models import YoloV3Lightning | |
| from utils import load_model_from_checkpoint | |
| import utils | |
| import config as cfg | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| from dataset import YOLODataset | |
| from torch.utils.data import Dataset, DataLoader | |
| from grad_cam import YoloGradCAM | |
| device = torch.device('cpu') | |
| dataset_mean, dataset_std = (0.4914, 0.4822, 0.4465), \ | |
| (0.2470, 0.2435, 0.2616) | |
| model = YoloV3Lightning.YOLOv3LightningModel(num_classes=cfg.NUM_CLASSES, anchors=cfg.ANCHORS, S=cfg.S) | |
| ckpt_file = 'ckpt_light.pth' | |
| checkpoint = load_model_from_checkpoint(device, file_name=ckpt_file) | |
| model.load_state_dict(checkpoint['model'], strict=False) | |
| model.eval() | |
| scaled_anchors = ( | |
| torch.tensor(cfg.ANCHORS) | |
| * torch.tensor(cfg.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) | |
| ).to(model.device) | |
| cam = YoloGradCAM(model=model, target_layers=[model.layers[-2]], scaled_anchors=scaled_anchors, use_cuda=False) | |
| '''cfg.IMG_DIR = cfg.DATASET + "/images/" | |
| cfg.LABEL_DIR = cfg.DATASET + "/labels/" | |
| eval_dataset = YOLODataset( | |
| cfg.DATASET + "/25examples.csv", | |
| transform=cfg.test_transforms, | |
| S=[cfg.IMAGE_SIZE // 32, cfg.IMAGE_SIZE // 16, cfg.IMAGE_SIZE // 8], | |
| img_dir=cfg.IMG_DIR, | |
| label_dir=cfg.LABEL_DIR, | |
| anchors=cfg.ANCHORS, | |
| mosaic=False | |
| ) | |
| eval_loader = DataLoader( | |
| dataset=eval_dataset, | |
| batch_size=cfg.BATCH_SIZE, | |
| num_workers=cfg.NUM_WORKERS, | |
| pin_memory=cfg.PIN_MEMORY, | |
| shuffle=True, | |
| drop_last=False, | |
| ) | |
| scaled_anchors = ( | |
| torch.tensor(cfg.ANCHORS) | |
| * torch.tensor(cfg.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) | |
| ) | |
| scaled_anchors = scaled_anchors.to(cfg.DEVICE) | |
| utils.plot_examples(model, eval_loader, 0.5, 0.6, scaled_anchors)''' | |
| sample_images = [ | |
| ['images/000001.jpg'], | |
| ['images/000002.jpg'], | |
| ['images/000003.jpg'], | |
| ['images/000004.jpg'], | |
| ['images/000005.jpg'], | |
| ['images/000006.jpg'], | |
| ['images/000007.jpg'], | |
| ['images/000008.jpg'], | |
| ['images/000009.jpg'], | |
| ['images/000010.jpg'], | |
| ['images/000011.jpg'], | |
| ['images/000012.jpg'], | |
| ['images/000013.jpg'], | |
| ['images/000014.jpg'], | |
| ['images/000015.jpg'], | |
| ['images/000016.jpg'], | |
| ['images/000017.jpg'], | |
| ['images/000018.jpg'], | |
| ['images/000019.jpg'], | |
| ['images/000020.jpg'], | |
| ['images/000021.jpg'], | |
| ['images/000022.jpg'], | |
| ['images/000023.jpg'], | |
| ['images/000024.jpg'], | |
| ['images/000025.jpg'] | |
| ] | |
| with gr.Blocks() as app: | |
| with gr.Row(): | |
| gr.Markdown( | |
| """ | |
| # YoloV3 App! | |
| ## Model is trained on PASCAL-VOC data to predict following classes - | |
| """) | |
| with gr.Row(): | |
| gr.HTML( | |
| """ | |
| <table> | |
| <tr> | |
| <th>aeroplane</th> | |
| <th>bicycle</th> | |
| <th>bird</th> | |
| <th>boat</th> | |
| <th>bottle</th> | |
| <th>bus</th> | |
| <th>car</th> | |
| <th>cat</th> | |
| </tr> | |
| <tr> | |
| <th>chair</th> | |
| <th>cow</th> | |
| <th>diningtable</th> | |
| <th>dog</th> | |
| <th>horse</th> | |
| <th>motorbike</th> | |
| <th>person</th> | |
| <th>pottedplant</th> | |
| </tr> | |
| <tr> | |
| <th>sheep</th> | |
| <th>sofa</th> | |
| <th>train</th> | |
| <th>tvmonitor</th> | |
| </tr> | |
| </table> | |
| <p> | |
| <a href='https://github.com/piygr/yolov3/blob/main/models/YoloV3Lightning.py'>Click to see the model architecture / code </a> | |
| </p> | |
| """ | |
| ) | |
| with gr.Row(visible=True) as top_pred_cls_col: | |
| with gr.Column(): | |
| example_images = gr.Gallery(allow_preview=False, label='Select image ', info='', | |
| value=[img[0] for img in sample_images], columns=5, rows=2) | |
| with gr.Column(): | |
| top_pred_image = gr.Image(label='Upload Image or Select from the gallery') | |
| with gr.Row(): | |
| top_class_btn = gr.Button("Submit", variant='primary') | |
| tc_clear_btn = gr.ClearButton() | |
| with gr.Row(): | |
| if_show_grad_cam = gr.Checkbox(value=True, label='Show Class Activation Map (What the model sees)?') | |
| # with gr.Row(visible=True) as top_class_output: | |
| with gr.Row(visible=True) as top_class_output: | |
| top_class_output_img = gr.Image(interactive=False, label='Prediction Output') | |
| with gr.Row(visible=True) as top_class_output: | |
| grad_cam_out = gr.Image(interactive=False, visible=True, label='CAM Outcome') | |
| def show_cam_output(input): | |
| return { | |
| grad_cam_out: gr.update(visible=input) | |
| } | |
| if_show_grad_cam.change( | |
| show_cam_output, | |
| if_show_grad_cam, | |
| grad_cam_out | |
| ) | |
| def clear_data(): | |
| return { | |
| top_pred_image: None, | |
| top_class_output_img: None | |
| } | |
| tc_clear_btn.click(clear_data, None, [top_pred_image, top_class_output_img]) | |
| def on_select(evt: gr.SelectData): | |
| return { | |
| top_pred_image: sample_images[evt.index][0] | |
| } | |
| example_images.select(on_select, None, top_pred_image) | |
| def plot_image(image, boxes): | |
| """Plots predicted bounding boxes on the image""" | |
| cmap = plt.get_cmap("tab20b") | |
| class_labels = cfg.PASCAL_CLASSES | |
| colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))] | |
| im = np.array(image) | |
| height, width, _ = im.shape | |
| # Create figure and axes | |
| fig, ax = plt.subplots(1) | |
| # Display the image | |
| ax.imshow(im) | |
| # box[0] is x midpoint, box[2] is width | |
| # box[1] is y midpoint, box[3] is height | |
| # Create a Rectangle patch | |
| for box in boxes: | |
| assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height" | |
| class_pred = box[0] | |
| box = box[2:] | |
| upper_left_x = box[0] - box[2] / 2 | |
| upper_left_y = box[1] - box[3] / 2 | |
| rect = patches.Rectangle( | |
| (upper_left_x * width, upper_left_y * height), | |
| box[2] * width, | |
| box[3] * height, | |
| linewidth=2, | |
| edgecolor=colors[int(class_pred)], | |
| facecolor="none", | |
| ) | |
| # Add the patch to the Axes | |
| ax.add_patch(rect) | |
| plt.text( | |
| upper_left_x * width, | |
| upper_left_y * height, | |
| s=class_labels[int(class_pred)], | |
| color="white", | |
| verticalalignment="top", | |
| bbox={"color": colors[int(class_pred)], "pad": 0}, | |
| ) | |
| plt.savefig('output.png') | |
| x = plt.show() | |
| def predict(image: np.ndarray, iou_thresh: float = 0.5, thresh: float = 0.4, show_cam: bool = False, | |
| transparency: float = 0.5) -> List[np.ndarray]: | |
| with torch.no_grad(): | |
| transformed_image = cfg.grad_cam_transforms(image=image)["image"].unsqueeze(0) | |
| output = model(transformed_image) | |
| bboxes = [[] for _ in range(1)] | |
| for i in range(3): | |
| batch_size, A, S, _, _ = output[i].shape | |
| anchor = scaled_anchors[i] | |
| boxes_scale_i = utils.cells_to_bboxes( | |
| output[i], anchor, S=S, is_preds=True | |
| ) | |
| for idx, (box) in enumerate(boxes_scale_i): | |
| bboxes[idx] += box | |
| nms_boxes = utils.non_max_suppression( | |
| bboxes[0], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint", | |
| ) | |
| plot_image(image, nms_boxes) | |
| plotted_img = 'output.png' | |
| if not show_cam: | |
| return [plotted_img, None] | |
| grayscale_cam = cam(transformed_image)[0, :, :] | |
| img = cv2.resize(image, (416, 416)) | |
| img = np.float32(img) / 255 | |
| cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True, image_weight=transparency) | |
| return [plotted_img, cam_image] | |
| def top_class_img_upload(input_img, if_cam): | |
| if input_img is not None: | |
| imgs = predict(input_img, show_cam=if_cam) | |
| return { | |
| top_class_output_img: imgs[0], | |
| grad_cam_out: imgs[1] | |
| } | |
| top_class_btn.click( | |
| top_class_img_upload, | |
| [top_pred_image, if_show_grad_cam], | |
| [top_class_output_img, grad_cam_out] | |
| ) | |
| ''' | |
| Launch the app | |
| ''' | |
| app.launch() | |