YoloV3_PascalVOC_Dataset / detection.py
HemaAM's picture
Bug fix to not use cuda
0ce0c39
raw
history blame
2.13 kB
from typing import List
import cv2
import torch
import numpy as np
import config as modelConfig
from pytorch_grad_cam.utils.image import show_cam_on_image
from yolo3 import YOLOv3
from utils import cells_to_bboxes, non_max_suppression, draw_prediction_boxes, YoloGradCAM
model = YOLOv3(num_classes=20)
model.load_state_dict(torch.load("yolo3_model_trained1.pth", map_location="cpu"))
model.eval()
print("Yolov3 Model Loaded..")
scaled_anchors = (
torch.tensor(modelConfig.ANCHORS)
* torch.tensor(modelConfig.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to(modelConfig.DEVICE)
yolo_grad_cam = YoloGradCAM(model=model, target_layers=[model.layers[-2]], use_cuda=False)
@torch.inference_mode()
def detect_objects(image: np.ndarray, iou_thresh: float = 0.5, thresh: float = 0.4, enable_grad_cam: bool = False, transparency: float = 0.5) -> List[np.ndarray]:
transformed_image = modelConfig.transforms(image=image)["image"].unsqueeze(0)
#transformed_image = transformed_image.cuda()
output = model(transformed_image)
bboxes = [[] for _ in range(1)]
for i in range(3):
batch_size, A, S, _, _ = output[i].shape
anchor = scaled_anchors[i]
boxes_scale_i = cells_to_bboxes(
output[i], anchor, S=S, is_preds=True
)
for idx, (box) in enumerate(boxes_scale_i):
bboxes[idx] += box
nms_boxes = non_max_suppression(
bboxes[0], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
)
plot_img = draw_prediction_boxes(image.copy(), nms_boxes, class_labels=modelConfig.PASCAL_CLASSES)
if not enable_grad_cam:
return [plot_img]
grayscale_cam = yolo_grad_cam(transformed_image, scaled_anchors)[0, :, :]
img = cv2.resize(image, (416, 416))
img = np.float32(img) / 255
grad_cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True, image_weight=transparency)
return [plot_img, grad_cam_image]
if __name__=="__main__":
image = cv2.imread("images/001155.jpg")
image = predict(image)
cv2.imshow("image", image)
cv2.waitKey(0)
cv2.destroyAllWindows()