HemaAM's picture
Updated the description of the application
1071e1e
raw
history blame
3.52 kB
from typing import List
import cv2
import torch
import numpy as np
import gradio as gr
import config as modelConfig
from pytorch_grad_cam.utils.image import show_cam_on_image
from yolov3 import YOLOv3
import utils
from utils import cells_to_bboxes, non_max_suppression, draw_bounding_boxes, YoloGradCAM
model = YOLOv3(num_classes=len(modelConfig.PASCAL_CLASSES))
optimizer = torch.optim.Adam(model.parameters(), lr=0.00072/100, weight_decay=1e-4)
utils.load_checkpoint("checkpoint.pth.tar", model, optimizer, 0.00072/100)
scaled_anchors = (
torch.tensor(modelConfig.ANCHORS)
* torch.tensor(modelConfig.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to(modelConfig.DEVICE)
yolo_grad_cam = YoloGradCAM(model=model, target_layers=[model.layers[-2]], use_cuda=False)
@torch.inference_mode()
def detect_objects(image: np.ndarray, iou_thresh: float = 0.5, thresh: float = 0.4, enable_grad_cam: bool = False, transparency: float = 0.5) -> List[np.ndarray]:
transformed_image = modelConfig.transforms(image=image)["image"].unsqueeze(0)
#transformed_image = transformed_image.cuda()
output = model(transformed_image)
bboxes = [[] for _ in range(1)]
for i in range(3):
batch_size, A, S, _, _ = output[i].shape
anchor = scaled_anchors[i]
boxes_scale_i = cells_to_bboxes(
output[i], anchor, S=S, is_preds=True
)
for idx, (box) in enumerate(boxes_scale_i):
bboxes[idx] += box
nms_boxes = non_max_suppression(
bboxes[0], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
)
plot_img_with_bboxes = draw_bounding_boxes(image.copy(), nms_boxes, class_labels=modelConfig.PASCAL_CLASSES)
if not enable_grad_cam:
return [plot_img_with_bboxes]
grayscale_cam = yolo_grad_cam(transformed_image, scaled_anchors)[0, :, :]
img = cv2.resize(image, (416, 416))
img = np.float32(img) / 255
grad_cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True, image_weight=transparency)
return [plot_img_with_bboxes, grad_cam_image]
def inference(
image: np.ndarray,
iou_thresh: float, thresh: float,
enable_grad_cam: str,
transparency: float,
):
results = detect_objects(image, iou_thresh, thresh, enable_grad_cam, transparency)
return results
title = "Object detection application using YoloV3 Model"
description = f"Object detection application using pre-trained YoloV3 model for Pascal VOC dataset. This app has GradCAM option also. \n The 20 classes in Pascal voc dataset are : {', '.join(modelConfig.PASCAL_CLASSES)}"
examples = [
["images/000811.jpg", 0.6, 0.6, True, 0.6],
["images/000830.jpg", 0.5, 0.5, True, 0.6],
["images/000842.jpg", 0.6, 0.6, True, 0.6],
["images/001114.jpg", 0.4, 0.5, True, 0.6],
["images/001133.jpg", 0.7, 0.7, True, 0.6],
["images/001155.jpg", 0.7, 0.69, True, 0.6],
["images/000008.jpg", 0.66, 0.69, True, 0.6],
["images/000031.jpg", 0.6, 0.6, True, 0.6],
["images/000175.jpg", 0.6, 0.6, True, 0.6],
]
demo = gr.Interface(
inference,
inputs=[
gr.Image(label="Input Image"),
gr.Slider(0, 1, value=0.5, label="IOU Threshold"),
gr.Slider(0, 1, value=0.4, label="Threshold"),
gr.Checkbox(label="Show Grad Cam"),
gr.Slider(0, 1, value=0.5, label="Opacity of GradCAM"),
],
outputs=[
gr.Gallery(rows=2, columns=1),
],
title=title,
description=description,
examples=examples,
)
demo.launch(debug=True)