import gradio as gr from typing import List import cv2 import torch from torchvision import transforms import numpy as np from PIL import Image from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image import io from models import YoloV3Lightning from utils import load_model_from_checkpoint import utils import config as cfg import matplotlib.pyplot as plt import matplotlib.patches as patches from dataset import YOLODataset from torch.utils.data import Dataset, DataLoader from grad_cam import YoloGradCAM device = torch.device('cpu') dataset_mean, dataset_std = (0.4914, 0.4822, 0.4465), \ (0.2470, 0.2435, 0.2616) model = YoloV3Lightning.YOLOv3LightningModel(num_classes=cfg.NUM_CLASSES, anchors=cfg.ANCHORS, S=cfg.S) ckpt_file = 'ckpt_light.pth' checkpoint = load_model_from_checkpoint(device, file_name=ckpt_file) model.load_state_dict(checkpoint['model'], strict=False) model.eval() scaled_anchors = ( torch.tensor(cfg.ANCHORS) * torch.tensor(cfg.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) ).to(model.device) cam = YoloGradCAM(model=model, target_layers=[model.layers[-2]], scaled_anchors=scaled_anchors, use_cuda=False) '''cfg.IMG_DIR = cfg.DATASET + "/images/" cfg.LABEL_DIR = cfg.DATASET + "/labels/" eval_dataset = YOLODataset( cfg.DATASET + "/25examples.csv", transform=cfg.test_transforms, S=[cfg.IMAGE_SIZE // 32, cfg.IMAGE_SIZE // 16, cfg.IMAGE_SIZE // 8], img_dir=cfg.IMG_DIR, label_dir=cfg.LABEL_DIR, anchors=cfg.ANCHORS, mosaic=False ) eval_loader = DataLoader( dataset=eval_dataset, batch_size=cfg.BATCH_SIZE, num_workers=cfg.NUM_WORKERS, pin_memory=cfg.PIN_MEMORY, shuffle=True, drop_last=False, ) scaled_anchors = ( torch.tensor(cfg.ANCHORS) * torch.tensor(cfg.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) ) scaled_anchors = scaled_anchors.to(cfg.DEVICE) utils.plot_examples(model, eval_loader, 0.5, 0.6, scaled_anchors)''' sample_images = [ ['images/000001.jpg'], ['images/000002.jpg'], ['images/000003.jpg'], ['images/000004.jpg'], ['images/000005.jpg'], ['images/000006.jpg'], ['images/000007.jpg'], ['images/000008.jpg'], ['images/000009.jpg'], ['images/000010.jpg'], ['images/000011.jpg'], ['images/000012.jpg'], ['images/000013.jpg'], ['images/000014.jpg'], ['images/000015.jpg'], ['images/000016.jpg'], ['images/000017.jpg'], ['images/000018.jpg'], ['images/000019.jpg'], ['images/000020.jpg'], ['images/000021.jpg'], ['images/000022.jpg'], ['images/000023.jpg'], ['images/000024.jpg'], ['images/000025.jpg'] ] with gr.Blocks() as app: with gr.Row(): gr.Markdown( """ # YoloV3 App! ## Model is trained on PASCAL-VOC data to predict following classes - """) with gr.Row(): gr.HTML( """
aeroplane bicycle bird boat bottle bus car cat
chair cow diningtable dog horse motorbike person pottedplant
sheep sofa train tvmonitor

Click to see the model architecture / code

""" ) with gr.Row(visible=True) as top_pred_cls_col: with gr.Column(): example_images = gr.Gallery(allow_preview=False, label='Select image ', info='', value=[img[0] for img in sample_images], columns=5, rows=2) with gr.Column(): top_pred_image = gr.Image(label='Upload Image or Select from the gallery') with gr.Row(): top_class_btn = gr.Button("Submit", variant='primary') tc_clear_btn = gr.ClearButton() with gr.Row(): if_show_grad_cam = gr.Checkbox(value=True, label='Show Class Activation Map (What the model sees)?') # with gr.Row(visible=True) as top_class_output: with gr.Row(visible=True) as top_class_output: top_class_output_img = gr.Image(interactive=False, label='Prediction Output') with gr.Row(visible=True) as top_class_output: grad_cam_out = gr.Image(interactive=False, visible=True, label='CAM Outcome') def show_cam_output(input): return { grad_cam_out: gr.update(visible=input) } if_show_grad_cam.change( show_cam_output, if_show_grad_cam, grad_cam_out ) def clear_data(): return { top_pred_image: None, top_class_output_img: None } tc_clear_btn.click(clear_data, None, [top_pred_image, top_class_output_img]) def on_select(evt: gr.SelectData): return { top_pred_image: sample_images[evt.index][0] } example_images.select(on_select, None, top_pred_image) def plot_image(image, boxes): """Plots predicted bounding boxes on the image""" cmap = plt.get_cmap("tab20b") class_labels = cfg.PASCAL_CLASSES colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))] im = np.array(image) height, width, _ = im.shape # Create figure and axes fig, ax = plt.subplots(1) # Display the image ax.imshow(im) # box[0] is x midpoint, box[2] is width # box[1] is y midpoint, box[3] is height # Create a Rectangle patch for box in boxes: assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height" class_pred = box[0] box = box[2:] upper_left_x = box[0] - box[2] / 2 upper_left_y = box[1] - box[3] / 2 rect = patches.Rectangle( (upper_left_x * width, upper_left_y * height), box[2] * width, box[3] * height, linewidth=2, edgecolor=colors[int(class_pred)], facecolor="none", ) # Add the patch to the Axes ax.add_patch(rect) plt.text( upper_left_x * width, upper_left_y * height, s=class_labels[int(class_pred)], color="white", verticalalignment="top", bbox={"color": colors[int(class_pred)], "pad": 0}, ) plt.savefig('output.png') x = plt.show() def predict(image: np.ndarray, iou_thresh: float = 0.5, thresh: float = 0.4, show_cam: bool = False, transparency: float = 0.5) -> List[np.ndarray]: with torch.no_grad(): transformed_image = cfg.grad_cam_transforms(image=image)["image"].unsqueeze(0) output = model(transformed_image) bboxes = [[] for _ in range(1)] for i in range(3): batch_size, A, S, _, _ = output[i].shape anchor = scaled_anchors[i] boxes_scale_i = utils.cells_to_bboxes( output[i], anchor, S=S, is_preds=True ) for idx, (box) in enumerate(boxes_scale_i): bboxes[idx] += box nms_boxes = utils.non_max_suppression( bboxes[0], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint", ) plot_image(image, nms_boxes) plotted_img = 'output.png' if not show_cam: return [plotted_img, None] grayscale_cam = cam(transformed_image)[0, :, :] img = cv2.resize(image, (416, 416)) img = np.float32(img) / 255 cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True, image_weight=transparency) return [plotted_img, cam_image] def top_class_img_upload(input_img, if_cam): if input_img is not None: imgs = predict(input_img, show_cam=if_cam) return { top_class_output_img: imgs[0], grad_cam_out: imgs[1] } top_class_btn.click( top_class_img_upload, [top_pred_image, if_show_grad_cam], [top_class_output_img, grad_cam_out] ) ''' Launch the app ''' app.launch()