HemaAM commited on
Commit
010350c
·
1 Parent(s): fa1bc1f

Removing this file as the functionality is put in application file

Browse files
Files changed (1) hide show
  1. detection.py +0 -62
detection.py DELETED
@@ -1,62 +0,0 @@
1
- from typing import List
2
- import cv2
3
- import torch
4
- import numpy as np
5
- import config as modelConfig
6
- from pytorch_grad_cam.utils.image import show_cam_on_image
7
-
8
- from yolo3 import YOLOv3
9
- from utils import cells_to_bboxes, non_max_suppression, draw_prediction_boxes, YoloGradCAM
10
-
11
-
12
- model = YOLOv3(num_classes=20)
13
-
14
- model.load_state_dict(torch.load("yolo3_model_trained1.pth", map_location="cpu"))
15
- model.eval()
16
- print("Yolov3 Model Loaded..")
17
-
18
- scaled_anchors = (
19
- torch.tensor(modelConfig.ANCHORS)
20
- * torch.tensor(modelConfig.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
21
- ).to(modelConfig.DEVICE)
22
-
23
- yolo_grad_cam = YoloGradCAM(model=model, target_layers=[model.layers[-2]], use_cuda=False)
24
-
25
-
26
- @torch.inference_mode()
27
- def detect_objects(image: np.ndarray, iou_thresh: float = 0.5, thresh: float = 0.4, enable_grad_cam: bool = False, transparency: float = 0.5) -> List[np.ndarray]:
28
- transformed_image = modelConfig.transforms(image=image)["image"].unsqueeze(0)
29
- #transformed_image = transformed_image.cuda()
30
- output = model(transformed_image)
31
-
32
- bboxes = [[] for _ in range(1)]
33
- for i in range(3):
34
- batch_size, A, S, _, _ = output[i].shape
35
- anchor = scaled_anchors[i]
36
- boxes_scale_i = cells_to_bboxes(
37
- output[i], anchor, S=S, is_preds=True
38
- )
39
- for idx, (box) in enumerate(boxes_scale_i):
40
- bboxes[idx] += box
41
-
42
- nms_boxes = non_max_suppression(
43
- bboxes[0], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
44
- )
45
- plot_img = draw_prediction_boxes(image.copy(), nms_boxes, class_labels=modelConfig.PASCAL_CLASSES)
46
- if not enable_grad_cam:
47
- return [plot_img]
48
-
49
- grayscale_cam = yolo_grad_cam(transformed_image, scaled_anchors)[0, :, :]
50
- img = cv2.resize(image, (416, 416))
51
- img = np.float32(img) / 255
52
- grad_cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True, image_weight=transparency)
53
- return [plot_img, grad_cam_image]
54
-
55
-
56
- if __name__=="__main__":
57
- image = cv2.imread("images/001155.jpg")
58
- image = predict(image)
59
-
60
- cv2.imshow("image", image)
61
- cv2.waitKey(0)
62
- cv2.destroyAllWindows()