joonhyun23452 commited on
Commit
8075387
Β·
1 Parent(s): 5f752ee

open proxydet demo

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. app.py +56 -39
  2. assets/beach.jpg +0 -0
  3. assets/desk.jpg +0 -0
  4. assets/pikachu.jpg +0 -0
  5. configs/Base-C2_L_R5021k_640b64_4x.yaml +83 -0
  6. configs/BoxSup-C2_Lbase_CLIP_R5021k_640b64_4x.yaml +7 -0
  7. configs/ProxyDet_R50_Lbase_INL.yaml +59 -0
  8. configs/ProxyDet_SwinB_Lbase_INL.yaml +51 -0
  9. datasets/metadata/__init__.py +0 -0
  10. demo.py +245 -0
  11. packages.txt +3 -0
  12. proxydet/__init__.py +18 -0
  13. proxydet/cat_names.py +1 -0
  14. proxydet/config.py +156 -0
  15. proxydet/custom_solver.py +78 -0
  16. proxydet/data/custom_build_augmentation.py +51 -0
  17. proxydet/data/custom_dataset_dataloader.py +331 -0
  18. proxydet/data/custom_dataset_mapper.py +280 -0
  19. proxydet/data/datasets/cc.py +23 -0
  20. proxydet/data/datasets/coco_zeroshot.py +121 -0
  21. proxydet/data/datasets/imagenet.py +41 -0
  22. proxydet/data/datasets/lvis_22k_categories.py +0 -0
  23. proxydet/data/datasets/lvis_v1.py +155 -0
  24. proxydet/data/datasets/objects365.py +770 -0
  25. proxydet/data/datasets/oid.py +535 -0
  26. proxydet/data/datasets/register_oid.py +122 -0
  27. proxydet/data/tar_dataset.py +138 -0
  28. proxydet/data/transforms/custom_augmentation_impl.py +60 -0
  29. proxydet/data/transforms/custom_transform.py +114 -0
  30. proxydet/evaluation/custom_coco_eval.py +124 -0
  31. proxydet/evaluation/oideval.py +699 -0
  32. proxydet/modeling/backbone/swintransformer.py +750 -0
  33. proxydet/modeling/backbone/timm.py +221 -0
  34. proxydet/modeling/debug.py +334 -0
  35. proxydet/modeling/meta_arch/custom_rcnn.py +232 -0
  36. proxydet/modeling/meta_arch/d2_deformable_detr.py +308 -0
  37. proxydet/modeling/roi_heads/proxydet_fast_rcnn.py +618 -0
  38. proxydet/modeling/roi_heads/proxydet_roi_heads.py +556 -0
  39. proxydet/modeling/roi_heads/zero_shot_classifier.py +111 -0
  40. proxydet/modeling/text/text_encoder.py +189 -0
  41. proxydet/modeling/utils.py +54 -0
  42. proxydet/predictor.py +295 -0
  43. requirements.txt +0 -6
  44. third_party/CenterNet2/.github/CODE_OF_CONDUCT.md +5 -0
  45. third_party/CenterNet2/.github/CONTRIBUTING.md +68 -0
  46. third_party/CenterNet2/.github/Detectron2-Logo-Horz.svg +1 -0
  47. third_party/CenterNet2/.github/ISSUE_TEMPLATE.md +5 -0
  48. third_party/CenterNet2/.github/ISSUE_TEMPLATE/bugs.md +38 -0
  49. third_party/CenterNet2/.github/ISSUE_TEMPLATE/config.yml +17 -0
  50. third_party/CenterNet2/.github/ISSUE_TEMPLATE/documentation.md +14 -0
app.py CHANGED
@@ -3,10 +3,14 @@ import cv2
3
  import os
4
  import gradio as gr
5
  import numpy as np
6
- from transformers import OwlViTProcessor, OwlViTForObjectDetection
7
-
8
- def setup():
 
9
  os.system("python3 -m pip install 'git+https://github.com/facebookresearch/detectron2.git'")
 
 
 
10
 
11
  # Use GPU if available
12
  if torch.cuda.is_available():
@@ -14,26 +18,42 @@ if torch.cuda.is_available():
14
  else:
15
  device = torch.device("cpu")
16
 
17
- model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(device)
18
- model.eval()
19
- processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
20
-
21
-
22
- def query_image(img, text_queries, score_threshold):
23
- text_queries = text_queries
24
- text_queries = text_queries.split(",")
25
 
26
- target_sizes = torch.Tensor([img.shape[:2]])
27
- inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
 
 
 
 
 
 
29
  with torch.no_grad():
30
- outputs = model(**inputs)
31
-
32
- outputs.logits = outputs.logits.cpu()
33
- outputs.pred_boxes = outputs.pred_boxes.cpu()
34
- results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
35
- boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
36
-
37
  font = cv2.FONT_HERSHEY_SIMPLEX
38
 
39
  for box, score, label in zip(boxes, scores, labels):
@@ -47,35 +67,32 @@ def query_image(img, text_queries, score_threshold):
47
  y = box[3] + 25
48
 
49
  img = cv2.putText(
50
- img, text_queries[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA
51
  )
52
  return img
53
 
54
  if __name__ == "__main__":
55
- setup()
56
-
57
  description = """
58
- Gradio demo for <a href="https://huggingface.co/docs/transformers/main/en/model_doc/owlvit">OWL-ViT</a>,
59
- introduced in <a href="https://arxiv.org/abs/2205.06230">Simple Open-Vocabulary Object Detection
60
- with Vision Transformers</a>.
61
- \n\nYou can use OWL-ViT to query images with text descriptions of any object.
62
- To use it, simply upload an image and enter comma separated text descriptions of objects you want to query the image for. You
63
- can also use the score threshold slider to set a threshold to filter out low probability predictions.
64
- \n\nOWL-ViT is trained on text templates,
65
- hence you can get better predictions by querying the image with text templates used in training the original model: *"photo of a star-spangled banner"*,
66
- *"image of a shoe"*. Refer to the <a href="https://arxiv.org/abs/2103.00020">CLIP</a> paper to see the full list of text templates used to augment the training data.
67
- \n\n<a href="https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/zeroshot_object_detection_with_owlvit.ipynb">Colab demo</a>
68
  """
69
  demo = gr.Interface(
70
  query_image,
71
- inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1)],
72
  outputs="image",
73
- title="Zero-Shot Object Detection with OWL-ViT",
74
  description=description,
75
  examples=[
76
- ["assets/astronaut.png", "human face, rocket, star-spangled banner, nasa badge", 0.11],
77
- ["assets/coffee.png", "coffee mug, spoon, plate", 0.1],
78
- ["assets/butterflies.jpeg", "orange butterfly", 0.3],
79
  ],
80
  )
81
- demo.launch()
 
 
 
 
 
 
3
  import os
4
  import gradio as gr
5
  import numpy as np
6
+ from argparse import Namespace
7
+ try:
8
+ import detectron2
9
+ except:
10
  os.system("python3 -m pip install 'git+https://github.com/facebookresearch/detectron2.git'")
11
+ import detectron2
12
+ from demo import setup_cfg
13
+ from proxydet.predictor import VisualizationDemo
14
 
15
  # Use GPU if available
16
  if torch.cuda.is_available():
 
18
  else:
19
  device = torch.device("cpu")
20
 
21
+ # download metadata
22
+ zs_weight_path = 'datasets/metadata/lvis_v1_clip_a+cname.npy'
23
+ if not os.path.exists(zs_weight_path):
24
+ wget.download("https://github.com/facebookresearch/Detic/raw/main/datasets/metadata/lvis_v1_clip_a+cname.npy", out=zs_weight_path)
 
 
 
 
25
 
26
+ args = Namespace(
27
+ base_cat_threshold=0.9,
28
+ confidence_threshold=0.0,
29
+ config_file='configs/ProxyDet_SwinB_Lbase_INL.yaml',
30
+ cpu=not torch.cuda.is_available(),
31
+ custom_vocabulary='headphone,webcam,paper,coffe',
32
+ input=['.assets/desk.jpg'],
33
+ opts=['MODEL.WEIGHTS', 'models/proxydet_swinb_w_inl.pth'],
34
+ output='out.jpg',
35
+ pred_all_class=False,
36
+ video_input=None,
37
+ vocabulary='custom',
38
+ webcam=None,
39
+ zeroshot_weight_path='datasets/metadata/lvis_v1_clip_a+cname.npy'
40
+ )
41
+ cfg = setup_cfg(args)
42
+ ovd_demo = VisualizationDemo(cfg, args)
43
 
44
+ def query_image(img, text_queries, score_threshold, base_alpha, novel_beta):
45
+ text_queries_split = text_queries.split(",")
46
+ ovd_demo.reset_classifier(text_queries)
47
+ ovd_demo.reset_base_cat_mask()
48
+ ovd_demo.predictor.model.roi_heads.cmm_base_alpha = base_alpha
49
+ ovd_demo.predictor.model.roi_heads.cmm_novel_beta = novel_beta
50
+ img_bgr = img[:, :, ::-1]
51
  with torch.no_grad():
52
+ predictions, visualized_output = ovd_demo.run_on_image(img_bgr)
53
+ output_instances = predictions["instances"].to(device)
54
+ boxes = output_instances.pred_boxes.tensor
55
+ scores = output_instances.scores
56
+ labels = output_instances.pred_classes.tolist()
 
 
57
  font = cv2.FONT_HERSHEY_SIMPLEX
58
 
59
  for box, score, label in zip(boxes, scores, labels):
 
67
  y = box[3] + 25
68
 
69
  img = cv2.putText(
70
+ img, text_queries_split[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA
71
  )
72
  return img
73
 
74
  if __name__ == "__main__":
 
 
75
  description = """
76
+ Gradio demo for ProxyDet, introduced in <a href="https://arxiv.org/abs/2312.07266">ProxyDet: Synthesizing Proxy Novel Classes via Classwise Mixup for Open-Vocabulary Object Detection</a>.
77
+ \n\nYou can use ProxyDet to query images with text descriptions of any object.
78
+ Simply upload an image and enter comma separated objects (e.g., "dog,cat,headphone") which you want to detect within the image.
79
+ You can also use the score threshold slider to set a threshold to filter out low probability predictions.
 
 
 
 
 
 
80
  """
81
  demo = gr.Interface(
82
  query_image,
83
+ inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1), gr.Slider(0, 1, value=0.15), gr.Slider(0, 1, value=0.35)],
84
  outputs="image",
85
+ title="Open-Vocabulary Object Detection with ProxyDet",
86
  description=description,
87
  examples=[
88
+ ["assets/desk.jpg", "headphone,webcam,paper,coffee", 0.11, 0.15, 0.35],
89
+ ["assets/beach.jpg", "person,kite", 0.1, 0.15, 0.35],
90
+ ["assets/pikachu.jpg", "pikachu,person", 0.15, 0.15, 0.35],
91
  ],
92
  )
93
+ demo.launch(
94
+ server_name="0.0.0.0",
95
+ server_port=int(os.environ["NSML_PORT1"]),
96
+ share=False,
97
+ debug=True,
98
+ )
assets/beach.jpg ADDED
assets/desk.jpg ADDED
assets/pikachu.jpg ADDED
configs/Base-C2_L_R5021k_640b64_4x.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code Copied from https://github.com/facebookresearch/Detic/blob/main/configs/Base-C2_L_R5021k_640b64_4x.yaml
2
+ MODEL:
3
+ META_ARCHITECTURE: "CustomRCNN"
4
+ MASK_ON: True
5
+ PROPOSAL_GENERATOR:
6
+ NAME: "CenterNet"
7
+ WEIGHTS: "models/resnet50_miil_21k.pkl"
8
+ BACKBONE:
9
+ NAME: build_p67_timm_fpn_backbone
10
+ TIMM:
11
+ BASE_NAME: resnet50_in21k
12
+ FPN:
13
+ IN_FEATURES: ["layer3", "layer4", "layer5"]
14
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
15
+ PIXEL_STD: [58.395, 57.12, 57.375]
16
+ ROI_HEADS:
17
+ NAME: ProxydetCascadeROIHeads
18
+ IN_FEATURES: ["p3", "p4", "p5"]
19
+ IOU_THRESHOLDS: [0.6]
20
+ NUM_CLASSES: 1203
21
+ SCORE_THRESH_TEST: 0.02
22
+ NMS_THRESH_TEST: 0.5
23
+ ROI_BOX_CASCADE_HEAD:
24
+ IOUS: [0.6, 0.7, 0.8]
25
+ ROI_BOX_HEAD:
26
+ NAME: "FastRCNNConvFCHead"
27
+ NUM_FC: 2
28
+ POOLER_RESOLUTION: 7
29
+ CLS_AGNOSTIC_BBOX_REG: True
30
+ MULT_PROPOSAL_SCORE: True
31
+
32
+ USE_SIGMOID_CE: True
33
+ USE_FED_LOSS: True
34
+ ROI_MASK_HEAD:
35
+ NAME: "MaskRCNNConvUpsampleHead"
36
+ NUM_CONV: 4
37
+ POOLER_RESOLUTION: 14
38
+ CLS_AGNOSTIC_MASK: True
39
+ CENTERNET:
40
+ NUM_CLASSES: 1203
41
+ REG_WEIGHT: 1.
42
+ NOT_NORM_REG: True
43
+ ONLY_PROPOSAL: True
44
+ WITH_AGN_HM: True
45
+ INFERENCE_TH: 0.0001
46
+ PRE_NMS_TOPK_TRAIN: 4000
47
+ POST_NMS_TOPK_TRAIN: 2000
48
+ PRE_NMS_TOPK_TEST: 1000
49
+ POST_NMS_TOPK_TEST: 256
50
+ NMS_TH_TRAIN: 0.9
51
+ NMS_TH_TEST: 0.9
52
+ POS_WEIGHT: 0.5
53
+ NEG_WEIGHT: 0.5
54
+ IGNORE_HIGH_FP: 0.85
55
+ DATASETS:
56
+ TRAIN: ("lvis_v1_train",)
57
+ TEST: ("lvis_v1_val",)
58
+ DATALOADER:
59
+ SAMPLER_TRAIN: "RepeatFactorTrainingSampler"
60
+ REPEAT_THRESHOLD: 0.001
61
+ NUM_WORKERS: 8
62
+ TEST:
63
+ DETECTIONS_PER_IMAGE: 300
64
+ SOLVER:
65
+ LR_SCHEDULER_NAME: "WarmupCosineLR"
66
+ CHECKPOINT_PERIOD: 1000000000
67
+ WARMUP_ITERS: 10000
68
+ WARMUP_FACTOR: 0.0001
69
+ USE_CUSTOM_SOLVER: True
70
+ OPTIMIZER: "ADAMW"
71
+ MAX_ITER: 90000
72
+ IMS_PER_BATCH: 64
73
+ BASE_LR: 0.0002
74
+ CLIP_GRADIENTS:
75
+ ENABLED: True
76
+ INPUT:
77
+ FORMAT: RGB
78
+ CUSTOM_AUG: EfficientDetResizeCrop
79
+ TRAIN_SIZE: 640
80
+ OUTPUT_DIR: "./output/ProxyDet/auto"
81
+ EVAL_PROPOSAL_AR: False
82
+ VERSION: 2
83
+ FP16: True
configs/BoxSup-C2_Lbase_CLIP_R5021k_640b64_4x.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Code Copied from https://github.com/facebookresearch/Detic/blob/main/configs/BoxSup-C2_Lbase_CLIP_R5021k_640b64_4x.yaml
2
+ _BASE_: "Base-C2_L_R5021k_640b64_4x.yaml"
3
+ MODEL:
4
+ ROI_BOX_HEAD:
5
+ USE_ZEROSHOT_CLS: True
6
+ DATASETS:
7
+ TRAIN: ("lvis_v1_train_norare",)
configs/ProxyDet_R50_Lbase_INL.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code Adapted from https://github.com/facebookresearch/Detic/blob/main/configs/Detic_LbaseI_CLIP_R5021k_640b64_4x_ft4x_max-size.yaml
2
+ _BASE_: "Base-C2_L_R5021k_640b64_4x.yaml"
3
+ MODEL:
4
+ ROI_BOX_HEAD:
5
+ USE_ZEROSHOT_CLS: True
6
+ IMAGE_LABEL_LOSS: 'max_size'
7
+ USE_REGIONAL_EMBEDDING: True
8
+ ROI_HEADS:
9
+ BASE_CAT_MASK: "datasets/metadata/lvis_v1_base_cat_mask.npy"
10
+ CMM:
11
+ MIXUP_STAGE: [2]
12
+ MIXUP_STAGE_TEST: [2]
13
+ MIXUP_BETA: 1.0
14
+ LOSS: "l1"
15
+ LOSS_WEIGHT: 256.0
16
+ SEPARATED_BRANCH: True
17
+ BASE_ALPHA: 0.15
18
+ NOVEL_BETA: 0.35
19
+ USE_INL: False
20
+ PROTOTYPE: "obj_score"
21
+ PROTOTYPE_TEMP: 1.0
22
+ CLASSIFIER_TEMP: 1.0
23
+ USE_SIGMOID_CE: True
24
+ WEIGHTS: "models/BoxSup-C2_Lbase_CLIP_R5021k_640b64_4x.pth"
25
+ SOLVER:
26
+ MAX_ITER: 90000
27
+ IMS_PER_BATCH: 64
28
+ BASE_LR: 0.0002
29
+ WARMUP_ITERS: 1000
30
+ WARMUP_FACTOR: 0.001
31
+ DATASETS:
32
+ TRAIN: ("lvis_v1_train_norare","imagenet_lvis_v1")
33
+ DATALOADER:
34
+ SAMPLER_TRAIN: "MultiDatasetSampler"
35
+ DATASET_RATIO: [1, 4]
36
+ USE_DIFF_BS_SIZE: True
37
+ DATASET_BS: [8, 32]
38
+ DATASET_INPUT_SIZE: [640, 320]
39
+ USE_RFS: [True, False]
40
+ DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]]
41
+ FILTER_EMPTY_ANNOTATIONS: False
42
+ MULTI_DATASET_GROUPING: True
43
+ DATASET_ANN: ['box', 'image']
44
+ NUM_WORKERS: 8
45
+ WITH_IMAGE_LABELS: True
46
+
47
+
48
+
49
+
50
+
51
+
52
+
53
+
54
+
55
+
56
+
57
+
58
+
59
+
configs/ProxyDet_SwinB_Lbase_INL.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code Adapted from https://github.com/facebookresearch/Detic/blob/main/configs/Detic_LbaseI_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml
2
+ _BASE_: "Base-C2_L_R5021k_640b64_4x.yaml"
3
+ MODEL:
4
+ ROI_BOX_HEAD:
5
+ USE_ZEROSHOT_CLS: True
6
+ IMAGE_LABEL_LOSS: 'max_size'
7
+ USE_REGIONAL_EMBEDDING: True
8
+ ROI_HEADS:
9
+ BASE_CAT_MASK: "datasets/metadata/lvis_v1_base_cat_mask.npy"
10
+ CMM:
11
+ MIXUP_STAGE: [2]
12
+ MIXUP_STAGE_TEST: [2]
13
+ MIXUP_BETA: 1.0
14
+ LOSS: "l1"
15
+ LOSS_WEIGHT: 256.0
16
+ SEPARATED_BRANCH: True
17
+ BASE_ALPHA: 0.15
18
+ NOVEL_BETA: 0.35
19
+ USE_INL: False
20
+ PROTOTYPE: "obj_score"
21
+ PROTOTYPE_TEMP: 1.0
22
+ CLASSIFIER_TEMP: 1.0
23
+ USE_SIGMOID_CE: True
24
+ BACKBONE:
25
+ NAME: build_swintransformer_fpn_backbone
26
+ SWIN:
27
+ SIZE: B-22k
28
+ FPN:
29
+ IN_FEATURES: ["swin1", "swin2", "swin3"]
30
+ WEIGHTS: "models/BoxSup-C2_Lbase_CLIP_SwinB_896b32_4x.pth"
31
+ SOLVER:
32
+ MAX_ITER: 180000
33
+ IMS_PER_BATCH: 32
34
+ BASE_LR: 0.0001
35
+ WARMUP_ITERS: 1000
36
+ WARMUP_FACTOR: 0.001
37
+ DATASETS:
38
+ TRAIN: ("lvis_v1_train_norare","imagenet_lvis_v1")
39
+ DATALOADER:
40
+ SAMPLER_TRAIN: "MultiDatasetSampler"
41
+ DATASET_RATIO: [1, 4]
42
+ USE_DIFF_BS_SIZE: True
43
+ DATASET_BS: [4, 16]
44
+ DATASET_INPUT_SIZE: [896, 448]
45
+ USE_RFS: [True, False]
46
+ DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]]
47
+ FILTER_EMPTY_ANNOTATIONS: False
48
+ MULTI_DATASET_GROUPING: True
49
+ DATASET_ANN: ['box', 'image']
50
+ NUM_WORKERS: 8
51
+ WITH_IMAGE_LABELS: True
datasets/metadata/__init__.py ADDED
File without changes
demo.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ '''
3
+ Modifications Copyright (c) 2024-present NAVER Corp, Apache License v2.0
4
+ original source: https://github.com/facebookresearch/Detic/blob/main/demo.py
5
+ '''
6
+ import argparse
7
+ import glob
8
+ import multiprocessing as mp
9
+ import numpy as np
10
+ import os
11
+ import tempfile
12
+ import time
13
+ import warnings
14
+ import cv2
15
+ import tqdm
16
+ import sys
17
+ import mss
18
+
19
+ from detectron2.config import get_cfg
20
+ from detectron2.data.detection_utils import read_image
21
+ from detectron2.utils.logger import setup_logger
22
+ from detectron2.engine.defaults import _highlight
23
+
24
+ sys.path.insert(0, 'third_party/CenterNet2/')
25
+ from centernet.config import add_centernet_config
26
+ from proxydet.config import add_proxydet_config
27
+
28
+ from proxydet.predictor import VisualizationDemo
29
+
30
+ # Fake a video capture object OpenCV style - half width, half height of first screen using MSS
31
+ class ScreenGrab:
32
+ def __init__(self):
33
+ self.sct = mss.mss()
34
+ m0 = self.sct.monitors[0]
35
+ self.monitor = {'top': 0, 'left': 0, 'width': m0['width'] / 2, 'height': m0['height'] / 2}
36
+
37
+ def read(self):
38
+ img = np.array(self.sct.grab(self.monitor))
39
+ nf = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
40
+ return (True, nf)
41
+
42
+ def isOpened(self):
43
+ return True
44
+ def release(self):
45
+ return True
46
+
47
+
48
+ # constants
49
+ WINDOW_NAME = "ProxyDet"
50
+
51
+ def setup_cfg(args):
52
+ cfg = get_cfg()
53
+ if args.cpu:
54
+ cfg.MODEL.DEVICE="cpu"
55
+ add_centernet_config(cfg)
56
+ add_proxydet_config(cfg)
57
+ cfg.merge_from_file(args.config_file)
58
+ cfg.merge_from_list(args.opts)
59
+ # Set score_threshold for builtin models
60
+ cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold
61
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
62
+ cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold
63
+ cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = 'rand' # load later
64
+ if not args.pred_all_class:
65
+ cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = True
66
+ cfg.freeze()
67
+ return cfg
68
+
69
+
70
+ def get_parser():
71
+ parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
72
+ parser.add_argument(
73
+ "--config-file",
74
+ default="configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml",
75
+ metavar="FILE",
76
+ help="path to config file",
77
+ )
78
+ parser.add_argument("--webcam", help="Take inputs from webcam.")
79
+ parser.add_argument("--cpu", action='store_true', help="Use CPU only.")
80
+ parser.add_argument("--video-input", help="Path to video file.")
81
+ parser.add_argument(
82
+ "--input",
83
+ nargs="+",
84
+ help="A list of space separated input images; "
85
+ "or a single glob pattern such as 'directory/*.jpg'",
86
+ )
87
+ parser.add_argument(
88
+ "--output",
89
+ help="A file or directory to save output visualizations. "
90
+ "If not given, will show output in an OpenCV window.",
91
+ )
92
+ parser.add_argument(
93
+ "--vocabulary",
94
+ default="lvis",
95
+ choices=['lvis', 'openimages', 'objects365', 'coco', 'custom'],
96
+ help="",
97
+ )
98
+ parser.add_argument(
99
+ "--custom_vocabulary",
100
+ default="",
101
+ help="",
102
+ )
103
+ parser.add_argument(
104
+ "--zeroshot_weight_path",
105
+ default=None,
106
+ help="zeroshot text embedding path used during training",
107
+ )
108
+ parser.add_argument("--pred_all_class", action='store_true')
109
+ parser.add_argument(
110
+ "--confidence-threshold",
111
+ type=float,
112
+ default=0.5,
113
+ help="Minimum score for instance predictions to be shown",
114
+ )
115
+ parser.add_argument(
116
+ "--base-cat-threshold",
117
+ type=float,
118
+ default=0.9,
119
+ help="Minimum score for similarity with trained base categories",
120
+ )
121
+ parser.add_argument(
122
+ "--opts",
123
+ help="Modify config options using the command-line 'KEY VALUE' pairs",
124
+ default=[],
125
+ nargs=argparse.REMAINDER,
126
+ )
127
+ return parser
128
+
129
+
130
+ def test_opencv_video_format(codec, file_ext):
131
+ with tempfile.TemporaryDirectory(prefix="video_format_test") as dir:
132
+ filename = os.path.join(dir, "test_file" + file_ext)
133
+ writer = cv2.VideoWriter(
134
+ filename=filename,
135
+ fourcc=cv2.VideoWriter_fourcc(*codec),
136
+ fps=float(30),
137
+ frameSize=(10, 10),
138
+ isColor=True,
139
+ )
140
+ [writer.write(np.zeros((10, 10, 3), np.uint8)) for _ in range(30)]
141
+ writer.release()
142
+ if os.path.isfile(filename):
143
+ return True
144
+ return False
145
+
146
+
147
+ if __name__ == "__main__":
148
+ mp.set_start_method("spawn", force=True)
149
+ args = get_parser().parse_args()
150
+ setup_logger(name="fvcore")
151
+ logger = setup_logger()
152
+ logger.info("Arguments: " + str(args))
153
+
154
+ cfg = setup_cfg(args)
155
+ print(_highlight(cfg.dump(), ".yaml"))
156
+
157
+ demo = VisualizationDemo(cfg, args)
158
+
159
+ if args.input:
160
+ if len(args.input) == 1:
161
+ args.input = glob.glob(os.path.expanduser(args.input[0]))
162
+ assert args.input, "The input path(s) was not found"
163
+ for path in tqdm.tqdm(args.input, disable=not args.output):
164
+ img = read_image(path, format="BGR")
165
+ start_time = time.time()
166
+ predictions, visualized_output = demo.run_on_image(img)
167
+ logger.info(
168
+ "{}: {} in {:.2f}s".format(
169
+ path,
170
+ "detected {} instances".format(len(predictions["instances"]))
171
+ if "instances" in predictions
172
+ else "finished",
173
+ time.time() - start_time,
174
+ )
175
+ )
176
+
177
+ if args.output:
178
+ if os.path.isdir(args.output):
179
+ assert os.path.isdir(args.output), args.output
180
+ out_filename = os.path.join(args.output, os.path.basename(path))
181
+ else:
182
+ assert len(args.input) == 1, "Please specify a directory with args.output"
183
+ out_filename = args.output
184
+ visualized_output.save(out_filename)
185
+ else:
186
+ cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
187
+ cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1])
188
+ if cv2.waitKey(0) == 27:
189
+ break # esc to quit
190
+ elif args.webcam:
191
+ assert args.input is None, "Cannot have both --input and --webcam!"
192
+ assert args.output is None, "output not yet supported with --webcam!"
193
+ if args.webcam == "screen":
194
+ cam = ScreenGrab()
195
+ else:
196
+ cam = cv2.VideoCapture(int(args.webcam))
197
+ for vis in tqdm.tqdm(demo.run_on_video(cam)):
198
+ cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
199
+ cv2.imshow(WINDOW_NAME, vis)
200
+ if cv2.waitKey(1) == 27:
201
+ break # esc to quit
202
+ cam.release()
203
+ cv2.destroyAllWindows()
204
+ elif args.video_input:
205
+ video = cv2.VideoCapture(args.video_input)
206
+ width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
207
+ height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
208
+ frames_per_second = video.get(cv2.CAP_PROP_FPS)
209
+ num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
210
+ basename = os.path.basename(args.video_input)
211
+ codec, file_ext = (
212
+ ("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4")
213
+ )
214
+ if codec == ".mp4v":
215
+ warnings.warn("x264 codec not available, switching to mp4v")
216
+ if args.output:
217
+ if os.path.isdir(args.output):
218
+ output_fname = os.path.join(args.output, basename)
219
+ output_fname = os.path.splitext(output_fname)[0] + file_ext
220
+ else:
221
+ output_fname = args.output
222
+ assert not os.path.isfile(output_fname), output_fname
223
+ output_file = cv2.VideoWriter(
224
+ filename=output_fname,
225
+ # some installation of opencv may not support x264 (due to its license),
226
+ # you can try other format (e.g. MPEG)
227
+ fourcc=cv2.VideoWriter_fourcc(*codec),
228
+ fps=float(frames_per_second),
229
+ frameSize=(width, height),
230
+ isColor=True,
231
+ )
232
+ assert os.path.isfile(args.video_input)
233
+ for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames):
234
+ if args.output:
235
+ output_file.write(vis_frame)
236
+ else:
237
+ cv2.namedWindow(basename, cv2.WINDOW_NORMAL)
238
+ cv2.imshow(basename, vis_frame)
239
+ if cv2.waitKey(1) == 27:
240
+ break # esc to quit
241
+ video.release()
242
+ if args.output:
243
+ output_file.release()
244
+ else:
245
+ cv2.destroyAllWindows()
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ffmpeg
2
+ libsm6
3
+ libxext6
proxydet/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from .modeling.meta_arch import custom_rcnn
3
+ from .modeling.roi_heads import proxydet_roi_heads
4
+ from .modeling.backbone import swintransformer
5
+ from .modeling.backbone import timm
6
+
7
+
8
+ from .data.datasets import lvis_v1
9
+ from .data.datasets import imagenet
10
+ from .data.datasets import cc
11
+ from .data.datasets import objects365
12
+ from .data.datasets import oid
13
+ from .data.datasets import coco_zeroshot
14
+
15
+ try:
16
+ from .modeling.meta_arch import d2_deformable_detr
17
+ except:
18
+ pass
proxydet/cat_names.py ADDED
@@ -0,0 +1 @@
 
 
1
+ lvis_cat_names = ['aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock', 'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet', 'antenna', 'apple', 'applesauce', 'apricot', 'apron', 'aquarium', 'arctic_(type_of_shoe)', 'armband', 'armchair', 'armoire', 'armor', 'artichoke', 'trash_can', 'ashtray', 'asparagus', 'atomizer', 'avocado', 'award', 'awning', 'ax', 'baboon', 'baby_buggy', 'basketball_backboard', 'backpack', 'handbag', 'suitcase', 'bagel', 'bagpipe', 'baguet', 'bait', 'ball', 'ballet_skirt', 'balloon', 'bamboo', 'banana', 'Band_Aid', 'bandage', 'bandanna', 'banjo', 'banner', 'barbell', 'barge', 'barrel', 'barrette', 'barrow', 'baseball_base', 'baseball', 'baseball_bat', 'baseball_cap', 'baseball_glove', 'basket', 'basketball', 'bass_horn', 'bat_(animal)', 'bath_mat', 'bath_towel', 'bathrobe', 'bathtub', 'batter_(food)', 'battery', 'beachball', 'bead', 'bean_curd', 'beanbag', 'beanie', 'bear', 'bed', 'bedpan', 'bedspread', 'cow', 'beef_(food)', 'beeper', 'beer_bottle', 'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt', 'belt_buckle', 'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor', 'billboard', 'binder', 'binoculars', 'bird', 'birdfeeder', 'birdbath', 'birdcage', 'birdhouse', 'birthday_cake', 'birthday_card', 'pirate_flag', 'black_sheep', 'blackberry', 'blackboard', 'blanket', 'blazer', 'blender', 'blimp', 'blinker', 'blouse', 'blueberry', 'gameboard', 'boat', 'bob', 'bobbin', 'bobby_pin', 'boiled_egg', 'bolo_tie', 'deadbolt', 'bolt', 'bonnet', 'book', 'bookcase', 'booklet', 'bookmark', 'boom_microphone', 'boot', 'bottle', 'bottle_opener', 'bouquet', 'bow_(weapon)', 'bow_(decorative_ribbons)', 'bow-tie', 'bowl', 'pipe_bowl', 'bowler_hat', 'bowling_ball', 'box', 'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere', 'bread-bin', 'bread', 'breechcloth', 'bridal_gown', 'briefcase', 'broccoli', 'broach', 'broom', 'brownie', 'brussels_sprouts', 'bubble_gum', 'bucket', 'horse_buggy', 'bull', 'bulldog', 'bulldozer', 'bullet_train', 'bulletin_board', 'bulletproof_vest', 'bullhorn', 'bun', 'bunk_bed', 'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car', 'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf', 'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)', 'can', 'can_opener', 'candle', 'candle_holder', 'candy_bar', 'candy_cane', 'walking_cane', 'canister', 'canoe', 'cantaloup', 'canteen', 'cap_(headwear)', 'bottle_cap', 'cape', 'cappuccino', 'car_(automobile)', 'railcar_(part_of_a_train)', 'elevator_car', 'car_battery', 'identity_card', 'card', 'cardigan', 'cargo_ship', 'carnation', 'horse_carriage', 'carrot', 'tote_bag', 'cart', 'carton', 'cash_register', 'casserole', 'cassette', 'cast', 'cat', 'cauliflower', 'cayenne_(spice)', 'CD_player', 'celery', 'cellular_telephone', 'chain_mail', 'chair', 'chaise_longue', 'chalice', 'chandelier', 'chap', 'checkbook', 'checkerboard', 'cherry', 'chessboard', 'chicken_(animal)', 'chickpea', 'chili_(vegetable)', 'chime', 'chinaware', 'crisp_(potato_chip)', 'poker_chip', 'chocolate_bar', 'chocolate_cake', 'chocolate_milk', 'chocolate_mousse', 'choker', 'chopping_board', 'chopstick', 'Christmas_tree', 'slide', 'cider', 'cigar_box', 'cigarette', 'cigarette_case', 'cistern', 'clarinet', 'clasp', 'cleansing_agent', 'cleat_(for_securing_rope)', 'clementine', 'clip', 'clipboard', 'clippers_(for_plants)', 'cloak', 'clock', 'clock_tower', 'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster', 'coat', 'coat_hanger', 'coatrack', 'cock', 'cockroach', 'cocoa_(beverage)', 'coconut', 'coffee_maker', 'coffee_table', 'coffeepot', 'coil', 'coin', 'colander', 'coleslaw', 'coloring_material', 'combination_lock', 'pacifier', 'comic_book', 'compass', 'computer_keyboard', 'condiment', 'cone', 'control', 'convertible_(automobile)', 'sofa_bed', 'cooker', 'cookie', 'cooking_utensil', 'cooler_(for_food)', 'cork_(bottle_plug)', 'corkboard', 'corkscrew', 'edible_corn', 'cornbread', 'cornet', 'cornice', 'cornmeal', 'corset', 'costume', 'cougar', 'coverall', 'cowbell', 'cowboy_hat', 'crab_(animal)', 'crabmeat', 'cracker', 'crape', 'crate', 'crayon', 'cream_pitcher', 'crescent_roll', 'crib', 'crock_pot', 'crossbar', 'crouton', 'crow', 'crowbar', 'crown', 'crucifix', 'cruise_ship', 'police_cruiser', 'crumb', 'crutch', 'cub_(animal)', 'cube', 'cucumber', 'cufflink', 'cup', 'trophy_cup', 'cupboard', 'cupcake', 'hair_curler', 'curling_iron', 'curtain', 'cushion', 'cylinder', 'cymbal', 'dagger', 'dalmatian', 'dartboard', 'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk', 'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table', 'tux', 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher', 'dishwasher_detergent', 'dispenser', 'diving_board', 'Dixie_cup', 'dog', 'dog_collar', 'doll', 'dollar', 'dollhouse', 'dolphin', 'domestic_ass', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly', 'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit', 'dresser', 'drill', 'drone', 'dropper', 'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling', 'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan', 'eagle', 'earphone', 'earplug', 'earring', 'easel', 'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater', 'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk', 'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan', 'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)', 'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)', 'fire_alarm', 'fire_engine', 'fire_extinguisher', 'fire_hose', 'fireplace', 'fireplug', 'first-aid_kit', 'fish', 'fish_(food)', 'fishbowl', 'fishing_rod', 'flag', 'flagpole', 'flamingo', 'flannel', 'flap', 'flash', 'flashlight', 'fleece', 'flip-flop_(sandal)', 'flipper_(footwear)', 'flower_arrangement', 'flute_glass', 'foal', 'folding_chair', 'food_processor', 'football_(American)', 'football_helmet', 'footstool', 'fork', 'forklift', 'freight_car', 'French_toast', 'freshener', 'frisbee', 'frog', 'fruit_juice', 'frying_pan', 'fudge', 'funnel', 'futon', 'gag', 'garbage', 'garbage_truck', 'garden_hose', 'gargle', 'gargoyle', 'garlic', 'gasmask', 'gazelle', 'gelatin', 'gemstone', 'generator', 'giant_panda', 'gift_wrap', 'ginger', 'giraffe', 'cincture', 'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles', 'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose', 'gorilla', 'gourd', 'grape', 'grater', 'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle', 'grill', 'grits', 'grizzly', 'grocery_bag', 'guitar', 'gull', 'gun', 'hairbrush', 'hairnet', 'hairpin', 'halter_top', 'ham', 'hamburger', 'hammer', 'hammock', 'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel', 'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw', 'hardback_book', 'harmonium', 'hat', 'hatbox', 'veil', 'headband', 'headboard', 'headlight', 'headscarf', 'headset', 'headstall_(for_horses)', 'heart', 'heater', 'helicopter', 'helmet', 'heron', 'highchair', 'hinge', 'hippopotamus', 'hockey_stick', 'hog', 'home_plate_(baseball)', 'honey', 'fume_hood', 'hook', 'hookah', 'hornet', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce', 'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear', 'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate', 'igniter', 'inhaler', 'iPod', 'iron_(for_clothing)', 'ironing_board', 'jacket', 'jam', 'jar', 'jean', 'jeep', 'jelly_bean', 'jersey', 'jet_plane', 'jewel', 'jewelry', 'joystick', 'jumpsuit', 'kayak', 'keg', 'kennel', 'kettle', 'key', 'keycard', 'kilt', 'kimono', 'kitchen_sink', 'kitchen_table', 'kite', 'kitten', 'kiwi_fruit', 'knee_pad', 'knife', 'knitting_needle', 'knob', 'knocker_(on_a_door)', 'koala', 'lab_coat', 'ladder', 'ladle', 'ladybug', 'lamb_(animal)', 'lamb-chop', 'lamp', 'lamppost', 'lampshade', 'lantern', 'lanyard', 'laptop_computer', 'lasagna', 'latch', 'lawn_mower', 'leather', 'legging_(clothing)', 'Lego', 'legume', 'lemon', 'lemonade', 'lettuce', 'license_plate', 'life_buoy', 'life_jacket', 'lightbulb', 'lightning_rod', 'lime', 'limousine', 'lion', 'lip_balm', 'liquor', 'lizard', 'log', 'lollipop', 'speaker_(stero_equipment)', 'loveseat', 'machine_gun', 'magazine', 'magnet', 'mail_slot', 'mailbox_(at_home)', 'mallard', 'mallet', 'mammoth', 'manatee', 'mandarin_orange', 'manger', 'manhole', 'map', 'marker', 'martini', 'mascot', 'mashed_potato', 'masher', 'mask', 'mast', 'mat_(gym_equipment)', 'matchbox', 'mattress', 'measuring_cup', 'measuring_stick', 'meatball', 'medicine', 'melon', 'microphone', 'microscope', 'microwave_oven', 'milestone', 'milk', 'milk_can', 'milkshake', 'minivan', 'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)', 'money', 'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor', 'motor_scooter', 'motor_vehicle', 'motorcycle', 'mound_(baseball)', 'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom', 'music_stool', 'musical_instrument', 'nailfile', 'napkin', 'neckerchief', 'necklace', 'necktie', 'needle', 'nest', 'newspaper', 'newsstand', 'nightshirt', 'nosebag_(for_animals)', 'noseband_(for_animals)', 'notebook', 'notepad', 'nut', 'nutcracker', 'oar', 'octopus_(food)', 'octopus_(animal)', 'oil_lamp', 'olive_oil', 'omelet', 'onion', 'orange_(fruit)', 'orange_juice', 'ostrich', 'ottoman', 'oven', 'overalls_(clothing)', 'owl', 'packet', 'inkpad', 'pad', 'paddle', 'padlock', 'paintbrush', 'painting', 'pajamas', 'palette', 'pan_(for_cooking)', 'pan_(metal_container)', 'pancake', 'pantyhose', 'papaya', 'paper_plate', 'paper_towel', 'paperback_book', 'paperweight', 'parachute', 'parakeet', 'parasail_(sports)', 'parasol', 'parchment', 'parka', 'parking_meter', 'parrot', 'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport', 'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter', 'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'wooden_leg', 'pegboard', 'pelican', 'pen', 'pencil', 'pencil_box', 'pencil_sharpener', 'pendulum', 'penguin', 'pennant', 'penny_(coin)', 'pepper', 'pepper_mill', 'perfume', 'persimmon', 'person', 'pet', 'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano', 'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow', 'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball', 'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)', 'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat', 'plate', 'platter', 'playpen', 'pliers', 'plow_(farm_equipment)', 'plume', 'pocket_watch', 'pocketknife', 'poker_(fire_stirring_tool)', 'pole', 'polo_shirt', 'poncho', 'pony', 'pool_table', 'pop_(soda)', 'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot', 'potato', 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn', 'pretzel', 'printer', 'projectile_(weapon)', 'projector', 'propeller', 'prune', 'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher', 'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit', 'race_car', 'racket', 'radar', 'radiator', 'radio_receiver', 'radish', 'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat', 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt', 'recliner', 'record_player', 'reflector', 'remote_control', 'rhinoceros', 'rib_(food)', 'rifle', 'ring', 'river_boat', 'road_map', 'robe', 'rocking_chair', 'rodent', 'roller_skate', 'Rollerblade', 'rolling_pin', 'root_beer', 'router_(computer_equipment)', 'rubber_band', 'runner_(carpet)', 'plastic_bag', 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag', 'safety_pin', 'sail', 'salad', 'salad_plate', 'salami', 'salmon_(fish)', 'salmon_(food)', 'salsa', 'saltshaker', 'sandal_(type_of_shoe)', 'sandwich', 'satchel', 'saucepan', 'saucer', 'sausage', 'sawhorse', 'saxophone', 'scale_(measuring_instrument)', 'scarecrow', 'scarf', 'school_bus', 'scissors', 'scoreboard', 'scraper', 'screwdriver', 'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane', 'seashell', 'sewing_machine', 'shaker', 'shampoo', 'shark', 'sharpener', 'Sharpie', 'shaver_(electric)', 'shaving_cream', 'shawl', 'shears', 'sheep', 'shepherd_dog', 'sherbert', 'shield', 'shirt', 'shoe', 'shopping_bag', 'shopping_cart', 'short_pants', 'shot_glass', 'shoulder_bag', 'shovel', 'shower_head', 'shower_cap', 'shower_curtain', 'shredder_(for_paper)', 'signboard', 'silo', 'sink', 'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka', 'ski_pole', 'skirt', 'skullcap', 'sled', 'sleeping_bag', 'sling_(bandage)', 'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman', 'snowmobile', 'soap', 'soccer_ball', 'sock', 'sofa', 'softball', 'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon', 'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)', 'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'crawfish', 'sponge', 'spoon', 'sportswear', 'spotlight', 'squid_(food)', 'squirrel', 'stagecoach', 'stapler_(stapling_machine)', 'starfish', 'statue_(sculpture)', 'steak_(food)', 'steak_knife', 'steering_wheel', 'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew', 'stirrer', 'stirrup', 'stool', 'stop_sign', 'brake_light', 'stove', 'strainer', 'strap', 'straw_(for_drinking)', 'strawberry', 'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer', 'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower', 'sunglasses', 'sunhat', 'surfboard', 'sushi', 'mop', 'sweat_pants', 'sweatband', 'sweater', 'sweatshirt', 'sweet_potato', 'swimsuit', 'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table', 'table', 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag', 'taillight', 'tambourine', 'army_tank', 'tank_(storage_vessel)', 'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure', 'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup', 'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth', 'telephone_pole', 'telephoto_lens', 'television_camera', 'television_set', 'tennis_ball', 'tennis_racket', 'tequila', 'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread', 'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer', 'tinfoil', 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster', 'toaster_oven', 'toilet', 'toilet_tissue', 'tomato', 'tongs', 'toolbox', 'toothbrush', 'toothpaste', 'toothpick', 'cover', 'tortilla', 'tow_truck', 'towel', 'towel_rack', 'toy', 'tractor_(farm_equipment)', 'traffic_light', 'dirt_bike', 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', 'tray', 'trench_coat', 'triangle_(musical_instrument)', 'tricycle', 'tripod', 'trousers', 'truck', 'truffle_(chocolate)', 'trunk', 'vat', 'turban', 'turkey_(food)', 'turnip', 'turtle', 'turtleneck_(clothing)', 'typewriter', 'umbrella', 'underwear', 'unicycle', 'urinal', 'urn', 'vacuum_cleaner', 'vase', 'vending_machine', 'vent', 'vest', 'videotape', 'vinegar', 'violin', 'vodka', 'volleyball', 'vulture', 'waffle', 'waffle_iron', 'wagon', 'wagon_wheel', 'walking_stick', 'wall_clock', 'wall_socket', 'wallet', 'walrus', 'wardrobe', 'washbasin', 'automatic_washer', 'watch', 'water_bottle', 'water_cooler', 'water_faucet', 'water_heater', 'water_jug', 'water_gun', 'water_scooter', 'water_ski', 'water_tower', 'watering_can', 'watermelon', 'weathervane', 'webcam', 'wedding_cake', 'wedding_ring', 'wet_suit', 'wheel', 'wheelchair', 'whipped_cream', 'whistle', 'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)', 'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket', 'wineglass', 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon', 'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', 'yogurt', 'yoke_(animal_equipment)', 'zebra', 'zucchini']
proxydet/config.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ '''
3
+ Modifications Copyright (c) 2024-present NAVER Corp, Apache License v2.0
4
+ original source: https://github.com/facebookresearch/Detic/blob/main/detic/config.py
5
+ '''
6
+ from detectron2.config import CfgNode as CN
7
+
8
+ def add_proxydet_config(cfg):
9
+ _C = cfg
10
+
11
+ _C.WITH_IMAGE_LABELS = False # Turn on co-training with classification data
12
+
13
+ # Open-vocabulary classifier
14
+ _C.MODEL.ROI_BOX_HEAD.USE_ZEROSHOT_CLS = False # Use fixed classifier for open-vocabulary detection
15
+ _C.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = 'datasets/metadata/lvis_v1_clip_a+cname.npy'
16
+ _C.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_DIM = 512
17
+ _C.MODEL.ROI_BOX_HEAD.NORM_WEIGHT = True
18
+ _C.MODEL.ROI_BOX_HEAD.NORM_TEMP = 50.0
19
+ _C.MODEL.ROI_BOX_HEAD.IGNORE_ZERO_CATS = False
20
+ _C.MODEL.ROI_BOX_HEAD.USE_BIAS = 0.0 # >= 0: not use
21
+
22
+ _C.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE = False # CenterNet2
23
+ _C.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE = False
24
+ _C.MODEL.ROI_BOX_HEAD.PRIOR_PROB = 0.01
25
+ _C.MODEL.ROI_BOX_HEAD.USE_FED_LOSS = False # Federated Loss
26
+ _C.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH = \
27
+ 'datasets/metadata/lvis_v1_train_cat_info.json'
28
+ _C.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CAT = 50
29
+ _C.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT = 0.5
30
+
31
+ # Classification data configs
32
+ _C.MODEL.ROI_BOX_HEAD.IMAGE_LABEL_LOSS = 'max_size' # max, softmax, sum
33
+ _C.MODEL.ROI_BOX_HEAD.IMAGE_LOSS_WEIGHT = 0.1
34
+ _C.MODEL.ROI_BOX_HEAD.IMAGE_BOX_SIZE = 1.0
35
+ _C.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX = False # Used for image-box loss and caption loss
36
+ _C.MODEL.ROI_BOX_HEAD.WS_NUM_PROPS = 128 # num proposals for image-labeled data
37
+ _C.MODEL.ROI_BOX_HEAD.WITH_SOFTMAX_PROP = False # Used for WSDDN
38
+ _C.MODEL.ROI_BOX_HEAD.CAPTION_WEIGHT = 1.0 # Caption loss weight
39
+ _C.MODEL.ROI_BOX_HEAD.NEG_CAP_WEIGHT = 0.125 # Caption loss hyper-parameter
40
+ _C.MODEL.ROI_BOX_HEAD.ADD_FEATURE_TO_PROP = False # Used for WSDDN
41
+ _C.MODEL.ROI_BOX_HEAD.SOFTMAX_WEAK_LOSS = False # Used when USE_SIGMOID_CE is False
42
+
43
+ _C.MODEL.ROI_HEADS.MASK_WEIGHT = 1.0
44
+ _C.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = False # For demo only
45
+
46
+ # Class-wise Multi-Modal Mixup
47
+ _C.MODEL.ROI_BOX_HEAD.USE_REGIONAL_EMBEDDING = False
48
+ _C.MODEL.ROI_HEADS.BASE_CAT_MASK = "datasets/metadata/lvis_v1_base_cat_mask.npy"
49
+ _C.MODEL.ROI_HEADS.CMM = CN()
50
+ _C.MODEL.ROI_HEADS.CMM.MIXUP_STAGE = []
51
+ _C.MODEL.ROI_HEADS.CMM.MIXUP_STAGE_TEST = None
52
+ _C.MODEL.ROI_HEADS.CMM.MIXUP_BETA = 1.0
53
+ _C.MODEL.ROI_HEADS.CMM.LOSS = "l1"
54
+ _C.MODEL.ROI_HEADS.CMM.LOSS_WEIGHT = 1.0
55
+ _C.MODEL.ROI_HEADS.CMM.SEPARATED_BRANCH = False
56
+ _C.MODEL.ROI_HEADS.CMM.BASE_ALPHA = 0.5
57
+ _C.MODEL.ROI_HEADS.CMM.NOVEL_BETA = 0.5
58
+ _C.MODEL.ROI_HEADS.CMM.USE_INL = False
59
+ _C.MODEL.ROI_HEADS.CMM.PROTOTYPE = "center"
60
+ _C.MODEL.ROI_HEADS.CMM.PROTOTYPE_TEMP = 1.0
61
+ _C.MODEL.ROI_HEADS.CMM.CLASSIFIER_TEMP = None
62
+ _C.MODEL.ROI_HEADS.CMM.USE_SIGMOID_CE = True
63
+
64
+ # Caption losses
65
+ _C.MODEL.CAP_BATCH_RATIO = 4 # Ratio between detection data and caption data
66
+ _C.MODEL.WITH_CAPTION = False
67
+ _C.MODEL.SYNC_CAPTION_BATCH = False # synchronize across GPUs to enlarge # "classes"
68
+
69
+ # dynamic class sampling when training with 21K classes
70
+ _C.MODEL.DYNAMIC_CLASSIFIER = False
71
+ _C.MODEL.NUM_SAMPLE_CATS = 50
72
+
73
+ # Different classifiers in testing, used in cross-dataset evaluation
74
+ _C.MODEL.RESET_CLS_TESTS = False
75
+ _C.MODEL.TEST_CLASSIFIERS = []
76
+ _C.MODEL.TEST_NUM_CLASSES = []
77
+
78
+ # Backbones
79
+ _C.MODEL.SWIN = CN()
80
+ _C.MODEL.SWIN.SIZE = 'T' # 'T', 'S', 'B'
81
+ _C.MODEL.SWIN.USE_CHECKPOINT = False
82
+ _C.MODEL.SWIN.OUT_FEATURES = (1, 2, 3) # FPN stride 8 - 32
83
+
84
+ _C.MODEL.TIMM = CN()
85
+ _C.MODEL.TIMM.BASE_NAME = 'resnet50'
86
+ _C.MODEL.TIMM.OUT_LEVELS = (3, 4, 5)
87
+ _C.MODEL.TIMM.NORM = 'FrozenBN'
88
+ _C.MODEL.TIMM.FREEZE_AT = 0
89
+ _C.MODEL.TIMM.PRETRAINED = False
90
+ _C.MODEL.DATASET_LOSS_WEIGHT = []
91
+
92
+ # Multi-dataset dataloader
93
+ _C.DATALOADER.DATASET_RATIO = [1, 1] # sample ratio
94
+ _C.DATALOADER.USE_RFS = [False, False]
95
+ _C.DATALOADER.MULTI_DATASET_GROUPING = False # Always true when multi-dataset is enabled
96
+ _C.DATALOADER.DATASET_ANN = ['box', 'box'] # Annotation type of each dataset
97
+ _C.DATALOADER.USE_DIFF_BS_SIZE = False # Use different batchsize for each dataset
98
+ _C.DATALOADER.DATASET_BS = [8, 32] # Used when USE_DIFF_BS_SIZE is on
99
+ _C.DATALOADER.DATASET_INPUT_SIZE = [896, 384] # Used when USE_DIFF_BS_SIZE is on
100
+ _C.DATALOADER.DATASET_INPUT_SCALE = [(0.1, 2.0), (0.5, 1.5)] # Used when USE_DIFF_BS_SIZE is on
101
+ _C.DATALOADER.DATASET_MIN_SIZES = [(640, 800), (320, 400)] # Used when USE_DIFF_BS_SIZE is on
102
+ _C.DATALOADER.DATASET_MAX_SIZES = [1333, 667] # Used when USE_DIFF_BS_SIZE is on
103
+ _C.DATALOADER.USE_TAR_DATASET = False # for ImageNet-21K, directly reading from unziped files
104
+ _C.DATALOADER.TARFILE_PATH = 'datasets/imagenet/metadata-22k/tar_files.npy'
105
+ _C.DATALOADER.TAR_INDEX_DIR = 'datasets/imagenet/metadata-22k/tarindex_npy'
106
+
107
+ _C.SOLVER.USE_CUSTOM_SOLVER = False
108
+ _C.SOLVER.OPTIMIZER = 'SGD'
109
+ _C.SOLVER.BACKBONE_MULTIPLIER = 1.0 # Used in DETR
110
+ _C.SOLVER.CUSTOM_MULTIPLIER = 1.0 # Used in DETR
111
+ _C.SOLVER.CUSTOM_MULTIPLIER_NAME = [] # Used in DETR
112
+
113
+ # Deformable DETR
114
+ _C.MODEL.DETR = CN()
115
+ _C.MODEL.DETR.NUM_CLASSES = 80
116
+ _C.MODEL.DETR.FROZEN_WEIGHTS = '' # For Segmentation
117
+ _C.MODEL.DETR.GIOU_WEIGHT = 2.0
118
+ _C.MODEL.DETR.L1_WEIGHT = 5.0
119
+ _C.MODEL.DETR.DEEP_SUPERVISION = True
120
+ _C.MODEL.DETR.NO_OBJECT_WEIGHT = 0.1
121
+ _C.MODEL.DETR.CLS_WEIGHT = 2.0
122
+ _C.MODEL.DETR.NUM_FEATURE_LEVELS = 4
123
+ _C.MODEL.DETR.TWO_STAGE = False
124
+ _C.MODEL.DETR.WITH_BOX_REFINE = False
125
+ _C.MODEL.DETR.FOCAL_ALPHA = 0.25
126
+ _C.MODEL.DETR.NHEADS = 8
127
+ _C.MODEL.DETR.DROPOUT = 0.1
128
+ _C.MODEL.DETR.DIM_FEEDFORWARD = 2048
129
+ _C.MODEL.DETR.ENC_LAYERS = 6
130
+ _C.MODEL.DETR.DEC_LAYERS = 6
131
+ _C.MODEL.DETR.PRE_NORM = False
132
+ _C.MODEL.DETR.HIDDEN_DIM = 256
133
+ _C.MODEL.DETR.NUM_OBJECT_QUERIES = 100
134
+
135
+ _C.MODEL.DETR.USE_FED_LOSS = False
136
+ _C.MODEL.DETR.WEAK_WEIGHT = 0.1
137
+
138
+ _C.INPUT.CUSTOM_AUG = ''
139
+ _C.INPUT.TRAIN_SIZE = 640
140
+ _C.INPUT.TEST_SIZE = 640
141
+ _C.INPUT.SCALE_RANGE = (0.1, 2.)
142
+ # 'default' for fixed short/ long edge, 'square' for max size=INPUT.SIZE
143
+ _C.INPUT.TEST_INPUT_TYPE = 'default'
144
+
145
+ _C.FIND_UNUSED_PARAM = True
146
+ _C.EVAL_PRED_AR = False
147
+ _C.EVAL_PROPOSAL_AR = False
148
+ _C.EVAL_CAT_SPEC_AR = False
149
+ _C.IS_DEBUG = False
150
+ _C.QUICK_DEBUG = False
151
+ _C.FP16 = False
152
+ _C.EVAL_AP_FIX = False
153
+ _C.GEN_PSEDO_LABELS = False
154
+ _C.SAVE_DEBUG_PATH = 'output/save_debug/'
155
+
156
+ _C.EVAL_START = 0
proxydet/custom_solver.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ from enum import Enum
3
+ import itertools
4
+ from typing import Any, Callable, Dict, Iterable, List, Set, Type, Union
5
+ import torch
6
+
7
+ from detectron2.config import CfgNode
8
+
9
+ from detectron2.solver.build import maybe_add_gradient_clipping
10
+
11
+ def match_name_keywords(n, name_keywords):
12
+ out = False
13
+ for b in name_keywords:
14
+ if b in n:
15
+ out = True
16
+ break
17
+ return out
18
+
19
+ def build_custom_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
20
+ """
21
+ Build an optimizer from config.
22
+ """
23
+ params: List[Dict[str, Any]] = []
24
+ memo: Set[torch.nn.parameter.Parameter] = set()
25
+ custom_multiplier_name = cfg.SOLVER.CUSTOM_MULTIPLIER_NAME
26
+ optimizer_type = cfg.SOLVER.OPTIMIZER
27
+ for key, value in model.named_parameters(recurse=True):
28
+ if not value.requires_grad:
29
+ continue
30
+ # Avoid duplicating parameters
31
+ if value in memo:
32
+ continue
33
+ memo.add(value)
34
+ lr = cfg.SOLVER.BASE_LR
35
+ weight_decay = cfg.SOLVER.WEIGHT_DECAY
36
+ if "backbone" in key:
37
+ lr = lr * cfg.SOLVER.BACKBONE_MULTIPLIER
38
+ if match_name_keywords(key, custom_multiplier_name):
39
+ lr = lr * cfg.SOLVER.CUSTOM_MULTIPLIER
40
+ print('Costum LR', key, lr)
41
+ param = {"params": [value], "lr": lr}
42
+ if optimizer_type != 'ADAMW':
43
+ param['weight_decay'] = weight_decay
44
+ params += [param]
45
+
46
+ def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class
47
+ # detectron2 doesn't have full model gradient clipping now
48
+ clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
49
+ enable = (
50
+ cfg.SOLVER.CLIP_GRADIENTS.ENABLED
51
+ and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
52
+ and clip_norm_val > 0.0
53
+ )
54
+
55
+ class FullModelGradientClippingOptimizer(optim):
56
+ def step(self, closure=None):
57
+ all_params = itertools.chain(*[x["params"] for x in self.param_groups])
58
+ torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
59
+ super().step(closure=closure)
60
+
61
+ return FullModelGradientClippingOptimizer if enable else optim
62
+
63
+
64
+ if optimizer_type == 'SGD':
65
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
66
+ params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM,
67
+ nesterov=cfg.SOLVER.NESTEROV
68
+ )
69
+ elif optimizer_type == 'ADAMW':
70
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
71
+ params, cfg.SOLVER.BASE_LR,
72
+ weight_decay=cfg.SOLVER.WEIGHT_DECAY
73
+ )
74
+ else:
75
+ raise NotImplementedError(f"no optimizer type {optimizer_type}")
76
+ if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
77
+ optimizer = maybe_add_gradient_clipping(cfg, optimizer)
78
+ return optimizer
proxydet/data/custom_build_augmentation.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ import numpy as np
4
+ import pycocotools.mask as mask_util
5
+ import torch
6
+ from fvcore.common.file_io import PathManager
7
+ from PIL import Image
8
+
9
+
10
+ from detectron2.data import transforms as T
11
+ from .transforms.custom_augmentation_impl import EfficientDetResizeCrop
12
+
13
+ def build_custom_augmentation(cfg, is_train, scale=None, size=None, \
14
+ min_size=None, max_size=None):
15
+ """
16
+ Create a list of default :class:`Augmentation` from config.
17
+ Now it includes resizing and flipping.
18
+
19
+ Returns:
20
+ list[Augmentation]
21
+ """
22
+ if cfg.INPUT.CUSTOM_AUG == 'ResizeShortestEdge':
23
+ if is_train:
24
+ min_size = cfg.INPUT.MIN_SIZE_TRAIN if min_size is None else min_size
25
+ max_size = cfg.INPUT.MAX_SIZE_TRAIN if max_size is None else max_size
26
+ sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
27
+ else:
28
+ min_size = cfg.INPUT.MIN_SIZE_TEST
29
+ max_size = cfg.INPUT.MAX_SIZE_TEST
30
+ sample_style = "choice"
31
+ augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)]
32
+ elif cfg.INPUT.CUSTOM_AUG == 'EfficientDetResizeCrop':
33
+ if is_train:
34
+ scale = cfg.INPUT.SCALE_RANGE if scale is None else scale
35
+ size = cfg.INPUT.TRAIN_SIZE if size is None else size
36
+ else:
37
+ scale = (1, 1)
38
+ size = cfg.INPUT.TEST_SIZE
39
+ augmentation = [EfficientDetResizeCrop(size, scale)]
40
+ else:
41
+ assert 0, cfg.INPUT.CUSTOM_AUG
42
+
43
+ if is_train:
44
+ augmentation.append(T.RandomFlip())
45
+ return augmentation
46
+
47
+
48
+ build_custom_transform_gen = build_custom_augmentation
49
+ """
50
+ Alias for backward-compatibility.
51
+ """
proxydet/data/custom_dataset_dataloader.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Part of the code is from https://github.com/xingyizhou/UniDet/blob/master/projects/UniDet/unidet/data/multi_dataset_dataloader.py (Apache-2.0 License)
3
+ import copy
4
+ import logging
5
+ import numpy as np
6
+ import operator
7
+ import torch
8
+ import torch.utils.data
9
+ import json
10
+ from detectron2.utils.comm import get_world_size
11
+ from detectron2.utils.logger import _log_api_usage, log_first_n
12
+
13
+ from detectron2.config import configurable
14
+ from detectron2.data import samplers
15
+ from torch.utils.data.sampler import BatchSampler, Sampler
16
+ from detectron2.data.common import DatasetFromList, MapDataset
17
+ from detectron2.data.dataset_mapper import DatasetMapper
18
+ from detectron2.data.build import get_detection_dataset_dicts, build_batch_data_loader
19
+ from detectron2.data.samplers import TrainingSampler, RepeatFactorTrainingSampler
20
+ from detectron2.data.build import worker_init_reset_seed, print_instances_class_histogram
21
+ from detectron2.data.build import filter_images_with_only_crowd_annotations
22
+ from detectron2.data.build import filter_images_with_few_keypoints
23
+ from detectron2.data.build import check_metadata_consistency
24
+ from detectron2.data.catalog import MetadataCatalog, DatasetCatalog
25
+ from detectron2.utils import comm
26
+ import itertools
27
+ import math
28
+ from collections import defaultdict
29
+ from typing import Optional
30
+
31
+
32
+ def _custom_train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
33
+ sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
34
+ if 'MultiDataset' in sampler_name:
35
+ dataset_dicts = get_detection_dataset_dicts_with_source(
36
+ cfg.DATASETS.TRAIN,
37
+ filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
38
+ min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
39
+ if cfg.MODEL.KEYPOINT_ON else 0,
40
+ proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
41
+ )
42
+ else:
43
+ dataset_dicts = get_detection_dataset_dicts(
44
+ cfg.DATASETS.TRAIN,
45
+ filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
46
+ min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
47
+ if cfg.MODEL.KEYPOINT_ON else 0,
48
+ proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
49
+ )
50
+
51
+ if mapper is None:
52
+ mapper = DatasetMapper(cfg, True)
53
+
54
+ if sampler is not None:
55
+ pass
56
+ elif sampler_name == "TrainingSampler":
57
+ sampler = TrainingSampler(len(dataset))
58
+ elif sampler_name == "MultiDatasetSampler":
59
+ sampler = MultiDatasetSampler(
60
+ dataset_dicts,
61
+ dataset_ratio = cfg.DATALOADER.DATASET_RATIO,
62
+ use_rfs = cfg.DATALOADER.USE_RFS,
63
+ dataset_ann = cfg.DATALOADER.DATASET_ANN,
64
+ repeat_threshold = cfg.DATALOADER.REPEAT_THRESHOLD,
65
+ )
66
+ elif sampler_name == "RepeatFactorTrainingSampler":
67
+ repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
68
+ dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD
69
+ )
70
+ sampler = RepeatFactorTrainingSampler(repeat_factors)
71
+ else:
72
+ raise ValueError("Unknown training sampler: {}".format(sampler_name))
73
+
74
+ return {
75
+ "dataset": dataset_dicts,
76
+ "sampler": sampler,
77
+ "mapper": mapper,
78
+ "total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
79
+ "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
80
+ "num_workers": cfg.DATALOADER.NUM_WORKERS,
81
+ 'multi_dataset_grouping': cfg.DATALOADER.MULTI_DATASET_GROUPING,
82
+ 'use_diff_bs_size': cfg.DATALOADER.USE_DIFF_BS_SIZE,
83
+ 'dataset_bs': cfg.DATALOADER.DATASET_BS,
84
+ 'num_datasets': len(cfg.DATASETS.TRAIN)
85
+ }
86
+
87
+
88
+ @configurable(from_config=_custom_train_loader_from_config)
89
+ def build_custom_train_loader(
90
+ dataset, *, mapper, sampler,
91
+ total_batch_size=16,
92
+ aspect_ratio_grouping=True,
93
+ num_workers=0,
94
+ num_datasets=1,
95
+ multi_dataset_grouping=False,
96
+ use_diff_bs_size=False,
97
+ dataset_bs=[]
98
+ ):
99
+ """
100
+ Modified from detectron2.data.build.build_custom_train_loader, but supports
101
+ different samplers
102
+ """
103
+ if isinstance(dataset, list):
104
+ dataset = DatasetFromList(dataset, copy=False)
105
+ if mapper is not None:
106
+ dataset = MapDataset(dataset, mapper)
107
+ if sampler is None:
108
+ sampler = TrainingSampler(len(dataset))
109
+ assert isinstance(sampler, torch.utils.data.sampler.Sampler)
110
+ if multi_dataset_grouping:
111
+ return build_multi_dataset_batch_data_loader(
112
+ use_diff_bs_size,
113
+ dataset_bs,
114
+ dataset,
115
+ sampler,
116
+ total_batch_size,
117
+ num_datasets=num_datasets,
118
+ num_workers=num_workers,
119
+ )
120
+ else:
121
+ return build_batch_data_loader(
122
+ dataset,
123
+ sampler,
124
+ total_batch_size,
125
+ aspect_ratio_grouping=aspect_ratio_grouping,
126
+ num_workers=num_workers,
127
+ )
128
+
129
+
130
+ def build_multi_dataset_batch_data_loader(
131
+ use_diff_bs_size, dataset_bs,
132
+ dataset, sampler, total_batch_size, num_datasets, num_workers=0
133
+ ):
134
+ """
135
+ """
136
+ world_size = get_world_size()
137
+ assert (
138
+ total_batch_size > 0 and total_batch_size % world_size == 0
139
+ ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
140
+ total_batch_size, world_size
141
+ )
142
+
143
+ batch_size = total_batch_size // world_size
144
+ data_loader = torch.utils.data.DataLoader(
145
+ dataset,
146
+ sampler=sampler,
147
+ num_workers=num_workers,
148
+ batch_sampler=None,
149
+ collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
150
+ worker_init_fn=worker_init_reset_seed,
151
+ ) # yield individual mapped dict
152
+ if use_diff_bs_size:
153
+ return DIFFMDAspectRatioGroupedDataset(
154
+ data_loader, dataset_bs, num_datasets)
155
+ else:
156
+ return MDAspectRatioGroupedDataset(
157
+ data_loader, batch_size, num_datasets)
158
+
159
+
160
+ def get_detection_dataset_dicts_with_source(
161
+ dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None
162
+ ):
163
+ assert len(dataset_names)
164
+ dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]
165
+ for dataset_name, dicts in zip(dataset_names, dataset_dicts):
166
+ assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
167
+
168
+ for source_id, (dataset_name, dicts) in \
169
+ enumerate(zip(dataset_names, dataset_dicts)):
170
+ assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
171
+ for d in dicts:
172
+ d['dataset_source'] = source_id
173
+
174
+ if "annotations" in dicts[0]:
175
+ try:
176
+ class_names = MetadataCatalog.get(dataset_name).thing_classes
177
+ check_metadata_consistency("thing_classes", dataset_name)
178
+ print_instances_class_histogram(dicts, class_names)
179
+ except AttributeError: # class names are not available for this dataset
180
+ pass
181
+
182
+ assert proposal_files is None
183
+
184
+ dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
185
+
186
+ has_instances = "annotations" in dataset_dicts[0]
187
+ if filter_empty and has_instances:
188
+ dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
189
+ if min_keypoints > 0 and has_instances:
190
+ dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
191
+
192
+ return dataset_dicts
193
+
194
+
195
+ class MultiDatasetSampler(Sampler):
196
+ def __init__(
197
+ self,
198
+ dataset_dicts,
199
+ dataset_ratio,
200
+ use_rfs,
201
+ dataset_ann,
202
+ repeat_threshold=0.001,
203
+ seed: Optional[int] = None,
204
+ ):
205
+ """
206
+ """
207
+ sizes = [0 for _ in range(len(dataset_ratio))]
208
+ for d in dataset_dicts:
209
+ sizes[d['dataset_source']] += 1
210
+ print('dataset sizes', sizes)
211
+ self.sizes = sizes
212
+ assert len(dataset_ratio) == len(sizes), \
213
+ 'length of dataset ratio {} should be equal to number if dataset {}'.format(
214
+ len(dataset_ratio), len(sizes)
215
+ )
216
+ if seed is None:
217
+ seed = comm.shared_random_seed()
218
+ self._seed = int(seed)
219
+ self._rank = comm.get_rank()
220
+ self._world_size = comm.get_world_size()
221
+
222
+ self.dataset_ids = torch.tensor(
223
+ [d['dataset_source'] for d in dataset_dicts], dtype=torch.long)
224
+
225
+ dataset_weight = [torch.ones(s) * max(sizes) / s * r / sum(dataset_ratio) \
226
+ for i, (r, s) in enumerate(zip(dataset_ratio, sizes))]
227
+ dataset_weight = torch.cat(dataset_weight)
228
+
229
+ rfs_factors = []
230
+ st = 0
231
+ for i, s in enumerate(sizes):
232
+ if use_rfs[i]:
233
+ if dataset_ann[i] == 'box':
234
+ rfs_func = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency
235
+ else:
236
+ rfs_func = repeat_factors_from_tag_frequency
237
+ rfs_factor = rfs_func(
238
+ dataset_dicts[st: st + s],
239
+ repeat_thresh=repeat_threshold)
240
+ rfs_factor = rfs_factor * (s / rfs_factor.sum())
241
+ else:
242
+ rfs_factor = torch.ones(s)
243
+ rfs_factors.append(rfs_factor)
244
+ st = st + s
245
+ rfs_factors = torch.cat(rfs_factors)
246
+
247
+ self.weights = dataset_weight * rfs_factors
248
+ self.sample_epoch_size = len(self.weights)
249
+
250
+ def __iter__(self):
251
+ start = self._rank
252
+ yield from itertools.islice(
253
+ self._infinite_indices(), start, None, self._world_size)
254
+
255
+
256
+ def _infinite_indices(self):
257
+ g = torch.Generator()
258
+ g.manual_seed(self._seed)
259
+ while True:
260
+ ids = torch.multinomial(
261
+ self.weights, self.sample_epoch_size, generator=g,
262
+ replacement=True)
263
+ nums = [(self.dataset_ids[ids] == i).sum().int().item() \
264
+ for i in range(len(self.sizes))]
265
+ yield from ids
266
+
267
+
268
+ class MDAspectRatioGroupedDataset(torch.utils.data.IterableDataset):
269
+ def __init__(self, dataset, batch_size, num_datasets):
270
+ """
271
+ """
272
+ self.dataset = dataset
273
+ self.batch_size = batch_size
274
+ self._buckets = [[] for _ in range(2 * num_datasets)]
275
+
276
+ def __iter__(self):
277
+ for d in self.dataset:
278
+ w, h = d["width"], d["height"]
279
+ aspect_ratio_bucket_id = 0 if w > h else 1
280
+ bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id
281
+ bucket = self._buckets[bucket_id]
282
+ bucket.append(d)
283
+ if len(bucket) == self.batch_size:
284
+ yield bucket[:]
285
+ del bucket[:]
286
+
287
+
288
+ class DIFFMDAspectRatioGroupedDataset(torch.utils.data.IterableDataset):
289
+ def __init__(self, dataset, batch_sizes, num_datasets):
290
+ """
291
+ """
292
+ self.dataset = dataset
293
+ self.batch_sizes = batch_sizes
294
+ self._buckets = [[] for _ in range(2 * num_datasets)]
295
+
296
+ def __iter__(self):
297
+ for d in self.dataset:
298
+ w, h = d["width"], d["height"]
299
+ aspect_ratio_bucket_id = 0 if w > h else 1
300
+ bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id
301
+ bucket = self._buckets[bucket_id]
302
+ bucket.append(d)
303
+ if len(bucket) == self.batch_sizes[d['dataset_source']]:
304
+ yield bucket[:]
305
+ del bucket[:]
306
+
307
+
308
+ def repeat_factors_from_tag_frequency(dataset_dicts, repeat_thresh):
309
+ """
310
+ """
311
+ category_freq = defaultdict(int)
312
+ for dataset_dict in dataset_dicts:
313
+ cat_ids = dataset_dict['pos_category_ids']
314
+ for cat_id in cat_ids:
315
+ category_freq[cat_id] += 1
316
+ num_images = len(dataset_dicts)
317
+ for k, v in category_freq.items():
318
+ category_freq[k] = v / num_images
319
+
320
+ category_rep = {
321
+ cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq))
322
+ for cat_id, cat_freq in category_freq.items()
323
+ }
324
+
325
+ rep_factors = []
326
+ for dataset_dict in dataset_dicts:
327
+ cat_ids = dataset_dict['pos_category_ids']
328
+ rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0)
329
+ rep_factors.append(rep_factor)
330
+
331
+ return torch.tensor(rep_factors, dtype=torch.float32)
proxydet/data/custom_dataset_mapper.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import copy
3
+ import logging
4
+ import numpy as np
5
+ from typing import List, Optional, Union
6
+ import torch
7
+ import pycocotools.mask as mask_util
8
+
9
+ from detectron2.config import configurable
10
+
11
+ from detectron2.data import detection_utils as utils
12
+ from detectron2.data.detection_utils import transform_keypoint_annotations
13
+ from detectron2.data import transforms as T
14
+ from detectron2.data.dataset_mapper import DatasetMapper
15
+ from detectron2.structures import Boxes, BoxMode, Instances
16
+ from detectron2.structures import Keypoints, PolygonMasks, BitMasks
17
+ from fvcore.transforms.transform import TransformList
18
+ from .custom_build_augmentation import build_custom_augmentation
19
+ from .tar_dataset import DiskTarDataset
20
+
21
+ __all__ = ["CustomDatasetMapper"]
22
+
23
+ class CustomDatasetMapper(DatasetMapper):
24
+ @configurable
25
+ def __init__(self, is_train: bool,
26
+ with_ann_type=False,
27
+ dataset_ann=[],
28
+ use_diff_bs_size=False,
29
+ dataset_augs=[],
30
+ is_debug=False,
31
+ use_tar_dataset=False,
32
+ tarfile_path='',
33
+ tar_index_dir='',
34
+ **kwargs):
35
+ """
36
+ add image labels
37
+ """
38
+ self.with_ann_type = with_ann_type
39
+ self.dataset_ann = dataset_ann
40
+ self.use_diff_bs_size = use_diff_bs_size
41
+ if self.use_diff_bs_size and is_train:
42
+ self.dataset_augs = [T.AugmentationList(x) for x in dataset_augs]
43
+ self.is_debug = is_debug
44
+ self.use_tar_dataset = use_tar_dataset
45
+ if self.use_tar_dataset:
46
+ print('Using tar dataset')
47
+ self.tar_dataset = DiskTarDataset(tarfile_path, tar_index_dir)
48
+ super().__init__(is_train, **kwargs)
49
+
50
+
51
+ @classmethod
52
+ def from_config(cls, cfg, is_train: bool = True):
53
+ ret = super().from_config(cfg, is_train)
54
+ ret.update({
55
+ 'with_ann_type': cfg.WITH_IMAGE_LABELS,
56
+ 'dataset_ann': cfg.DATALOADER.DATASET_ANN,
57
+ 'use_diff_bs_size': cfg.DATALOADER.USE_DIFF_BS_SIZE,
58
+ 'is_debug': cfg.IS_DEBUG,
59
+ 'use_tar_dataset': cfg.DATALOADER.USE_TAR_DATASET,
60
+ 'tarfile_path': cfg.DATALOADER.TARFILE_PATH,
61
+ 'tar_index_dir': cfg.DATALOADER.TAR_INDEX_DIR,
62
+ })
63
+ if ret['use_diff_bs_size'] and is_train:
64
+ if cfg.INPUT.CUSTOM_AUG == 'EfficientDetResizeCrop':
65
+ dataset_scales = cfg.DATALOADER.DATASET_INPUT_SCALE
66
+ dataset_sizes = cfg.DATALOADER.DATASET_INPUT_SIZE
67
+ ret['dataset_augs'] = [
68
+ build_custom_augmentation(cfg, True, scale, size) \
69
+ for scale, size in zip(dataset_scales, dataset_sizes)]
70
+ else:
71
+ assert cfg.INPUT.CUSTOM_AUG == 'ResizeShortestEdge'
72
+ min_sizes = cfg.DATALOADER.DATASET_MIN_SIZES
73
+ max_sizes = cfg.DATALOADER.DATASET_MAX_SIZES
74
+ ret['dataset_augs'] = [
75
+ build_custom_augmentation(
76
+ cfg, True, min_size=mi, max_size=ma) \
77
+ for mi, ma in zip(min_sizes, max_sizes)]
78
+ else:
79
+ ret['dataset_augs'] = []
80
+
81
+ return ret
82
+
83
+ def __call__(self, dataset_dict):
84
+ """
85
+ include image labels
86
+ """
87
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
88
+ # USER: Write your own image loading if it's not from a file
89
+ if 'file_name' in dataset_dict:
90
+ ori_image = utils.read_image(
91
+ dataset_dict["file_name"], format=self.image_format)
92
+ else:
93
+ ori_image, _, _ = self.tar_dataset[dataset_dict["tar_index"]]
94
+ ori_image = utils._apply_exif_orientation(ori_image)
95
+ ori_image = utils.convert_PIL_to_numpy(ori_image, self.image_format)
96
+ utils.check_image_size(dataset_dict, ori_image)
97
+
98
+ # USER: Remove if you don't do semantic/panoptic segmentation.
99
+ if "sem_seg_file_name" in dataset_dict:
100
+ sem_seg_gt = utils.read_image(
101
+ dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
102
+ else:
103
+ sem_seg_gt = None
104
+
105
+ if self.is_debug:
106
+ dataset_dict['dataset_source'] = 0
107
+
108
+ not_full_labeled = 'dataset_source' in dataset_dict and \
109
+ self.with_ann_type and \
110
+ self.dataset_ann[dataset_dict['dataset_source']] != 'box'
111
+
112
+ aug_input = T.AugInput(copy.deepcopy(ori_image), sem_seg=sem_seg_gt)
113
+ if self.use_diff_bs_size and self.is_train:
114
+ transforms = \
115
+ self.dataset_augs[dataset_dict['dataset_source']](aug_input)
116
+ else:
117
+ transforms = self.augmentations(aug_input)
118
+ image, sem_seg_gt = aug_input.image, aug_input.sem_seg
119
+
120
+ image_shape = image.shape[:2] # h, w
121
+ dataset_dict["image"] = torch.as_tensor(
122
+ np.ascontiguousarray(image.transpose(2, 0, 1)))
123
+
124
+ if sem_seg_gt is not None:
125
+ dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
126
+
127
+ # USER: Remove if you don't use pre-computed proposals.
128
+ # Most users would not need this feature.
129
+ if self.proposal_topk is not None:
130
+ utils.transform_proposals(
131
+ dataset_dict, image_shape, transforms,
132
+ proposal_topk=self.proposal_topk
133
+ )
134
+
135
+ if not self.is_train:
136
+ # USER: Modify this if you want to keep them for some reason.
137
+ dataset_dict.pop("annotations", None)
138
+ dataset_dict.pop("sem_seg_file_name", None)
139
+ return dataset_dict
140
+
141
+ if "annotations" in dataset_dict:
142
+ # USER: Modify this if you want to keep them for some reason.
143
+ for anno in dataset_dict["annotations"]:
144
+ if not self.use_instance_mask:
145
+ anno.pop("segmentation", None)
146
+ if not self.use_keypoint:
147
+ anno.pop("keypoints", None)
148
+
149
+ # USER: Implement additional transformations if you have other types of data
150
+ all_annos = [
151
+ (utils.transform_instance_annotations(
152
+ obj, transforms, image_shape,
153
+ keypoint_hflip_indices=self.keypoint_hflip_indices,
154
+ ), obj.get("iscrowd", 0))
155
+ for obj in dataset_dict.pop("annotations")
156
+ ]
157
+ annos = [ann[0] for ann in all_annos if ann[1] == 0]
158
+ instances = utils.annotations_to_instances(
159
+ annos, image_shape, mask_format=self.instance_mask_format
160
+ )
161
+
162
+ del all_annos
163
+ if self.recompute_boxes:
164
+ instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
165
+ dataset_dict["instances"] = utils.filter_empty_instances(instances)
166
+ if self.with_ann_type:
167
+ dataset_dict["pos_category_ids"] = dataset_dict.get(
168
+ 'pos_category_ids', [])
169
+ dataset_dict["ann_type"] = \
170
+ self.dataset_ann[dataset_dict['dataset_source']]
171
+ if self.is_debug and (('pos_category_ids' not in dataset_dict) or \
172
+ (dataset_dict['pos_category_ids'] == [])):
173
+ dataset_dict['pos_category_ids'] = [x for x in sorted(set(
174
+ dataset_dict['instances'].gt_classes.tolist()
175
+ ))]
176
+ return dataset_dict
177
+
178
+ # DETR augmentation
179
+ def build_transform_gen(cfg, is_train):
180
+ """
181
+ """
182
+ if is_train:
183
+ min_size = cfg.INPUT.MIN_SIZE_TRAIN
184
+ max_size = cfg.INPUT.MAX_SIZE_TRAIN
185
+ sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
186
+ else:
187
+ min_size = cfg.INPUT.MIN_SIZE_TEST
188
+ max_size = cfg.INPUT.MAX_SIZE_TEST
189
+ sample_style = "choice"
190
+ if sample_style == "range":
191
+ assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size))
192
+
193
+ logger = logging.getLogger(__name__)
194
+ tfm_gens = []
195
+ if is_train:
196
+ tfm_gens.append(T.RandomFlip())
197
+ tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
198
+ if is_train:
199
+ logger.info("TransformGens used in training: " + str(tfm_gens))
200
+ return tfm_gens
201
+
202
+
203
+ class DetrDatasetMapper:
204
+ """
205
+ A callable which takes a dataset dict in Detectron2 Dataset format,
206
+ and map it into a format used by DETR.
207
+ The callable currently does the following:
208
+ 1. Read the image from "file_name"
209
+ 2. Applies geometric transforms to the image and annotation
210
+ 3. Find and applies suitable cropping to the image and annotation
211
+ 4. Prepare image and annotation to Tensors
212
+ """
213
+
214
+ def __init__(self, cfg, is_train=True):
215
+ if cfg.INPUT.CROP.ENABLED and is_train:
216
+ self.crop_gen = [
217
+ T.ResizeShortestEdge([400, 500, 600], sample_style="choice"),
218
+ T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE),
219
+ ]
220
+ else:
221
+ self.crop_gen = None
222
+
223
+ self.mask_on = cfg.MODEL.MASK_ON
224
+ self.tfm_gens = build_transform_gen(cfg, is_train)
225
+ logging.getLogger(__name__).info(
226
+ "Full TransformGens used in training: {}, crop: {}".format(str(self.tfm_gens), str(self.crop_gen))
227
+ )
228
+
229
+ self.img_format = cfg.INPUT.FORMAT
230
+ self.is_train = is_train
231
+
232
+ def __call__(self, dataset_dict):
233
+ """
234
+ Args:
235
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
236
+ Returns:
237
+ dict: a format that builtin models in detectron2 accept
238
+ """
239
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
240
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
241
+ utils.check_image_size(dataset_dict, image)
242
+
243
+ if self.crop_gen is None:
244
+ image, transforms = T.apply_transform_gens(self.tfm_gens, image)
245
+ else:
246
+ if np.random.rand() > 0.5:
247
+ image, transforms = T.apply_transform_gens(self.tfm_gens, image)
248
+ else:
249
+ image, transforms = T.apply_transform_gens(
250
+ self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image
251
+ )
252
+
253
+ image_shape = image.shape[:2] # h, w
254
+
255
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
256
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
257
+ # Therefore it's important to use torch.Tensor.
258
+ dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
259
+
260
+ if not self.is_train:
261
+ # USER: Modify this if you want to keep them for some reason.
262
+ dataset_dict.pop("annotations", None)
263
+ return dataset_dict
264
+
265
+ if "annotations" in dataset_dict:
266
+ # USER: Modify this if you want to keep them for some reason.
267
+ for anno in dataset_dict["annotations"]:
268
+ if not self.mask_on:
269
+ anno.pop("segmentation", None)
270
+ anno.pop("keypoints", None)
271
+
272
+ # USER: Implement additional transformations if you have other types of data
273
+ annos = [
274
+ utils.transform_instance_annotations(obj, transforms, image_shape)
275
+ for obj in dataset_dict.pop("annotations")
276
+ if obj.get("iscrowd", 0) == 0
277
+ ]
278
+ instances = utils.annotations_to_instances(annos, image_shape)
279
+ dataset_dict["instances"] = utils.filter_empty_instances(instances)
280
+ return dataset_dict
proxydet/data/datasets/cc.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ import os
4
+
5
+ from detectron2.data.datasets.builtin_meta import _get_builtin_metadata
6
+ from detectron2.data.datasets.lvis import get_lvis_instances_meta
7
+ from .lvis_v1 import custom_register_lvis_instances
8
+
9
+ _CUSTOM_SPLITS = {
10
+ "cc3m_v1_val": ("cc3m/validation/", "cc3m/val_image_info.json"),
11
+ "cc3m_v1_train": ("cc3m/training/", "cc3m/train_image_info.json"),
12
+ "cc3m_v1_train_tags": ("cc3m/training/", "cc3m/train_image_info_tags.json"),
13
+
14
+ }
15
+
16
+ for key, (image_root, json_file) in _CUSTOM_SPLITS.items():
17
+ custom_register_lvis_instances(
18
+ key,
19
+ get_lvis_instances_meta('lvis_v1'),
20
+ os.path.join("datasets", json_file) if "://" not in json_file else json_file,
21
+ os.path.join("datasets", image_root),
22
+ )
23
+
proxydet/data/datasets/coco_zeroshot.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import os
3
+
4
+ from detectron2.data.datasets.register_coco import register_coco_instances
5
+ from detectron2.data.datasets.builtin_meta import _get_builtin_metadata
6
+ from .lvis_v1 import custom_register_lvis_instances
7
+
8
+ categories_seen = [
9
+ {'id': 1, 'name': 'person'},
10
+ {'id': 2, 'name': 'bicycle'},
11
+ {'id': 3, 'name': 'car'},
12
+ {'id': 4, 'name': 'motorcycle'},
13
+ {'id': 7, 'name': 'train'},
14
+ {'id': 8, 'name': 'truck'},
15
+ {'id': 9, 'name': 'boat'},
16
+ {'id': 15, 'name': 'bench'},
17
+ {'id': 16, 'name': 'bird'},
18
+ {'id': 19, 'name': 'horse'},
19
+ {'id': 20, 'name': 'sheep'},
20
+ {'id': 23, 'name': 'bear'},
21
+ {'id': 24, 'name': 'zebra'},
22
+ {'id': 25, 'name': 'giraffe'},
23
+ {'id': 27, 'name': 'backpack'},
24
+ {'id': 31, 'name': 'handbag'},
25
+ {'id': 33, 'name': 'suitcase'},
26
+ {'id': 34, 'name': 'frisbee'},
27
+ {'id': 35, 'name': 'skis'},
28
+ {'id': 38, 'name': 'kite'},
29
+ {'id': 42, 'name': 'surfboard'},
30
+ {'id': 44, 'name': 'bottle'},
31
+ {'id': 48, 'name': 'fork'},
32
+ {'id': 50, 'name': 'spoon'},
33
+ {'id': 51, 'name': 'bowl'},
34
+ {'id': 52, 'name': 'banana'},
35
+ {'id': 53, 'name': 'apple'},
36
+ {'id': 54, 'name': 'sandwich'},
37
+ {'id': 55, 'name': 'orange'},
38
+ {'id': 56, 'name': 'broccoli'},
39
+ {'id': 57, 'name': 'carrot'},
40
+ {'id': 59, 'name': 'pizza'},
41
+ {'id': 60, 'name': 'donut'},
42
+ {'id': 62, 'name': 'chair'},
43
+ {'id': 65, 'name': 'bed'},
44
+ {'id': 70, 'name': 'toilet'},
45
+ {'id': 72, 'name': 'tv'},
46
+ {'id': 73, 'name': 'laptop'},
47
+ {'id': 74, 'name': 'mouse'},
48
+ {'id': 75, 'name': 'remote'},
49
+ {'id': 78, 'name': 'microwave'},
50
+ {'id': 79, 'name': 'oven'},
51
+ {'id': 80, 'name': 'toaster'},
52
+ {'id': 82, 'name': 'refrigerator'},
53
+ {'id': 84, 'name': 'book'},
54
+ {'id': 85, 'name': 'clock'},
55
+ {'id': 86, 'name': 'vase'},
56
+ {'id': 90, 'name': 'toothbrush'},
57
+ ]
58
+
59
+ categories_unseen = [
60
+ {'id': 5, 'name': 'airplane'},
61
+ {'id': 6, 'name': 'bus'},
62
+ {'id': 17, 'name': 'cat'},
63
+ {'id': 18, 'name': 'dog'},
64
+ {'id': 21, 'name': 'cow'},
65
+ {'id': 22, 'name': 'elephant'},
66
+ {'id': 28, 'name': 'umbrella'},
67
+ {'id': 32, 'name': 'tie'},
68
+ {'id': 36, 'name': 'snowboard'},
69
+ {'id': 41, 'name': 'skateboard'},
70
+ {'id': 47, 'name': 'cup'},
71
+ {'id': 49, 'name': 'knife'},
72
+ {'id': 61, 'name': 'cake'},
73
+ {'id': 63, 'name': 'couch'},
74
+ {'id': 76, 'name': 'keyboard'},
75
+ {'id': 81, 'name': 'sink'},
76
+ {'id': 87, 'name': 'scissors'},
77
+ ]
78
+
79
+ def _get_metadata(cat):
80
+ if cat == 'all':
81
+ return _get_builtin_metadata('coco')
82
+ elif cat == 'seen':
83
+ id_to_name = {x['id']: x['name'] for x in categories_seen}
84
+ else:
85
+ assert cat == 'unseen'
86
+ id_to_name = {x['id']: x['name'] for x in categories_unseen}
87
+
88
+ thing_dataset_id_to_contiguous_id = {
89
+ x: i for i, x in enumerate(sorted(id_to_name))}
90
+ thing_classes = [id_to_name[k] for k in sorted(id_to_name)]
91
+ return {
92
+ "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
93
+ "thing_classes": thing_classes}
94
+
95
+ _PREDEFINED_SPLITS_COCO = {
96
+ "coco_zeroshot_train": ("coco/train2017", "coco/zero-shot/instances_train2017_seen_2.json", 'seen'),
97
+ "coco_zeroshot_val": ("coco/val2017", "coco/zero-shot/instances_val2017_unseen_2.json", 'unseen'),
98
+ "coco_not_zeroshot_val": ("coco/val2017", "coco/zero-shot/instances_val2017_seen_2.json", 'seen'),
99
+ "coco_generalized_zeroshot_val": ("coco/val2017", "coco/zero-shot/instances_val2017_all_2_oriorder.json", 'all'),
100
+ "coco_zeroshot_train_oriorder": ("coco/train2017", "coco/zero-shot/instances_train2017_seen_2_oriorder.json", 'all'),
101
+ }
102
+
103
+ for key, (image_root, json_file, cat) in _PREDEFINED_SPLITS_COCO.items():
104
+ register_coco_instances(
105
+ key,
106
+ _get_metadata(cat),
107
+ os.path.join("datasets", json_file) if "://" not in json_file else json_file,
108
+ os.path.join("datasets", image_root),
109
+ )
110
+
111
+ _CUSTOM_SPLITS_COCO = {
112
+ "cc3m_coco_train_tags": ("cc3m/training/", "cc3m/coco_train_image_info_tags.json"),
113
+ "coco_caption_train_tags": ("coco/train2017/", "coco/annotations/captions_train2017_tags_allcaps.json"),}
114
+
115
+ for key, (image_root, json_file) in _CUSTOM_SPLITS_COCO.items():
116
+ custom_register_lvis_instances(
117
+ key,
118
+ _get_builtin_metadata('coco'),
119
+ os.path.join("datasets", json_file) if "://" not in json_file else json_file,
120
+ os.path.join("datasets", image_root),
121
+ )
proxydet/data/datasets/imagenet.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ import os
4
+
5
+ from detectron2.data import DatasetCatalog, MetadataCatalog
6
+ from detectron2.data.datasets.lvis import get_lvis_instances_meta
7
+ from .lvis_v1 import custom_load_lvis_json, get_lvis_22k_meta
8
+ def custom_register_imagenet_instances(name, metadata, json_file, image_root):
9
+ """
10
+ """
11
+ DatasetCatalog.register(name, lambda: custom_load_lvis_json(
12
+ json_file, image_root, name))
13
+ MetadataCatalog.get(name).set(
14
+ json_file=json_file, image_root=image_root,
15
+ evaluator_type="imagenet", **metadata
16
+ )
17
+
18
+ _CUSTOM_SPLITS_IMAGENET = {
19
+ "imagenet_lvis_v1": ("imagenet/ImageNet-LVIS/", "imagenet/annotations/imagenet_lvis_image_info.json"),
20
+ }
21
+
22
+ for key, (image_root, json_file) in _CUSTOM_SPLITS_IMAGENET.items():
23
+ custom_register_imagenet_instances(
24
+ key,
25
+ get_lvis_instances_meta('lvis_v1'),
26
+ os.path.join("datasets", json_file) if "://" not in json_file else json_file,
27
+ os.path.join("datasets", image_root),
28
+ )
29
+
30
+
31
+ _CUSTOM_SPLITS_IMAGENET_22K = {
32
+ "imagenet_lvis-22k": ("imagenet/ImageNet-LVIS/", "imagenet/annotations/imagenet-22k_image_info_lvis-22k.json"),
33
+ }
34
+
35
+ for key, (image_root, json_file) in _CUSTOM_SPLITS_IMAGENET_22K.items():
36
+ custom_register_imagenet_instances(
37
+ key,
38
+ get_lvis_22k_meta(),
39
+ os.path.join("datasets", json_file) if "://" not in json_file else json_file,
40
+ os.path.join("datasets", image_root),
41
+ )
proxydet/data/datasets/lvis_22k_categories.py ADDED
The diff for this file is too large to render. See raw diff
 
proxydet/data/datasets/lvis_v1.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ import os
4
+
5
+ from fvcore.common.timer import Timer
6
+ from detectron2.structures import BoxMode
7
+ from fvcore.common.file_io import PathManager
8
+ from detectron2.data import DatasetCatalog, MetadataCatalog
9
+ from detectron2.data.datasets.lvis import get_lvis_instances_meta
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ __all__ = ["custom_load_lvis_json", "custom_register_lvis_instances"]
14
+
15
+
16
+ def custom_register_lvis_instances(name, metadata, json_file, image_root):
17
+ """
18
+ """
19
+ DatasetCatalog.register(name, lambda: custom_load_lvis_json(
20
+ json_file, image_root, name))
21
+ MetadataCatalog.get(name).set(
22
+ json_file=json_file, image_root=image_root,
23
+ evaluator_type="lvis", **metadata
24
+ )
25
+
26
+
27
+ def custom_load_lvis_json(json_file, image_root, dataset_name=None):
28
+ '''
29
+ Modifications:
30
+ use `file_name`
31
+ convert neg_category_ids
32
+ add pos_category_ids
33
+ '''
34
+ from lvis import LVIS
35
+
36
+ json_file = PathManager.get_local_path(json_file)
37
+
38
+ timer = Timer()
39
+ lvis_api = LVIS(json_file)
40
+ if timer.seconds() > 1:
41
+ logger.info("Loading {} takes {:.2f} seconds.".format(
42
+ json_file, timer.seconds()))
43
+
44
+ catid2contid = {x['id']: i for i, x in enumerate(
45
+ sorted(lvis_api.dataset['categories'], key=lambda x: x['id']))}
46
+ if len(lvis_api.dataset['categories']) == 1203:
47
+ for x in lvis_api.dataset['categories']:
48
+ assert catid2contid[x['id']] == x['id'] - 1
49
+ img_ids = sorted(lvis_api.imgs.keys())
50
+ imgs = lvis_api.load_imgs(img_ids)
51
+ anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
52
+
53
+ ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
54
+ assert len(set(ann_ids)) == len(ann_ids), \
55
+ "Annotation ids in '{}' are not unique".format(json_file)
56
+
57
+ imgs_anns = list(zip(imgs, anns))
58
+ logger.info("Loaded {} images in the LVIS v1 format from {}".format(
59
+ len(imgs_anns), json_file))
60
+
61
+ dataset_dicts = []
62
+
63
+ for (img_dict, anno_dict_list) in imgs_anns:
64
+ record = {}
65
+ if "file_name" in img_dict:
66
+ file_name = img_dict["file_name"]
67
+ if img_dict["file_name"].startswith("COCO"):
68
+ file_name = file_name[-16:]
69
+ record["file_name"] = os.path.join(image_root, file_name)
70
+ elif 'coco_url' in img_dict:
71
+ # e.g., http://images.cocodataset.org/train2017/000000391895.jpg
72
+ file_name = img_dict["coco_url"][30:]
73
+ record["file_name"] = os.path.join(image_root, file_name)
74
+ elif 'tar_index' in img_dict:
75
+ record['tar_index'] = img_dict['tar_index']
76
+
77
+ record["height"] = img_dict["height"]
78
+ record["width"] = img_dict["width"]
79
+ record["not_exhaustive_category_ids"] = img_dict.get(
80
+ "not_exhaustive_category_ids", [])
81
+ record["neg_category_ids"] = img_dict.get("neg_category_ids", [])
82
+ # NOTE: modified by Xingyi: convert to 0-based
83
+ record["neg_category_ids"] = [
84
+ catid2contid[x] for x in record["neg_category_ids"]]
85
+ if 'pos_category_ids' in img_dict:
86
+ record['pos_category_ids'] = [
87
+ catid2contid[x] for x in img_dict.get("pos_category_ids", [])]
88
+ if 'captions' in img_dict:
89
+ record['captions'] = img_dict['captions']
90
+ if 'caption_features' in img_dict:
91
+ record['caption_features'] = img_dict['caption_features']
92
+ image_id = record["image_id"] = img_dict["id"]
93
+
94
+ objs = []
95
+ for anno in anno_dict_list:
96
+ assert anno["image_id"] == image_id
97
+ if anno.get('iscrowd', 0) > 0:
98
+ continue
99
+ obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
100
+ obj["category_id"] = catid2contid[anno['category_id']]
101
+ if 'segmentation' in anno:
102
+ segm = anno["segmentation"]
103
+ valid_segm = [poly for poly in segm \
104
+ if len(poly) % 2 == 0 and len(poly) >= 6]
105
+ # assert len(segm) == len(
106
+ # valid_segm
107
+ # ), "Annotation contains an invalid polygon with < 3 points"
108
+ if not len(segm) == len(valid_segm):
109
+ print('Annotation contains an invalid polygon with < 3 points')
110
+ assert len(segm) > 0
111
+ obj["segmentation"] = segm
112
+ objs.append(obj)
113
+ record["annotations"] = objs
114
+ dataset_dicts.append(record)
115
+
116
+ return dataset_dicts
117
+
118
+ _CUSTOM_SPLITS_LVIS = {
119
+ "lvis_v1_train+coco": ("coco/", "lvis/lvis_v1_train+coco_mask.json"),
120
+ "lvis_v1_train_norare": ("coco/", "lvis/lvis_v1_train_norare.json"),
121
+ }
122
+
123
+
124
+ for key, (image_root, json_file) in _CUSTOM_SPLITS_LVIS.items():
125
+ custom_register_lvis_instances(
126
+ key,
127
+ get_lvis_instances_meta(key),
128
+ os.path.join("datasets", json_file) if "://" not in json_file else json_file,
129
+ os.path.join("datasets", image_root),
130
+ )
131
+
132
+
133
+ def get_lvis_22k_meta():
134
+ from .lvis_22k_categories import CATEGORIES
135
+ cat_ids = [k["id"] for k in CATEGORIES]
136
+ assert min(cat_ids) == 1 and max(cat_ids) == len(
137
+ cat_ids
138
+ ), "Category ids are not in [1, #categories], as expected"
139
+ # Ensure that the category list is sorted by id
140
+ lvis_categories = sorted(CATEGORIES, key=lambda x: x["id"])
141
+ thing_classes = [k["name"] for k in lvis_categories]
142
+ meta = {"thing_classes": thing_classes}
143
+ return meta
144
+
145
+ _CUSTOM_SPLITS_LVIS_22K = {
146
+ "lvis_v1_train_22k": ("coco/", "lvis/lvis_v1_train_lvis-22k.json"),
147
+ }
148
+
149
+ for key, (image_root, json_file) in _CUSTOM_SPLITS_LVIS_22K.items():
150
+ custom_register_lvis_instances(
151
+ key,
152
+ get_lvis_22k_meta(),
153
+ os.path.join("datasets", json_file) if "://" not in json_file else json_file,
154
+ os.path.join("datasets", image_root),
155
+ )
proxydet/data/datasets/objects365.py ADDED
@@ -0,0 +1,770 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from detectron2.data.datasets.register_coco import register_coco_instances
3
+ import os
4
+
5
+ # categories_v2 = [
6
+ # {'id': 1, 'name': 'Person'},
7
+ # {'id': 2, 'name': 'Sneakers'},
8
+ # {'id': 3, 'name': 'Chair'},
9
+ # {'id': 4, 'name': 'Other Shoes'},
10
+ # {'id': 5, 'name': 'Hat'},
11
+ # {'id': 6, 'name': 'Car'},
12
+ # {'id': 7, 'name': 'Lamp'},
13
+ # {'id': 8, 'name': 'Glasses'},
14
+ # {'id': 9, 'name': 'Bottle'},
15
+ # {'id': 10, 'name': 'Desk'},
16
+ # {'id': 11, 'name': 'Cup'},
17
+ # {'id': 12, 'name': 'Street Lights'},
18
+ # {'id': 13, 'name': 'Cabinet/shelf'},
19
+ # {'id': 14, 'name': 'Handbag/Satchel'},
20
+ # {'id': 15, 'name': 'Bracelet'},
21
+ # {'id': 16, 'name': 'Plate'},
22
+ # {'id': 17, 'name': 'Picture/Frame'},
23
+ # {'id': 18, 'name': 'Helmet'},
24
+ # {'id': 19, 'name': 'Book'},
25
+ # {'id': 20, 'name': 'Gloves'},
26
+ # {'id': 21, 'name': 'Storage box'},
27
+ # {'id': 22, 'name': 'Boat'},
28
+ # {'id': 23, 'name': 'Leather Shoes'},
29
+ # {'id': 24, 'name': 'Flower'},
30
+ # {'id': 25, 'name': 'Bench'},
31
+ # {'id': 26, 'name': 'Potted Plant'},
32
+ # {'id': 27, 'name': 'Bowl/Basin'},
33
+ # {'id': 28, 'name': 'Flag'},
34
+ # {'id': 29, 'name': 'Pillow'},
35
+ # {'id': 30, 'name': 'Boots'},
36
+ # {'id': 31, 'name': 'Vase'},
37
+ # {'id': 32, 'name': 'Microphone'},
38
+ # {'id': 33, 'name': 'Necklace'},
39
+ # {'id': 34, 'name': 'Ring'},
40
+ # {'id': 35, 'name': 'SUV'},
41
+ # {'id': 36, 'name': 'Wine Glass'},
42
+ # {'id': 37, 'name': 'Belt'},
43
+ # {'id': 38, 'name': 'Moniter/TV'},
44
+ # {'id': 39, 'name': 'Backpack'},
45
+ # {'id': 40, 'name': 'Umbrella'},
46
+ # {'id': 41, 'name': 'Traffic Light'},
47
+ # {'id': 42, 'name': 'Speaker'},
48
+ # {'id': 43, 'name': 'Watch'},
49
+ # {'id': 44, 'name': 'Tie'},
50
+ # {'id': 45, 'name': 'Trash bin Can'},
51
+ # {'id': 46, 'name': 'Slippers'},
52
+ # {'id': 47, 'name': 'Bicycle'},
53
+ # {'id': 48, 'name': 'Stool'},
54
+ # {'id': 49, 'name': 'Barrel/bucket'},
55
+ # {'id': 50, 'name': 'Van'},
56
+ # {'id': 51, 'name': 'Couch'},
57
+ # {'id': 52, 'name': 'Sandals'},
58
+ # {'id': 53, 'name': 'Bakset'},
59
+ # {'id': 54, 'name': 'Drum'},
60
+ # {'id': 55, 'name': 'Pen/Pencil'},
61
+ # {'id': 56, 'name': 'Bus'},
62
+ # {'id': 57, 'name': 'Wild Bird'},
63
+ # {'id': 58, 'name': 'High Heels'},
64
+ # {'id': 59, 'name': 'Motorcycle'},
65
+ # {'id': 60, 'name': 'Guitar'},
66
+ # {'id': 61, 'name': 'Carpet'},
67
+ # {'id': 62, 'name': 'Cell Phone'},
68
+ # {'id': 63, 'name': 'Bread'},
69
+ # {'id': 64, 'name': 'Camera'},
70
+ # {'id': 65, 'name': 'Canned'},
71
+ # {'id': 66, 'name': 'Truck'},
72
+ # {'id': 67, 'name': 'Traffic cone'},
73
+ # {'id': 68, 'name': 'Cymbal'},
74
+ # {'id': 69, 'name': 'Lifesaver'},
75
+ # {'id': 70, 'name': 'Towel'},
76
+ # {'id': 71, 'name': 'Stuffed Toy'},
77
+ # {'id': 72, 'name': 'Candle'},
78
+ # {'id': 73, 'name': 'Sailboat'},
79
+ # {'id': 74, 'name': 'Laptop'},
80
+ # {'id': 75, 'name': 'Awning'},
81
+ # {'id': 76, 'name': 'Bed'},
82
+ # {'id': 77, 'name': 'Faucet'},
83
+ # {'id': 78, 'name': 'Tent'},
84
+ # {'id': 79, 'name': 'Horse'},
85
+ # {'id': 80, 'name': 'Mirror'},
86
+ # {'id': 81, 'name': 'Power outlet'},
87
+ # {'id': 82, 'name': 'Sink'},
88
+ # {'id': 83, 'name': 'Apple'},
89
+ # {'id': 84, 'name': 'Air Conditioner'},
90
+ # {'id': 85, 'name': 'Knife'},
91
+ # {'id': 86, 'name': 'Hockey Stick'},
92
+ # {'id': 87, 'name': 'Paddle'},
93
+ # {'id': 88, 'name': 'Pickup Truck'},
94
+ # {'id': 89, 'name': 'Fork'},
95
+ # {'id': 90, 'name': 'Traffic Sign'},
96
+ # {'id': 91, 'name': 'Ballon'},
97
+ # {'id': 92, 'name': 'Tripod'},
98
+ # {'id': 93, 'name': 'Dog'},
99
+ # {'id': 94, 'name': 'Spoon'},
100
+ # {'id': 95, 'name': 'Clock'},
101
+ # {'id': 96, 'name': 'Pot'},
102
+ # {'id': 97, 'name': 'Cow'},
103
+ # {'id': 98, 'name': 'Cake'},
104
+ # {'id': 99, 'name': 'Dinning Table'},
105
+ # {'id': 100, 'name': 'Sheep'},
106
+ # {'id': 101, 'name': 'Hanger'},
107
+ # {'id': 102, 'name': 'Blackboard/Whiteboard'},
108
+ # {'id': 103, 'name': 'Napkin'},
109
+ # {'id': 104, 'name': 'Other Fish'},
110
+ # {'id': 105, 'name': 'Orange/Tangerine'},
111
+ # {'id': 106, 'name': 'Toiletry'},
112
+ # {'id': 107, 'name': 'Keyboard'},
113
+ # {'id': 108, 'name': 'Tomato'},
114
+ # {'id': 109, 'name': 'Lantern'},
115
+ # {'id': 110, 'name': 'Machinery Vehicle'},
116
+ # {'id': 111, 'name': 'Fan'},
117
+ # {'id': 112, 'name': 'Green Vegetables'},
118
+ # {'id': 113, 'name': 'Banana'},
119
+ # {'id': 114, 'name': 'Baseball Glove'},
120
+ # {'id': 115, 'name': 'Airplane'},
121
+ # {'id': 116, 'name': 'Mouse'},
122
+ # {'id': 117, 'name': 'Train'},
123
+ # {'id': 118, 'name': 'Pumpkin'},
124
+ # {'id': 119, 'name': 'Soccer'},
125
+ # {'id': 120, 'name': 'Skiboard'},
126
+ # {'id': 121, 'name': 'Luggage'},
127
+ # {'id': 122, 'name': 'Nightstand'},
128
+ # {'id': 123, 'name': 'Tea pot'},
129
+ # {'id': 124, 'name': 'Telephone'},
130
+ # {'id': 125, 'name': 'Trolley'},
131
+ # {'id': 126, 'name': 'Head Phone'},
132
+ # {'id': 127, 'name': 'Sports Car'},
133
+ # {'id': 128, 'name': 'Stop Sign'},
134
+ # {'id': 129, 'name': 'Dessert'},
135
+ # {'id': 130, 'name': 'Scooter'},
136
+ # {'id': 131, 'name': 'Stroller'},
137
+ # {'id': 132, 'name': 'Crane'},
138
+ # {'id': 133, 'name': 'Remote'},
139
+ # {'id': 134, 'name': 'Refrigerator'},
140
+ # {'id': 135, 'name': 'Oven'},
141
+ # {'id': 136, 'name': 'Lemon'},
142
+ # {'id': 137, 'name': 'Duck'},
143
+ # {'id': 138, 'name': 'Baseball Bat'},
144
+ # {'id': 139, 'name': 'Surveillance Camera'},
145
+ # {'id': 140, 'name': 'Cat'},
146
+ # {'id': 141, 'name': 'Jug'},
147
+ # {'id': 142, 'name': 'Broccoli'},
148
+ # {'id': 143, 'name': 'Piano'},
149
+ # {'id': 144, 'name': 'Pizza'},
150
+ # {'id': 145, 'name': 'Elephant'},
151
+ # {'id': 146, 'name': 'Skateboard'},
152
+ # {'id': 147, 'name': 'Surfboard'},
153
+ # {'id': 148, 'name': 'Gun'},
154
+ # {'id': 149, 'name': 'Skating and Skiing shoes'},
155
+ # {'id': 150, 'name': 'Gas stove'},
156
+ # {'id': 151, 'name': 'Donut'},
157
+ # {'id': 152, 'name': 'Bow Tie'},
158
+ # {'id': 153, 'name': 'Carrot'},
159
+ # {'id': 154, 'name': 'Toilet'},
160
+ # {'id': 155, 'name': 'Kite'},
161
+ # {'id': 156, 'name': 'Strawberry'},
162
+ # {'id': 157, 'name': 'Other Balls'},
163
+ # {'id': 158, 'name': 'Shovel'},
164
+ # {'id': 159, 'name': 'Pepper'},
165
+ # {'id': 160, 'name': 'Computer Box'},
166
+ # {'id': 161, 'name': 'Toilet Paper'},
167
+ # {'id': 162, 'name': 'Cleaning Products'},
168
+ # {'id': 163, 'name': 'Chopsticks'},
169
+ # {'id': 164, 'name': 'Microwave'},
170
+ # {'id': 165, 'name': 'Pigeon'},
171
+ # {'id': 166, 'name': 'Baseball'},
172
+ # {'id': 167, 'name': 'Cutting/chopping Board'},
173
+ # {'id': 168, 'name': 'Coffee Table'},
174
+ # {'id': 169, 'name': 'Side Table'},
175
+ # {'id': 170, 'name': 'Scissors'},
176
+ # {'id': 171, 'name': 'Marker'},
177
+ # {'id': 172, 'name': 'Pie'},
178
+ # {'id': 173, 'name': 'Ladder'},
179
+ # {'id': 174, 'name': 'Snowboard'},
180
+ # {'id': 175, 'name': 'Cookies'},
181
+ # {'id': 176, 'name': 'Radiator'},
182
+ # {'id': 177, 'name': 'Fire Hydrant'},
183
+ # {'id': 178, 'name': 'Basketball'},
184
+ # {'id': 179, 'name': 'Zebra'},
185
+ # {'id': 180, 'name': 'Grape'},
186
+ # {'id': 181, 'name': 'Giraffe'},
187
+ # {'id': 182, 'name': 'Potato'},
188
+ # {'id': 183, 'name': 'Sausage'},
189
+ # {'id': 184, 'name': 'Tricycle'},
190
+ # {'id': 185, 'name': 'Violin'},
191
+ # {'id': 186, 'name': 'Egg'},
192
+ # {'id': 187, 'name': 'Fire Extinguisher'},
193
+ # {'id': 188, 'name': 'Candy'},
194
+ # {'id': 189, 'name': 'Fire Truck'},
195
+ # {'id': 190, 'name': 'Billards'},
196
+ # {'id': 191, 'name': 'Converter'},
197
+ # {'id': 192, 'name': 'Bathtub'},
198
+ # {'id': 193, 'name': 'Wheelchair'},
199
+ # {'id': 194, 'name': 'Golf Club'},
200
+ # {'id': 195, 'name': 'Briefcase'},
201
+ # {'id': 196, 'name': 'Cucumber'},
202
+ # {'id': 197, 'name': 'Cigar/Cigarette '},
203
+ # {'id': 198, 'name': 'Paint Brush'},
204
+ # {'id': 199, 'name': 'Pear'},
205
+ # {'id': 200, 'name': 'Heavy Truck'},
206
+ # {'id': 201, 'name': 'Hamburger'},
207
+ # {'id': 202, 'name': 'Extractor'},
208
+ # {'id': 203, 'name': 'Extention Cord'},
209
+ # {'id': 204, 'name': 'Tong'},
210
+ # {'id': 205, 'name': 'Tennis Racket'},
211
+ # {'id': 206, 'name': 'Folder'},
212
+ # {'id': 207, 'name': 'American Football'},
213
+ # {'id': 208, 'name': 'earphone'},
214
+ # {'id': 209, 'name': 'Mask'},
215
+ # {'id': 210, 'name': 'Kettle'},
216
+ # {'id': 211, 'name': 'Tennis'},
217
+ # {'id': 212, 'name': 'Ship'},
218
+ # {'id': 213, 'name': 'Swing'},
219
+ # {'id': 214, 'name': 'Coffee Machine'},
220
+ # {'id': 215, 'name': 'Slide'},
221
+ # {'id': 216, 'name': 'Carriage'},
222
+ # {'id': 217, 'name': 'Onion'},
223
+ # {'id': 218, 'name': 'Green beans'},
224
+ # {'id': 219, 'name': 'Projector'},
225
+ # {'id': 220, 'name': 'Frisbee'},
226
+ # {'id': 221, 'name': 'Washing Machine/Drying Machine'},
227
+ # {'id': 222, 'name': 'Chicken'},
228
+ # {'id': 223, 'name': 'Printer'},
229
+ # {'id': 224, 'name': 'Watermelon'},
230
+ # {'id': 225, 'name': 'Saxophone'},
231
+ # {'id': 226, 'name': 'Tissue'},
232
+ # {'id': 227, 'name': 'Toothbrush'},
233
+ # {'id': 228, 'name': 'Ice cream'},
234
+ # {'id': 229, 'name': 'Hotair ballon'},
235
+ # {'id': 230, 'name': 'Cello'},
236
+ # {'id': 231, 'name': 'French Fries'},
237
+ # {'id': 232, 'name': 'Scale'},
238
+ # {'id': 233, 'name': 'Trophy'},
239
+ # {'id': 234, 'name': 'Cabbage'},
240
+ # {'id': 235, 'name': 'Hot dog'},
241
+ # {'id': 236, 'name': 'Blender'},
242
+ # {'id': 237, 'name': 'Peach'},
243
+ # {'id': 238, 'name': 'Rice'},
244
+ # {'id': 239, 'name': 'Wallet/Purse'},
245
+ # {'id': 240, 'name': 'Volleyball'},
246
+ # {'id': 241, 'name': 'Deer'},
247
+ # {'id': 242, 'name': 'Goose'},
248
+ # {'id': 243, 'name': 'Tape'},
249
+ # {'id': 244, 'name': 'Tablet'},
250
+ # {'id': 245, 'name': 'Cosmetics'},
251
+ # {'id': 246, 'name': 'Trumpet'},
252
+ # {'id': 247, 'name': 'Pineapple'},
253
+ # {'id': 248, 'name': 'Golf Ball'},
254
+ # {'id': 249, 'name': 'Ambulance'},
255
+ # {'id': 250, 'name': 'Parking meter'},
256
+ # {'id': 251, 'name': 'Mango'},
257
+ # {'id': 252, 'name': 'Key'},
258
+ # {'id': 253, 'name': 'Hurdle'},
259
+ # {'id': 254, 'name': 'Fishing Rod'},
260
+ # {'id': 255, 'name': 'Medal'},
261
+ # {'id': 256, 'name': 'Flute'},
262
+ # {'id': 257, 'name': 'Brush'},
263
+ # {'id': 258, 'name': 'Penguin'},
264
+ # {'id': 259, 'name': 'Megaphone'},
265
+ # {'id': 260, 'name': 'Corn'},
266
+ # {'id': 261, 'name': 'Lettuce'},
267
+ # {'id': 262, 'name': 'Garlic'},
268
+ # {'id': 263, 'name': 'Swan'},
269
+ # {'id': 264, 'name': 'Helicopter'},
270
+ # {'id': 265, 'name': 'Green Onion'},
271
+ # {'id': 266, 'name': 'Sandwich'},
272
+ # {'id': 267, 'name': 'Nuts'},
273
+ # {'id': 268, 'name': 'Speed Limit Sign'},
274
+ # {'id': 269, 'name': 'Induction Cooker'},
275
+ # {'id': 270, 'name': 'Broom'},
276
+ # {'id': 271, 'name': 'Trombone'},
277
+ # {'id': 272, 'name': 'Plum'},
278
+ # {'id': 273, 'name': 'Rickshaw'},
279
+ # {'id': 274, 'name': 'Goldfish'},
280
+ # {'id': 275, 'name': 'Kiwi fruit'},
281
+ # {'id': 276, 'name': 'Router/modem'},
282
+ # {'id': 277, 'name': 'Poker Card'},
283
+ # {'id': 278, 'name': 'Toaster'},
284
+ # {'id': 279, 'name': 'Shrimp'},
285
+ # {'id': 280, 'name': 'Sushi'},
286
+ # {'id': 281, 'name': 'Cheese'},
287
+ # {'id': 282, 'name': 'Notepaper'},
288
+ # {'id': 283, 'name': 'Cherry'},
289
+ # {'id': 284, 'name': 'Pliers'},
290
+ # {'id': 285, 'name': 'CD'},
291
+ # {'id': 286, 'name': 'Pasta'},
292
+ # {'id': 287, 'name': 'Hammer'},
293
+ # {'id': 288, 'name': 'Cue'},
294
+ # {'id': 289, 'name': 'Avocado'},
295
+ # {'id': 290, 'name': 'Hamimelon'},
296
+ # {'id': 291, 'name': 'Flask'},
297
+ # {'id': 292, 'name': 'Mushroon'},
298
+ # {'id': 293, 'name': 'Screwdriver'},
299
+ # {'id': 294, 'name': 'Soap'},
300
+ # {'id': 295, 'name': 'Recorder'},
301
+ # {'id': 296, 'name': 'Bear'},
302
+ # {'id': 297, 'name': 'Eggplant'},
303
+ # {'id': 298, 'name': 'Board Eraser'},
304
+ # {'id': 299, 'name': 'Coconut'},
305
+ # {'id': 300, 'name': 'Tape Measur/ Ruler'},
306
+ # {'id': 301, 'name': 'Pig'},
307
+ # {'id': 302, 'name': 'Showerhead'},
308
+ # {'id': 303, 'name': 'Globe'},
309
+ # {'id': 304, 'name': 'Chips'},
310
+ # {'id': 305, 'name': 'Steak'},
311
+ # {'id': 306, 'name': 'Crosswalk Sign'},
312
+ # {'id': 307, 'name': 'Stapler'},
313
+ # {'id': 308, 'name': 'Campel'},
314
+ # {'id': 309, 'name': 'Formula 1 '},
315
+ # {'id': 310, 'name': 'Pomegranate'},
316
+ # {'id': 311, 'name': 'Dishwasher'},
317
+ # {'id': 312, 'name': 'Crab'},
318
+ # {'id': 313, 'name': 'Hoverboard'},
319
+ # {'id': 314, 'name': 'Meat ball'},
320
+ # {'id': 315, 'name': 'Rice Cooker'},
321
+ # {'id': 316, 'name': 'Tuba'},
322
+ # {'id': 317, 'name': 'Calculator'},
323
+ # {'id': 318, 'name': 'Papaya'},
324
+ # {'id': 319, 'name': 'Antelope'},
325
+ # {'id': 320, 'name': 'Parrot'},
326
+ # {'id': 321, 'name': 'Seal'},
327
+ # {'id': 322, 'name': 'Buttefly'},
328
+ # {'id': 323, 'name': 'Dumbbell'},
329
+ # {'id': 324, 'name': 'Donkey'},
330
+ # {'id': 325, 'name': 'Lion'},
331
+ # {'id': 326, 'name': 'Urinal'},
332
+ # {'id': 327, 'name': 'Dolphin'},
333
+ # {'id': 328, 'name': 'Electric Drill'},
334
+ # {'id': 329, 'name': 'Hair Dryer'},
335
+ # {'id': 330, 'name': 'Egg tart'},
336
+ # {'id': 331, 'name': 'Jellyfish'},
337
+ # {'id': 332, 'name': 'Treadmill'},
338
+ # {'id': 333, 'name': 'Lighter'},
339
+ # {'id': 334, 'name': 'Grapefruit'},
340
+ # {'id': 335, 'name': 'Game board'},
341
+ # {'id': 336, 'name': 'Mop'},
342
+ # {'id': 337, 'name': 'Radish'},
343
+ # {'id': 338, 'name': 'Baozi'},
344
+ # {'id': 339, 'name': 'Target'},
345
+ # {'id': 340, 'name': 'French'},
346
+ # {'id': 341, 'name': 'Spring Rolls'},
347
+ # {'id': 342, 'name': 'Monkey'},
348
+ # {'id': 343, 'name': 'Rabbit'},
349
+ # {'id': 344, 'name': 'Pencil Case'},
350
+ # {'id': 345, 'name': 'Yak'},
351
+ # {'id': 346, 'name': 'Red Cabbage'},
352
+ # {'id': 347, 'name': 'Binoculars'},
353
+ # {'id': 348, 'name': 'Asparagus'},
354
+ # {'id': 349, 'name': 'Barbell'},
355
+ # {'id': 350, 'name': 'Scallop'},
356
+ # {'id': 351, 'name': 'Noddles'},
357
+ # {'id': 352, 'name': 'Comb'},
358
+ # {'id': 353, 'name': 'Dumpling'},
359
+ # {'id': 354, 'name': 'Oyster'},
360
+ # {'id': 355, 'name': 'Table Teniis paddle'},
361
+ # {'id': 356, 'name': 'Cosmetics Brush/Eyeliner Pencil'},
362
+ # {'id': 357, 'name': 'Chainsaw'},
363
+ # {'id': 358, 'name': 'Eraser'},
364
+ # {'id': 359, 'name': 'Lobster'},
365
+ # {'id': 360, 'name': 'Durian'},
366
+ # {'id': 361, 'name': 'Okra'},
367
+ # {'id': 362, 'name': 'Lipstick'},
368
+ # {'id': 363, 'name': 'Cosmetics Mirror'},
369
+ # {'id': 364, 'name': 'Curling'},
370
+ # {'id': 365, 'name': 'Table Tennis '},
371
+ # ]
372
+
373
+ '''
374
+ The official Objects365 category names contains typos.
375
+ Below is a manual fix.
376
+ '''
377
+ categories_v2_fix = [
378
+ {'id': 1, 'name': 'Person'},
379
+ {'id': 2, 'name': 'Sneakers'},
380
+ {'id': 3, 'name': 'Chair'},
381
+ {'id': 4, 'name': 'Other Shoes'},
382
+ {'id': 5, 'name': 'Hat'},
383
+ {'id': 6, 'name': 'Car'},
384
+ {'id': 7, 'name': 'Lamp'},
385
+ {'id': 8, 'name': 'Glasses'},
386
+ {'id': 9, 'name': 'Bottle'},
387
+ {'id': 10, 'name': 'Desk'},
388
+ {'id': 11, 'name': 'Cup'},
389
+ {'id': 12, 'name': 'Street Lights'},
390
+ {'id': 13, 'name': 'Cabinet/shelf'},
391
+ {'id': 14, 'name': 'Handbag/Satchel'},
392
+ {'id': 15, 'name': 'Bracelet'},
393
+ {'id': 16, 'name': 'Plate'},
394
+ {'id': 17, 'name': 'Picture/Frame'},
395
+ {'id': 18, 'name': 'Helmet'},
396
+ {'id': 19, 'name': 'Book'},
397
+ {'id': 20, 'name': 'Gloves'},
398
+ {'id': 21, 'name': 'Storage box'},
399
+ {'id': 22, 'name': 'Boat'},
400
+ {'id': 23, 'name': 'Leather Shoes'},
401
+ {'id': 24, 'name': 'Flower'},
402
+ {'id': 25, 'name': 'Bench'},
403
+ {'id': 26, 'name': 'Potted Plant'},
404
+ {'id': 27, 'name': 'Bowl/Basin'},
405
+ {'id': 28, 'name': 'Flag'},
406
+ {'id': 29, 'name': 'Pillow'},
407
+ {'id': 30, 'name': 'Boots'},
408
+ {'id': 31, 'name': 'Vase'},
409
+ {'id': 32, 'name': 'Microphone'},
410
+ {'id': 33, 'name': 'Necklace'},
411
+ {'id': 34, 'name': 'Ring'},
412
+ {'id': 35, 'name': 'SUV'},
413
+ {'id': 36, 'name': 'Wine Glass'},
414
+ {'id': 37, 'name': 'Belt'},
415
+ {'id': 38, 'name': 'Monitor/TV'},
416
+ {'id': 39, 'name': 'Backpack'},
417
+ {'id': 40, 'name': 'Umbrella'},
418
+ {'id': 41, 'name': 'Traffic Light'},
419
+ {'id': 42, 'name': 'Speaker'},
420
+ {'id': 43, 'name': 'Watch'},
421
+ {'id': 44, 'name': 'Tie'},
422
+ {'id': 45, 'name': 'Trash bin Can'},
423
+ {'id': 46, 'name': 'Slippers'},
424
+ {'id': 47, 'name': 'Bicycle'},
425
+ {'id': 48, 'name': 'Stool'},
426
+ {'id': 49, 'name': 'Barrel/bucket'},
427
+ {'id': 50, 'name': 'Van'},
428
+ {'id': 51, 'name': 'Couch'},
429
+ {'id': 52, 'name': 'Sandals'},
430
+ {'id': 53, 'name': 'Basket'},
431
+ {'id': 54, 'name': 'Drum'},
432
+ {'id': 55, 'name': 'Pen/Pencil'},
433
+ {'id': 56, 'name': 'Bus'},
434
+ {'id': 57, 'name': 'Wild Bird'},
435
+ {'id': 58, 'name': 'High Heels'},
436
+ {'id': 59, 'name': 'Motorcycle'},
437
+ {'id': 60, 'name': 'Guitar'},
438
+ {'id': 61, 'name': 'Carpet'},
439
+ {'id': 62, 'name': 'Cell Phone'},
440
+ {'id': 63, 'name': 'Bread'},
441
+ {'id': 64, 'name': 'Camera'},
442
+ {'id': 65, 'name': 'Canned'},
443
+ {'id': 66, 'name': 'Truck'},
444
+ {'id': 67, 'name': 'Traffic cone'},
445
+ {'id': 68, 'name': 'Cymbal'},
446
+ {'id': 69, 'name': 'Lifesaver'},
447
+ {'id': 70, 'name': 'Towel'},
448
+ {'id': 71, 'name': 'Stuffed Toy'},
449
+ {'id': 72, 'name': 'Candle'},
450
+ {'id': 73, 'name': 'Sailboat'},
451
+ {'id': 74, 'name': 'Laptop'},
452
+ {'id': 75, 'name': 'Awning'},
453
+ {'id': 76, 'name': 'Bed'},
454
+ {'id': 77, 'name': 'Faucet'},
455
+ {'id': 78, 'name': 'Tent'},
456
+ {'id': 79, 'name': 'Horse'},
457
+ {'id': 80, 'name': 'Mirror'},
458
+ {'id': 81, 'name': 'Power outlet'},
459
+ {'id': 82, 'name': 'Sink'},
460
+ {'id': 83, 'name': 'Apple'},
461
+ {'id': 84, 'name': 'Air Conditioner'},
462
+ {'id': 85, 'name': 'Knife'},
463
+ {'id': 86, 'name': 'Hockey Stick'},
464
+ {'id': 87, 'name': 'Paddle'},
465
+ {'id': 88, 'name': 'Pickup Truck'},
466
+ {'id': 89, 'name': 'Fork'},
467
+ {'id': 90, 'name': 'Traffic Sign'},
468
+ {'id': 91, 'name': 'Ballon'},
469
+ {'id': 92, 'name': 'Tripod'},
470
+ {'id': 93, 'name': 'Dog'},
471
+ {'id': 94, 'name': 'Spoon'},
472
+ {'id': 95, 'name': 'Clock'},
473
+ {'id': 96, 'name': 'Pot'},
474
+ {'id': 97, 'name': 'Cow'},
475
+ {'id': 98, 'name': 'Cake'},
476
+ {'id': 99, 'name': 'Dining Table'},
477
+ {'id': 100, 'name': 'Sheep'},
478
+ {'id': 101, 'name': 'Hanger'},
479
+ {'id': 102, 'name': 'Blackboard/Whiteboard'},
480
+ {'id': 103, 'name': 'Napkin'},
481
+ {'id': 104, 'name': 'Other Fish'},
482
+ {'id': 105, 'name': 'Orange/Tangerine'},
483
+ {'id': 106, 'name': 'Toiletry'},
484
+ {'id': 107, 'name': 'Keyboard'},
485
+ {'id': 108, 'name': 'Tomato'},
486
+ {'id': 109, 'name': 'Lantern'},
487
+ {'id': 110, 'name': 'Machinery Vehicle'},
488
+ {'id': 111, 'name': 'Fan'},
489
+ {'id': 112, 'name': 'Green Vegetables'},
490
+ {'id': 113, 'name': 'Banana'},
491
+ {'id': 114, 'name': 'Baseball Glove'},
492
+ {'id': 115, 'name': 'Airplane'},
493
+ {'id': 116, 'name': 'Mouse'},
494
+ {'id': 117, 'name': 'Train'},
495
+ {'id': 118, 'name': 'Pumpkin'},
496
+ {'id': 119, 'name': 'Soccer'},
497
+ {'id': 120, 'name': 'Skiboard'},
498
+ {'id': 121, 'name': 'Luggage'},
499
+ {'id': 122, 'name': 'Nightstand'},
500
+ {'id': 123, 'name': 'Teapot'},
501
+ {'id': 124, 'name': 'Telephone'},
502
+ {'id': 125, 'name': 'Trolley'},
503
+ {'id': 126, 'name': 'Head Phone'},
504
+ {'id': 127, 'name': 'Sports Car'},
505
+ {'id': 128, 'name': 'Stop Sign'},
506
+ {'id': 129, 'name': 'Dessert'},
507
+ {'id': 130, 'name': 'Scooter'},
508
+ {'id': 131, 'name': 'Stroller'},
509
+ {'id': 132, 'name': 'Crane'},
510
+ {'id': 133, 'name': 'Remote'},
511
+ {'id': 134, 'name': 'Refrigerator'},
512
+ {'id': 135, 'name': 'Oven'},
513
+ {'id': 136, 'name': 'Lemon'},
514
+ {'id': 137, 'name': 'Duck'},
515
+ {'id': 138, 'name': 'Baseball Bat'},
516
+ {'id': 139, 'name': 'Surveillance Camera'},
517
+ {'id': 140, 'name': 'Cat'},
518
+ {'id': 141, 'name': 'Jug'},
519
+ {'id': 142, 'name': 'Broccoli'},
520
+ {'id': 143, 'name': 'Piano'},
521
+ {'id': 144, 'name': 'Pizza'},
522
+ {'id': 145, 'name': 'Elephant'},
523
+ {'id': 146, 'name': 'Skateboard'},
524
+ {'id': 147, 'name': 'Surfboard'},
525
+ {'id': 148, 'name': 'Gun'},
526
+ {'id': 149, 'name': 'Skating and Skiing shoes'},
527
+ {'id': 150, 'name': 'Gas stove'},
528
+ {'id': 151, 'name': 'Donut'},
529
+ {'id': 152, 'name': 'Bow Tie'},
530
+ {'id': 153, 'name': 'Carrot'},
531
+ {'id': 154, 'name': 'Toilet'},
532
+ {'id': 155, 'name': 'Kite'},
533
+ {'id': 156, 'name': 'Strawberry'},
534
+ {'id': 157, 'name': 'Other Balls'},
535
+ {'id': 158, 'name': 'Shovel'},
536
+ {'id': 159, 'name': 'Pepper'},
537
+ {'id': 160, 'name': 'Computer Box'},
538
+ {'id': 161, 'name': 'Toilet Paper'},
539
+ {'id': 162, 'name': 'Cleaning Products'},
540
+ {'id': 163, 'name': 'Chopsticks'},
541
+ {'id': 164, 'name': 'Microwave'},
542
+ {'id': 165, 'name': 'Pigeon'},
543
+ {'id': 166, 'name': 'Baseball'},
544
+ {'id': 167, 'name': 'Cutting/chopping Board'},
545
+ {'id': 168, 'name': 'Coffee Table'},
546
+ {'id': 169, 'name': 'Side Table'},
547
+ {'id': 170, 'name': 'Scissors'},
548
+ {'id': 171, 'name': 'Marker'},
549
+ {'id': 172, 'name': 'Pie'},
550
+ {'id': 173, 'name': 'Ladder'},
551
+ {'id': 174, 'name': 'Snowboard'},
552
+ {'id': 175, 'name': 'Cookies'},
553
+ {'id': 176, 'name': 'Radiator'},
554
+ {'id': 177, 'name': 'Fire Hydrant'},
555
+ {'id': 178, 'name': 'Basketball'},
556
+ {'id': 179, 'name': 'Zebra'},
557
+ {'id': 180, 'name': 'Grape'},
558
+ {'id': 181, 'name': 'Giraffe'},
559
+ {'id': 182, 'name': 'Potato'},
560
+ {'id': 183, 'name': 'Sausage'},
561
+ {'id': 184, 'name': 'Tricycle'},
562
+ {'id': 185, 'name': 'Violin'},
563
+ {'id': 186, 'name': 'Egg'},
564
+ {'id': 187, 'name': 'Fire Extinguisher'},
565
+ {'id': 188, 'name': 'Candy'},
566
+ {'id': 189, 'name': 'Fire Truck'},
567
+ {'id': 190, 'name': 'Billards'},
568
+ {'id': 191, 'name': 'Converter'},
569
+ {'id': 192, 'name': 'Bathtub'},
570
+ {'id': 193, 'name': 'Wheelchair'},
571
+ {'id': 194, 'name': 'Golf Club'},
572
+ {'id': 195, 'name': 'Briefcase'},
573
+ {'id': 196, 'name': 'Cucumber'},
574
+ {'id': 197, 'name': 'Cigar/Cigarette '},
575
+ {'id': 198, 'name': 'Paint Brush'},
576
+ {'id': 199, 'name': 'Pear'},
577
+ {'id': 200, 'name': 'Heavy Truck'},
578
+ {'id': 201, 'name': 'Hamburger'},
579
+ {'id': 202, 'name': 'Extractor'},
580
+ {'id': 203, 'name': 'Extension Cord'},
581
+ {'id': 204, 'name': 'Tong'},
582
+ {'id': 205, 'name': 'Tennis Racket'},
583
+ {'id': 206, 'name': 'Folder'},
584
+ {'id': 207, 'name': 'American Football'},
585
+ {'id': 208, 'name': 'earphone'},
586
+ {'id': 209, 'name': 'Mask'},
587
+ {'id': 210, 'name': 'Kettle'},
588
+ {'id': 211, 'name': 'Tennis'},
589
+ {'id': 212, 'name': 'Ship'},
590
+ {'id': 213, 'name': 'Swing'},
591
+ {'id': 214, 'name': 'Coffee Machine'},
592
+ {'id': 215, 'name': 'Slide'},
593
+ {'id': 216, 'name': 'Carriage'},
594
+ {'id': 217, 'name': 'Onion'},
595
+ {'id': 218, 'name': 'Green beans'},
596
+ {'id': 219, 'name': 'Projector'},
597
+ {'id': 220, 'name': 'Frisbee'},
598
+ {'id': 221, 'name': 'Washing Machine/Drying Machine'},
599
+ {'id': 222, 'name': 'Chicken'},
600
+ {'id': 223, 'name': 'Printer'},
601
+ {'id': 224, 'name': 'Watermelon'},
602
+ {'id': 225, 'name': 'Saxophone'},
603
+ {'id': 226, 'name': 'Tissue'},
604
+ {'id': 227, 'name': 'Toothbrush'},
605
+ {'id': 228, 'name': 'Ice cream'},
606
+ {'id': 229, 'name': 'Hot air balloon'},
607
+ {'id': 230, 'name': 'Cello'},
608
+ {'id': 231, 'name': 'French Fries'},
609
+ {'id': 232, 'name': 'Scale'},
610
+ {'id': 233, 'name': 'Trophy'},
611
+ {'id': 234, 'name': 'Cabbage'},
612
+ {'id': 235, 'name': 'Hot dog'},
613
+ {'id': 236, 'name': 'Blender'},
614
+ {'id': 237, 'name': 'Peach'},
615
+ {'id': 238, 'name': 'Rice'},
616
+ {'id': 239, 'name': 'Wallet/Purse'},
617
+ {'id': 240, 'name': 'Volleyball'},
618
+ {'id': 241, 'name': 'Deer'},
619
+ {'id': 242, 'name': 'Goose'},
620
+ {'id': 243, 'name': 'Tape'},
621
+ {'id': 244, 'name': 'Tablet'},
622
+ {'id': 245, 'name': 'Cosmetics'},
623
+ {'id': 246, 'name': 'Trumpet'},
624
+ {'id': 247, 'name': 'Pineapple'},
625
+ {'id': 248, 'name': 'Golf Ball'},
626
+ {'id': 249, 'name': 'Ambulance'},
627
+ {'id': 250, 'name': 'Parking meter'},
628
+ {'id': 251, 'name': 'Mango'},
629
+ {'id': 252, 'name': 'Key'},
630
+ {'id': 253, 'name': 'Hurdle'},
631
+ {'id': 254, 'name': 'Fishing Rod'},
632
+ {'id': 255, 'name': 'Medal'},
633
+ {'id': 256, 'name': 'Flute'},
634
+ {'id': 257, 'name': 'Brush'},
635
+ {'id': 258, 'name': 'Penguin'},
636
+ {'id': 259, 'name': 'Megaphone'},
637
+ {'id': 260, 'name': 'Corn'},
638
+ {'id': 261, 'name': 'Lettuce'},
639
+ {'id': 262, 'name': 'Garlic'},
640
+ {'id': 263, 'name': 'Swan'},
641
+ {'id': 264, 'name': 'Helicopter'},
642
+ {'id': 265, 'name': 'Green Onion'},
643
+ {'id': 266, 'name': 'Sandwich'},
644
+ {'id': 267, 'name': 'Nuts'},
645
+ {'id': 268, 'name': 'Speed Limit Sign'},
646
+ {'id': 269, 'name': 'Induction Cooker'},
647
+ {'id': 270, 'name': 'Broom'},
648
+ {'id': 271, 'name': 'Trombone'},
649
+ {'id': 272, 'name': 'Plum'},
650
+ {'id': 273, 'name': 'Rickshaw'},
651
+ {'id': 274, 'name': 'Goldfish'},
652
+ {'id': 275, 'name': 'Kiwi fruit'},
653
+ {'id': 276, 'name': 'Router/modem'},
654
+ {'id': 277, 'name': 'Poker Card'},
655
+ {'id': 278, 'name': 'Toaster'},
656
+ {'id': 279, 'name': 'Shrimp'},
657
+ {'id': 280, 'name': 'Sushi'},
658
+ {'id': 281, 'name': 'Cheese'},
659
+ {'id': 282, 'name': 'Notepaper'},
660
+ {'id': 283, 'name': 'Cherry'},
661
+ {'id': 284, 'name': 'Pliers'},
662
+ {'id': 285, 'name': 'CD'},
663
+ {'id': 286, 'name': 'Pasta'},
664
+ {'id': 287, 'name': 'Hammer'},
665
+ {'id': 288, 'name': 'Cue'},
666
+ {'id': 289, 'name': 'Avocado'},
667
+ {'id': 290, 'name': 'Hami melon'},
668
+ {'id': 291, 'name': 'Flask'},
669
+ {'id': 292, 'name': 'Mushroom'},
670
+ {'id': 293, 'name': 'Screwdriver'},
671
+ {'id': 294, 'name': 'Soap'},
672
+ {'id': 295, 'name': 'Recorder'},
673
+ {'id': 296, 'name': 'Bear'},
674
+ {'id': 297, 'name': 'Eggplant'},
675
+ {'id': 298, 'name': 'Board Eraser'},
676
+ {'id': 299, 'name': 'Coconut'},
677
+ {'id': 300, 'name': 'Tape Measure/ Ruler'},
678
+ {'id': 301, 'name': 'Pig'},
679
+ {'id': 302, 'name': 'Showerhead'},
680
+ {'id': 303, 'name': 'Globe'},
681
+ {'id': 304, 'name': 'Chips'},
682
+ {'id': 305, 'name': 'Steak'},
683
+ {'id': 306, 'name': 'Crosswalk Sign'},
684
+ {'id': 307, 'name': 'Stapler'},
685
+ {'id': 308, 'name': 'Camel'},
686
+ {'id': 309, 'name': 'Formula 1 '},
687
+ {'id': 310, 'name': 'Pomegranate'},
688
+ {'id': 311, 'name': 'Dishwasher'},
689
+ {'id': 312, 'name': 'Crab'},
690
+ {'id': 313, 'name': 'Hoverboard'},
691
+ {'id': 314, 'name': 'Meatball'},
692
+ {'id': 315, 'name': 'Rice Cooker'},
693
+ {'id': 316, 'name': 'Tuba'},
694
+ {'id': 317, 'name': 'Calculator'},
695
+ {'id': 318, 'name': 'Papaya'},
696
+ {'id': 319, 'name': 'Antelope'},
697
+ {'id': 320, 'name': 'Parrot'},
698
+ {'id': 321, 'name': 'Seal'},
699
+ {'id': 322, 'name': 'Butterfly'},
700
+ {'id': 323, 'name': 'Dumbbell'},
701
+ {'id': 324, 'name': 'Donkey'},
702
+ {'id': 325, 'name': 'Lion'},
703
+ {'id': 326, 'name': 'Urinal'},
704
+ {'id': 327, 'name': 'Dolphin'},
705
+ {'id': 328, 'name': 'Electric Drill'},
706
+ {'id': 329, 'name': 'Hair Dryer'},
707
+ {'id': 330, 'name': 'Egg tart'},
708
+ {'id': 331, 'name': 'Jellyfish'},
709
+ {'id': 332, 'name': 'Treadmill'},
710
+ {'id': 333, 'name': 'Lighter'},
711
+ {'id': 334, 'name': 'Grapefruit'},
712
+ {'id': 335, 'name': 'Game board'},
713
+ {'id': 336, 'name': 'Mop'},
714
+ {'id': 337, 'name': 'Radish'},
715
+ {'id': 338, 'name': 'Baozi'},
716
+ {'id': 339, 'name': 'Target'},
717
+ {'id': 340, 'name': 'French'},
718
+ {'id': 341, 'name': 'Spring Rolls'},
719
+ {'id': 342, 'name': 'Monkey'},
720
+ {'id': 343, 'name': 'Rabbit'},
721
+ {'id': 344, 'name': 'Pencil Case'},
722
+ {'id': 345, 'name': 'Yak'},
723
+ {'id': 346, 'name': 'Red Cabbage'},
724
+ {'id': 347, 'name': 'Binoculars'},
725
+ {'id': 348, 'name': 'Asparagus'},
726
+ {'id': 349, 'name': 'Barbell'},
727
+ {'id': 350, 'name': 'Scallop'},
728
+ {'id': 351, 'name': 'Noddles'},
729
+ {'id': 352, 'name': 'Comb'},
730
+ {'id': 353, 'name': 'Dumpling'},
731
+ {'id': 354, 'name': 'Oyster'},
732
+ {'id': 355, 'name': 'Table Tennis paddle'},
733
+ {'id': 356, 'name': 'Cosmetics Brush/Eyeliner Pencil'},
734
+ {'id': 357, 'name': 'Chainsaw'},
735
+ {'id': 358, 'name': 'Eraser'},
736
+ {'id': 359, 'name': 'Lobster'},
737
+ {'id': 360, 'name': 'Durian'},
738
+ {'id': 361, 'name': 'Okra'},
739
+ {'id': 362, 'name': 'Lipstick'},
740
+ {'id': 363, 'name': 'Cosmetics Mirror'},
741
+ {'id': 364, 'name': 'Curling'},
742
+ {'id': 365, 'name': 'Table Tennis '},
743
+ ]
744
+
745
+
746
+ def _get_builtin_metadata():
747
+ id_to_name = {x['id']: x['name'] for x in categories_v2_fix}
748
+ thing_dataset_id_to_contiguous_id = {
749
+ x['id']: i for i, x in enumerate(
750
+ sorted(categories_v2_fix, key=lambda x: x['id']))}
751
+ thing_classes = [id_to_name[k] for k in sorted(id_to_name)]
752
+ return {
753
+ "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
754
+ "thing_classes": thing_classes}
755
+
756
+
757
+ _PREDEFINED_SPLITS_OBJECTS365 = {
758
+ "objects365_v2_train": ("objects365/train", "objects365/annotations/zhiyuan_objv2_train_fixname_fixmiss.json"),
759
+ # 80,000 images, 1,240,587 annotations
760
+ "objects365_v2_val": ("objects365/val", "objects365/annotations/zhiyuan_objv2_val_fixname.json"),
761
+ "objects365_v2_val_rare": ("objects365/val", "objects365/annotations/zhiyuan_objv2_val_fixname_rare.json"),
762
+ }
763
+
764
+ for key, (image_root, json_file) in _PREDEFINED_SPLITS_OBJECTS365.items():
765
+ register_coco_instances(
766
+ key,
767
+ _get_builtin_metadata(),
768
+ os.path.join("datasets", json_file) if "://" not in json_file else json_file,
769
+ os.path.join("datasets", image_root),
770
+ )
proxydet/data/datasets/oid.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Part of the code is from https://github.com/xingyizhou/UniDet/blob/master/projects/UniDet/unidet/data/datasets/oid.py
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ from .register_oid import register_oid_instances
4
+ import os
5
+
6
+ categories = [
7
+ {'id': 1, 'name': 'Infant bed', 'freebase_id': '/m/061hd_'},
8
+ {'id': 2, 'name': 'Rose', 'freebase_id': '/m/06m11'},
9
+ {'id': 3, 'name': 'Flag', 'freebase_id': '/m/03120'},
10
+ {'id': 4, 'name': 'Flashlight', 'freebase_id': '/m/01kb5b'},
11
+ {'id': 5, 'name': 'Sea turtle', 'freebase_id': '/m/0120dh'},
12
+ {'id': 6, 'name': 'Camera', 'freebase_id': '/m/0dv5r'},
13
+ {'id': 7, 'name': 'Animal', 'freebase_id': '/m/0jbk'},
14
+ {'id': 8, 'name': 'Glove', 'freebase_id': '/m/0174n1'},
15
+ {'id': 9, 'name': 'Crocodile', 'freebase_id': '/m/09f_2'},
16
+ {'id': 10, 'name': 'Cattle', 'freebase_id': '/m/01xq0k1'},
17
+ {'id': 11, 'name': 'House', 'freebase_id': '/m/03jm5'},
18
+ {'id': 12, 'name': 'Guacamole', 'freebase_id': '/m/02g30s'},
19
+ {'id': 13, 'name': 'Penguin', 'freebase_id': '/m/05z6w'},
20
+ {'id': 14, 'name': 'Vehicle registration plate', 'freebase_id': '/m/01jfm_'},
21
+ {'id': 15, 'name': 'Bench', 'freebase_id': '/m/076lb9'},
22
+ {'id': 16, 'name': 'Ladybug', 'freebase_id': '/m/0gj37'},
23
+ {'id': 17, 'name': 'Human nose', 'freebase_id': '/m/0k0pj'},
24
+ {'id': 18, 'name': 'Watermelon', 'freebase_id': '/m/0kpqd'},
25
+ {'id': 19, 'name': 'Flute', 'freebase_id': '/m/0l14j_'},
26
+ {'id': 20, 'name': 'Butterfly', 'freebase_id': '/m/0cyf8'},
27
+ {'id': 21, 'name': 'Washing machine', 'freebase_id': '/m/0174k2'},
28
+ {'id': 22, 'name': 'Raccoon', 'freebase_id': '/m/0dq75'},
29
+ {'id': 23, 'name': 'Segway', 'freebase_id': '/m/076bq'},
30
+ {'id': 24, 'name': 'Taco', 'freebase_id': '/m/07crc'},
31
+ {'id': 25, 'name': 'Jellyfish', 'freebase_id': '/m/0d8zb'},
32
+ {'id': 26, 'name': 'Cake', 'freebase_id': '/m/0fszt'},
33
+ {'id': 27, 'name': 'Pen', 'freebase_id': '/m/0k1tl'},
34
+ {'id': 28, 'name': 'Cannon', 'freebase_id': '/m/020kz'},
35
+ {'id': 29, 'name': 'Bread', 'freebase_id': '/m/09728'},
36
+ {'id': 30, 'name': 'Tree', 'freebase_id': '/m/07j7r'},
37
+ {'id': 31, 'name': 'Shellfish', 'freebase_id': '/m/0fbdv'},
38
+ {'id': 32, 'name': 'Bed', 'freebase_id': '/m/03ssj5'},
39
+ {'id': 33, 'name': 'Hamster', 'freebase_id': '/m/03qrc'},
40
+ {'id': 34, 'name': 'Hat', 'freebase_id': '/m/02dl1y'},
41
+ {'id': 35, 'name': 'Toaster', 'freebase_id': '/m/01k6s3'},
42
+ {'id': 36, 'name': 'Sombrero', 'freebase_id': '/m/02jfl0'},
43
+ {'id': 37, 'name': 'Tiara', 'freebase_id': '/m/01krhy'},
44
+ {'id': 38, 'name': 'Bowl', 'freebase_id': '/m/04kkgm'},
45
+ {'id': 39, 'name': 'Dragonfly', 'freebase_id': '/m/0ft9s'},
46
+ {'id': 40, 'name': 'Moths and butterflies', 'freebase_id': '/m/0d_2m'},
47
+ {'id': 41, 'name': 'Antelope', 'freebase_id': '/m/0czz2'},
48
+ {'id': 42, 'name': 'Vegetable', 'freebase_id': '/m/0f4s2w'},
49
+ {'id': 43, 'name': 'Torch', 'freebase_id': '/m/07dd4'},
50
+ {'id': 44, 'name': 'Building', 'freebase_id': '/m/0cgh4'},
51
+ {'id': 45, 'name': 'Power plugs and sockets', 'freebase_id': '/m/03bbps'},
52
+ {'id': 46, 'name': 'Blender', 'freebase_id': '/m/02pjr4'},
53
+ {'id': 47, 'name': 'Billiard table', 'freebase_id': '/m/04p0qw'},
54
+ {'id': 48, 'name': 'Cutting board', 'freebase_id': '/m/02pdsw'},
55
+ {'id': 49, 'name': 'Bronze sculpture', 'freebase_id': '/m/01yx86'},
56
+ {'id': 50, 'name': 'Turtle', 'freebase_id': '/m/09dzg'},
57
+ {'id': 51, 'name': 'Broccoli', 'freebase_id': '/m/0hkxq'},
58
+ {'id': 52, 'name': 'Tiger', 'freebase_id': '/m/07dm6'},
59
+ {'id': 53, 'name': 'Mirror', 'freebase_id': '/m/054_l'},
60
+ {'id': 54, 'name': 'Bear', 'freebase_id': '/m/01dws'},
61
+ {'id': 55, 'name': 'Zucchini', 'freebase_id': '/m/027pcv'},
62
+ {'id': 56, 'name': 'Dress', 'freebase_id': '/m/01d40f'},
63
+ {'id': 57, 'name': 'Volleyball', 'freebase_id': '/m/02rgn06'},
64
+ {'id': 58, 'name': 'Guitar', 'freebase_id': '/m/0342h'},
65
+ {'id': 59, 'name': 'Reptile', 'freebase_id': '/m/06bt6'},
66
+ {'id': 60, 'name': 'Golf cart', 'freebase_id': '/m/0323sq'},
67
+ {'id': 61, 'name': 'Tart', 'freebase_id': '/m/02zvsm'},
68
+ {'id': 62, 'name': 'Fedora', 'freebase_id': '/m/02fq_6'},
69
+ {'id': 63, 'name': 'Carnivore', 'freebase_id': '/m/01lrl'},
70
+ {'id': 64, 'name': 'Car', 'freebase_id': '/m/0k4j'},
71
+ {'id': 65, 'name': 'Lighthouse', 'freebase_id': '/m/04h7h'},
72
+ {'id': 66, 'name': 'Coffeemaker', 'freebase_id': '/m/07xyvk'},
73
+ {'id': 67, 'name': 'Food processor', 'freebase_id': '/m/03y6mg'},
74
+ {'id': 68, 'name': 'Truck', 'freebase_id': '/m/07r04'},
75
+ {'id': 69, 'name': 'Bookcase', 'freebase_id': '/m/03__z0'},
76
+ {'id': 70, 'name': 'Surfboard', 'freebase_id': '/m/019w40'},
77
+ {'id': 71, 'name': 'Footwear', 'freebase_id': '/m/09j5n'},
78
+ {'id': 72, 'name': 'Bench', 'freebase_id': '/m/0cvnqh'},
79
+ {'id': 73, 'name': 'Necklace', 'freebase_id': '/m/01llwg'},
80
+ {'id': 74, 'name': 'Flower', 'freebase_id': '/m/0c9ph5'},
81
+ {'id': 75, 'name': 'Radish', 'freebase_id': '/m/015x5n'},
82
+ {'id': 76, 'name': 'Marine mammal', 'freebase_id': '/m/0gd2v'},
83
+ {'id': 77, 'name': 'Frying pan', 'freebase_id': '/m/04v6l4'},
84
+ {'id': 78, 'name': 'Tap', 'freebase_id': '/m/02jz0l'},
85
+ {'id': 79, 'name': 'Peach', 'freebase_id': '/m/0dj6p'},
86
+ {'id': 80, 'name': 'Knife', 'freebase_id': '/m/04ctx'},
87
+ {'id': 81, 'name': 'Handbag', 'freebase_id': '/m/080hkjn'},
88
+ {'id': 82, 'name': 'Laptop', 'freebase_id': '/m/01c648'},
89
+ {'id': 83, 'name': 'Tent', 'freebase_id': '/m/01j61q'},
90
+ {'id': 84, 'name': 'Ambulance', 'freebase_id': '/m/012n7d'},
91
+ {'id': 85, 'name': 'Christmas tree', 'freebase_id': '/m/025nd'},
92
+ {'id': 86, 'name': 'Eagle', 'freebase_id': '/m/09csl'},
93
+ {'id': 87, 'name': 'Limousine', 'freebase_id': '/m/01lcw4'},
94
+ {'id': 88, 'name': 'Kitchen & dining room table', 'freebase_id': '/m/0h8n5zk'},
95
+ {'id': 89, 'name': 'Polar bear', 'freebase_id': '/m/0633h'},
96
+ {'id': 90, 'name': 'Tower', 'freebase_id': '/m/01fdzj'},
97
+ {'id': 91, 'name': 'Football', 'freebase_id': '/m/01226z'},
98
+ {'id': 92, 'name': 'Willow', 'freebase_id': '/m/0mw_6'},
99
+ {'id': 93, 'name': 'Human head', 'freebase_id': '/m/04hgtk'},
100
+ {'id': 94, 'name': 'Stop sign', 'freebase_id': '/m/02pv19'},
101
+ {'id': 95, 'name': 'Banana', 'freebase_id': '/m/09qck'},
102
+ {'id': 96, 'name': 'Mixer', 'freebase_id': '/m/063rgb'},
103
+ {'id': 97, 'name': 'Binoculars', 'freebase_id': '/m/0lt4_'},
104
+ {'id': 98, 'name': 'Dessert', 'freebase_id': '/m/0270h'},
105
+ {'id': 99, 'name': 'Bee', 'freebase_id': '/m/01h3n'},
106
+ {'id': 100, 'name': 'Chair', 'freebase_id': '/m/01mzpv'},
107
+ {'id': 101, 'name': 'Wood-burning stove', 'freebase_id': '/m/04169hn'},
108
+ {'id': 102, 'name': 'Flowerpot', 'freebase_id': '/m/0fm3zh'},
109
+ {'id': 103, 'name': 'Beaker', 'freebase_id': '/m/0d20w4'},
110
+ {'id': 104, 'name': 'Oyster', 'freebase_id': '/m/0_cp5'},
111
+ {'id': 105, 'name': 'Woodpecker', 'freebase_id': '/m/01dy8n'},
112
+ {'id': 106, 'name': 'Harp', 'freebase_id': '/m/03m5k'},
113
+ {'id': 107, 'name': 'Bathtub', 'freebase_id': '/m/03dnzn'},
114
+ {'id': 108, 'name': 'Wall clock', 'freebase_id': '/m/0h8mzrc'},
115
+ {'id': 109, 'name': 'Sports uniform', 'freebase_id': '/m/0h8mhzd'},
116
+ {'id': 110, 'name': 'Rhinoceros', 'freebase_id': '/m/03d443'},
117
+ {'id': 111, 'name': 'Beehive', 'freebase_id': '/m/01gllr'},
118
+ {'id': 112, 'name': 'Cupboard', 'freebase_id': '/m/0642b4'},
119
+ {'id': 113, 'name': 'Chicken', 'freebase_id': '/m/09b5t'},
120
+ {'id': 114, 'name': 'Man', 'freebase_id': '/m/04yx4'},
121
+ {'id': 115, 'name': 'Blue jay', 'freebase_id': '/m/01f8m5'},
122
+ {'id': 116, 'name': 'Cucumber', 'freebase_id': '/m/015x4r'},
123
+ {'id': 117, 'name': 'Balloon', 'freebase_id': '/m/01j51'},
124
+ {'id': 118, 'name': 'Kite', 'freebase_id': '/m/02zt3'},
125
+ {'id': 119, 'name': 'Fireplace', 'freebase_id': '/m/03tw93'},
126
+ {'id': 120, 'name': 'Lantern', 'freebase_id': '/m/01jfsr'},
127
+ {'id': 121, 'name': 'Missile', 'freebase_id': '/m/04ylt'},
128
+ {'id': 122, 'name': 'Book', 'freebase_id': '/m/0bt_c3'},
129
+ {'id': 123, 'name': 'Spoon', 'freebase_id': '/m/0cmx8'},
130
+ {'id': 124, 'name': 'Grapefruit', 'freebase_id': '/m/0hqkz'},
131
+ {'id': 125, 'name': 'Squirrel', 'freebase_id': '/m/071qp'},
132
+ {'id': 126, 'name': 'Orange', 'freebase_id': '/m/0cyhj_'},
133
+ {'id': 127, 'name': 'Coat', 'freebase_id': '/m/01xygc'},
134
+ {'id': 128, 'name': 'Punching bag', 'freebase_id': '/m/0420v5'},
135
+ {'id': 129, 'name': 'Zebra', 'freebase_id': '/m/0898b'},
136
+ {'id': 130, 'name': 'Billboard', 'freebase_id': '/m/01knjb'},
137
+ {'id': 131, 'name': 'Bicycle', 'freebase_id': '/m/0199g'},
138
+ {'id': 132, 'name': 'Door handle', 'freebase_id': '/m/03c7gz'},
139
+ {'id': 133, 'name': 'Mechanical fan', 'freebase_id': '/m/02x984l'},
140
+ {'id': 134, 'name': 'Ring binder', 'freebase_id': '/m/04zwwv'},
141
+ {'id': 135, 'name': 'Table', 'freebase_id': '/m/04bcr3'},
142
+ {'id': 136, 'name': 'Parrot', 'freebase_id': '/m/0gv1x'},
143
+ {'id': 137, 'name': 'Sock', 'freebase_id': '/m/01nq26'},
144
+ {'id': 138, 'name': 'Vase', 'freebase_id': '/m/02s195'},
145
+ {'id': 139, 'name': 'Weapon', 'freebase_id': '/m/083kb'},
146
+ {'id': 140, 'name': 'Shotgun', 'freebase_id': '/m/06nrc'},
147
+ {'id': 141, 'name': 'Glasses', 'freebase_id': '/m/0jyfg'},
148
+ {'id': 142, 'name': 'Seahorse', 'freebase_id': '/m/0nybt'},
149
+ {'id': 143, 'name': 'Belt', 'freebase_id': '/m/0176mf'},
150
+ {'id': 144, 'name': 'Watercraft', 'freebase_id': '/m/01rzcn'},
151
+ {'id': 145, 'name': 'Window', 'freebase_id': '/m/0d4v4'},
152
+ {'id': 146, 'name': 'Giraffe', 'freebase_id': '/m/03bk1'},
153
+ {'id': 147, 'name': 'Lion', 'freebase_id': '/m/096mb'},
154
+ {'id': 148, 'name': 'Tire', 'freebase_id': '/m/0h9mv'},
155
+ {'id': 149, 'name': 'Vehicle', 'freebase_id': '/m/07yv9'},
156
+ {'id': 150, 'name': 'Canoe', 'freebase_id': '/m/0ph39'},
157
+ {'id': 151, 'name': 'Tie', 'freebase_id': '/m/01rkbr'},
158
+ {'id': 152, 'name': 'Shelf', 'freebase_id': '/m/0gjbg72'},
159
+ {'id': 153, 'name': 'Picture frame', 'freebase_id': '/m/06z37_'},
160
+ {'id': 154, 'name': 'Printer', 'freebase_id': '/m/01m4t'},
161
+ {'id': 155, 'name': 'Human leg', 'freebase_id': '/m/035r7c'},
162
+ {'id': 156, 'name': 'Boat', 'freebase_id': '/m/019jd'},
163
+ {'id': 157, 'name': 'Slow cooker', 'freebase_id': '/m/02tsc9'},
164
+ {'id': 158, 'name': 'Croissant', 'freebase_id': '/m/015wgc'},
165
+ {'id': 159, 'name': 'Candle', 'freebase_id': '/m/0c06p'},
166
+ {'id': 160, 'name': 'Pancake', 'freebase_id': '/m/01dwwc'},
167
+ {'id': 161, 'name': 'Pillow', 'freebase_id': '/m/034c16'},
168
+ {'id': 162, 'name': 'Coin', 'freebase_id': '/m/0242l'},
169
+ {'id': 163, 'name': 'Stretcher', 'freebase_id': '/m/02lbcq'},
170
+ {'id': 164, 'name': 'Sandal', 'freebase_id': '/m/03nfch'},
171
+ {'id': 165, 'name': 'Woman', 'freebase_id': '/m/03bt1vf'},
172
+ {'id': 166, 'name': 'Stairs', 'freebase_id': '/m/01lynh'},
173
+ {'id': 167, 'name': 'Harpsichord', 'freebase_id': '/m/03q5t'},
174
+ {'id': 168, 'name': 'Stool', 'freebase_id': '/m/0fqt361'},
175
+ {'id': 169, 'name': 'Bus', 'freebase_id': '/m/01bjv'},
176
+ {'id': 170, 'name': 'Suitcase', 'freebase_id': '/m/01s55n'},
177
+ {'id': 171, 'name': 'Human mouth', 'freebase_id': '/m/0283dt1'},
178
+ {'id': 172, 'name': 'Juice', 'freebase_id': '/m/01z1kdw'},
179
+ {'id': 173, 'name': 'Skull', 'freebase_id': '/m/016m2d'},
180
+ {'id': 174, 'name': 'Door', 'freebase_id': '/m/02dgv'},
181
+ {'id': 175, 'name': 'Violin', 'freebase_id': '/m/07y_7'},
182
+ {'id': 176, 'name': 'Chopsticks', 'freebase_id': '/m/01_5g'},
183
+ {'id': 177, 'name': 'Digital clock', 'freebase_id': '/m/06_72j'},
184
+ {'id': 178, 'name': 'Sunflower', 'freebase_id': '/m/0ftb8'},
185
+ {'id': 179, 'name': 'Leopard', 'freebase_id': '/m/0c29q'},
186
+ {'id': 180, 'name': 'Bell pepper', 'freebase_id': '/m/0jg57'},
187
+ {'id': 181, 'name': 'Harbor seal', 'freebase_id': '/m/02l8p9'},
188
+ {'id': 182, 'name': 'Snake', 'freebase_id': '/m/078jl'},
189
+ {'id': 183, 'name': 'Sewing machine', 'freebase_id': '/m/0llzx'},
190
+ {'id': 184, 'name': 'Goose', 'freebase_id': '/m/0dbvp'},
191
+ {'id': 185, 'name': 'Helicopter', 'freebase_id': '/m/09ct_'},
192
+ {'id': 186, 'name': 'Seat belt', 'freebase_id': '/m/0dkzw'},
193
+ {'id': 187, 'name': 'Coffee cup', 'freebase_id': '/m/02p5f1q'},
194
+ {'id': 188, 'name': 'Microwave oven', 'freebase_id': '/m/0fx9l'},
195
+ {'id': 189, 'name': 'Hot dog', 'freebase_id': '/m/01b9xk'},
196
+ {'id': 190, 'name': 'Countertop', 'freebase_id': '/m/0b3fp9'},
197
+ {'id': 191, 'name': 'Serving tray', 'freebase_id': '/m/0h8n27j'},
198
+ {'id': 192, 'name': 'Dog bed', 'freebase_id': '/m/0h8n6f9'},
199
+ {'id': 193, 'name': 'Beer', 'freebase_id': '/m/01599'},
200
+ {'id': 194, 'name': 'Sunglasses', 'freebase_id': '/m/017ftj'},
201
+ {'id': 195, 'name': 'Golf ball', 'freebase_id': '/m/044r5d'},
202
+ {'id': 196, 'name': 'Waffle', 'freebase_id': '/m/01dwsz'},
203
+ {'id': 197, 'name': 'Palm tree', 'freebase_id': '/m/0cdl1'},
204
+ {'id': 198, 'name': 'Trumpet', 'freebase_id': '/m/07gql'},
205
+ {'id': 199, 'name': 'Ruler', 'freebase_id': '/m/0hdln'},
206
+ {'id': 200, 'name': 'Helmet', 'freebase_id': '/m/0zvk5'},
207
+ {'id': 201, 'name': 'Ladder', 'freebase_id': '/m/012w5l'},
208
+ {'id': 202, 'name': 'Office building', 'freebase_id': '/m/021sj1'},
209
+ {'id': 203, 'name': 'Tablet computer', 'freebase_id': '/m/0bh9flk'},
210
+ {'id': 204, 'name': 'Toilet paper', 'freebase_id': '/m/09gtd'},
211
+ {'id': 205, 'name': 'Pomegranate', 'freebase_id': '/m/0jwn_'},
212
+ {'id': 206, 'name': 'Skirt', 'freebase_id': '/m/02wv6h6'},
213
+ {'id': 207, 'name': 'Gas stove', 'freebase_id': '/m/02wv84t'},
214
+ {'id': 208, 'name': 'Cookie', 'freebase_id': '/m/021mn'},
215
+ {'id': 209, 'name': 'Cart', 'freebase_id': '/m/018p4k'},
216
+ {'id': 210, 'name': 'Raven', 'freebase_id': '/m/06j2d'},
217
+ {'id': 211, 'name': 'Egg', 'freebase_id': '/m/033cnk'},
218
+ {'id': 212, 'name': 'Burrito', 'freebase_id': '/m/01j3zr'},
219
+ {'id': 213, 'name': 'Goat', 'freebase_id': '/m/03fwl'},
220
+ {'id': 214, 'name': 'Kitchen knife', 'freebase_id': '/m/058qzx'},
221
+ {'id': 215, 'name': 'Skateboard', 'freebase_id': '/m/06_fw'},
222
+ {'id': 216, 'name': 'Salt and pepper shakers', 'freebase_id': '/m/02x8cch'},
223
+ {'id': 217, 'name': 'Lynx', 'freebase_id': '/m/04g2r'},
224
+ {'id': 218, 'name': 'Boot', 'freebase_id': '/m/01b638'},
225
+ {'id': 219, 'name': 'Platter', 'freebase_id': '/m/099ssp'},
226
+ {'id': 220, 'name': 'Ski', 'freebase_id': '/m/071p9'},
227
+ {'id': 221, 'name': 'Swimwear', 'freebase_id': '/m/01gkx_'},
228
+ {'id': 222, 'name': 'Swimming pool', 'freebase_id': '/m/0b_rs'},
229
+ {'id': 223, 'name': 'Drinking straw', 'freebase_id': '/m/03v5tg'},
230
+ {'id': 224, 'name': 'Wrench', 'freebase_id': '/m/01j5ks'},
231
+ {'id': 225, 'name': 'Drum', 'freebase_id': '/m/026t6'},
232
+ {'id': 226, 'name': 'Ant', 'freebase_id': '/m/0_k2'},
233
+ {'id': 227, 'name': 'Human ear', 'freebase_id': '/m/039xj_'},
234
+ {'id': 228, 'name': 'Headphones', 'freebase_id': '/m/01b7fy'},
235
+ {'id': 229, 'name': 'Fountain', 'freebase_id': '/m/0220r2'},
236
+ {'id': 230, 'name': 'Bird', 'freebase_id': '/m/015p6'},
237
+ {'id': 231, 'name': 'Jeans', 'freebase_id': '/m/0fly7'},
238
+ {'id': 232, 'name': 'Television', 'freebase_id': '/m/07c52'},
239
+ {'id': 233, 'name': 'Crab', 'freebase_id': '/m/0n28_'},
240
+ {'id': 234, 'name': 'Microphone', 'freebase_id': '/m/0hg7b'},
241
+ {'id': 235, 'name': 'Home appliance', 'freebase_id': '/m/019dx1'},
242
+ {'id': 236, 'name': 'Snowplow', 'freebase_id': '/m/04vv5k'},
243
+ {'id': 237, 'name': 'Beetle', 'freebase_id': '/m/020jm'},
244
+ {'id': 238, 'name': 'Artichoke', 'freebase_id': '/m/047v4b'},
245
+ {'id': 239, 'name': 'Jet ski', 'freebase_id': '/m/01xs3r'},
246
+ {'id': 240, 'name': 'Stationary bicycle', 'freebase_id': '/m/03kt2w'},
247
+ {'id': 241, 'name': 'Human hair', 'freebase_id': '/m/03q69'},
248
+ {'id': 242, 'name': 'Brown bear', 'freebase_id': '/m/01dxs'},
249
+ {'id': 243, 'name': 'Starfish', 'freebase_id': '/m/01h8tj'},
250
+ {'id': 244, 'name': 'Fork', 'freebase_id': '/m/0dt3t'},
251
+ {'id': 245, 'name': 'Lobster', 'freebase_id': '/m/0cjq5'},
252
+ {'id': 246, 'name': 'Corded phone', 'freebase_id': '/m/0h8lkj8'},
253
+ {'id': 247, 'name': 'Drink', 'freebase_id': '/m/0271t'},
254
+ {'id': 248, 'name': 'Saucer', 'freebase_id': '/m/03q5c7'},
255
+ {'id': 249, 'name': 'Carrot', 'freebase_id': '/m/0fj52s'},
256
+ {'id': 250, 'name': 'Insect', 'freebase_id': '/m/03vt0'},
257
+ {'id': 251, 'name': 'Clock', 'freebase_id': '/m/01x3z'},
258
+ {'id': 252, 'name': 'Castle', 'freebase_id': '/m/0d5gx'},
259
+ {'id': 253, 'name': 'Tennis racket', 'freebase_id': '/m/0h8my_4'},
260
+ {'id': 254, 'name': 'Ceiling fan', 'freebase_id': '/m/03ldnb'},
261
+ {'id': 255, 'name': 'Asparagus', 'freebase_id': '/m/0cjs7'},
262
+ {'id': 256, 'name': 'Jaguar', 'freebase_id': '/m/0449p'},
263
+ {'id': 257, 'name': 'Musical instrument', 'freebase_id': '/m/04szw'},
264
+ {'id': 258, 'name': 'Train', 'freebase_id': '/m/07jdr'},
265
+ {'id': 259, 'name': 'Cat', 'freebase_id': '/m/01yrx'},
266
+ {'id': 260, 'name': 'Rifle', 'freebase_id': '/m/06c54'},
267
+ {'id': 261, 'name': 'Dumbbell', 'freebase_id': '/m/04h8sr'},
268
+ {'id': 262, 'name': 'Mobile phone', 'freebase_id': '/m/050k8'},
269
+ {'id': 263, 'name': 'Taxi', 'freebase_id': '/m/0pg52'},
270
+ {'id': 264, 'name': 'Shower', 'freebase_id': '/m/02f9f_'},
271
+ {'id': 265, 'name': 'Pitcher', 'freebase_id': '/m/054fyh'},
272
+ {'id': 266, 'name': 'Lemon', 'freebase_id': '/m/09k_b'},
273
+ {'id': 267, 'name': 'Invertebrate', 'freebase_id': '/m/03xxp'},
274
+ {'id': 268, 'name': 'Turkey', 'freebase_id': '/m/0jly1'},
275
+ {'id': 269, 'name': 'High heels', 'freebase_id': '/m/06k2mb'},
276
+ {'id': 270, 'name': 'Bust', 'freebase_id': '/m/04yqq2'},
277
+ {'id': 271, 'name': 'Elephant', 'freebase_id': '/m/0bwd_0j'},
278
+ {'id': 272, 'name': 'Scarf', 'freebase_id': '/m/02h19r'},
279
+ {'id': 273, 'name': 'Barrel', 'freebase_id': '/m/02zn6n'},
280
+ {'id': 274, 'name': 'Trombone', 'freebase_id': '/m/07c6l'},
281
+ {'id': 275, 'name': 'Pumpkin', 'freebase_id': '/m/05zsy'},
282
+ {'id': 276, 'name': 'Box', 'freebase_id': '/m/025dyy'},
283
+ {'id': 277, 'name': 'Tomato', 'freebase_id': '/m/07j87'},
284
+ {'id': 278, 'name': 'Frog', 'freebase_id': '/m/09ld4'},
285
+ {'id': 279, 'name': 'Bidet', 'freebase_id': '/m/01vbnl'},
286
+ {'id': 280, 'name': 'Human face', 'freebase_id': '/m/0dzct'},
287
+ {'id': 281, 'name': 'Houseplant', 'freebase_id': '/m/03fp41'},
288
+ {'id': 282, 'name': 'Van', 'freebase_id': '/m/0h2r6'},
289
+ {'id': 283, 'name': 'Shark', 'freebase_id': '/m/0by6g'},
290
+ {'id': 284, 'name': 'Ice cream', 'freebase_id': '/m/0cxn2'},
291
+ {'id': 285, 'name': 'Swim cap', 'freebase_id': '/m/04tn4x'},
292
+ {'id': 286, 'name': 'Falcon', 'freebase_id': '/m/0f6wt'},
293
+ {'id': 287, 'name': 'Ostrich', 'freebase_id': '/m/05n4y'},
294
+ {'id': 288, 'name': 'Handgun', 'freebase_id': '/m/0gxl3'},
295
+ {'id': 289, 'name': 'Whiteboard', 'freebase_id': '/m/02d9qx'},
296
+ {'id': 290, 'name': 'Lizard', 'freebase_id': '/m/04m9y'},
297
+ {'id': 291, 'name': 'Pasta', 'freebase_id': '/m/05z55'},
298
+ {'id': 292, 'name': 'Snowmobile', 'freebase_id': '/m/01x3jk'},
299
+ {'id': 293, 'name': 'Light bulb', 'freebase_id': '/m/0h8l4fh'},
300
+ {'id': 294, 'name': 'Window blind', 'freebase_id': '/m/031b6r'},
301
+ {'id': 295, 'name': 'Muffin', 'freebase_id': '/m/01tcjp'},
302
+ {'id': 296, 'name': 'Pretzel', 'freebase_id': '/m/01f91_'},
303
+ {'id': 297, 'name': 'Computer monitor', 'freebase_id': '/m/02522'},
304
+ {'id': 298, 'name': 'Horn', 'freebase_id': '/m/0319l'},
305
+ {'id': 299, 'name': 'Furniture', 'freebase_id': '/m/0c_jw'},
306
+ {'id': 300, 'name': 'Sandwich', 'freebase_id': '/m/0l515'},
307
+ {'id': 301, 'name': 'Fox', 'freebase_id': '/m/0306r'},
308
+ {'id': 302, 'name': 'Convenience store', 'freebase_id': '/m/0crjs'},
309
+ {'id': 303, 'name': 'Fish', 'freebase_id': '/m/0ch_cf'},
310
+ {'id': 304, 'name': 'Fruit', 'freebase_id': '/m/02xwb'},
311
+ {'id': 305, 'name': 'Earrings', 'freebase_id': '/m/01r546'},
312
+ {'id': 306, 'name': 'Curtain', 'freebase_id': '/m/03rszm'},
313
+ {'id': 307, 'name': 'Grape', 'freebase_id': '/m/0388q'},
314
+ {'id': 308, 'name': 'Sofa bed', 'freebase_id': '/m/03m3pdh'},
315
+ {'id': 309, 'name': 'Horse', 'freebase_id': '/m/03k3r'},
316
+ {'id': 310, 'name': 'Luggage and bags', 'freebase_id': '/m/0hf58v5'},
317
+ {'id': 311, 'name': 'Desk', 'freebase_id': '/m/01y9k5'},
318
+ {'id': 312, 'name': 'Crutch', 'freebase_id': '/m/05441v'},
319
+ {'id': 313, 'name': 'Bicycle helmet', 'freebase_id': '/m/03p3bw'},
320
+ {'id': 314, 'name': 'Tick', 'freebase_id': '/m/0175cv'},
321
+ {'id': 315, 'name': 'Airplane', 'freebase_id': '/m/0cmf2'},
322
+ {'id': 316, 'name': 'Canary', 'freebase_id': '/m/0ccs93'},
323
+ {'id': 317, 'name': 'Spatula', 'freebase_id': '/m/02d1br'},
324
+ {'id': 318, 'name': 'Watch', 'freebase_id': '/m/0gjkl'},
325
+ {'id': 319, 'name': 'Lily', 'freebase_id': '/m/0jqgx'},
326
+ {'id': 320, 'name': 'Kitchen appliance', 'freebase_id': '/m/0h99cwc'},
327
+ {'id': 321, 'name': 'Filing cabinet', 'freebase_id': '/m/047j0r'},
328
+ {'id': 322, 'name': 'Aircraft', 'freebase_id': '/m/0k5j'},
329
+ {'id': 323, 'name': 'Cake stand', 'freebase_id': '/m/0h8n6ft'},
330
+ {'id': 324, 'name': 'Candy', 'freebase_id': '/m/0gm28'},
331
+ {'id': 325, 'name': 'Sink', 'freebase_id': '/m/0130jx'},
332
+ {'id': 326, 'name': 'Mouse', 'freebase_id': '/m/04rmv'},
333
+ {'id': 327, 'name': 'Wine', 'freebase_id': '/m/081qc'},
334
+ {'id': 328, 'name': 'Wheelchair', 'freebase_id': '/m/0qmmr'},
335
+ {'id': 329, 'name': 'Goldfish', 'freebase_id': '/m/03fj2'},
336
+ {'id': 330, 'name': 'Refrigerator', 'freebase_id': '/m/040b_t'},
337
+ {'id': 331, 'name': 'French fries', 'freebase_id': '/m/02y6n'},
338
+ {'id': 332, 'name': 'Drawer', 'freebase_id': '/m/0fqfqc'},
339
+ {'id': 333, 'name': 'Treadmill', 'freebase_id': '/m/030610'},
340
+ {'id': 334, 'name': 'Picnic basket', 'freebase_id': '/m/07kng9'},
341
+ {'id': 335, 'name': 'Dice', 'freebase_id': '/m/029b3'},
342
+ {'id': 336, 'name': 'Cabbage', 'freebase_id': '/m/0fbw6'},
343
+ {'id': 337, 'name': 'Football helmet', 'freebase_id': '/m/07qxg_'},
344
+ {'id': 338, 'name': 'Pig', 'freebase_id': '/m/068zj'},
345
+ {'id': 339, 'name': 'Person', 'freebase_id': '/m/01g317'},
346
+ {'id': 340, 'name': 'Shorts', 'freebase_id': '/m/01bfm9'},
347
+ {'id': 341, 'name': 'Gondola', 'freebase_id': '/m/02068x'},
348
+ {'id': 342, 'name': 'Honeycomb', 'freebase_id': '/m/0fz0h'},
349
+ {'id': 343, 'name': 'Doughnut', 'freebase_id': '/m/0jy4k'},
350
+ {'id': 344, 'name': 'Chest of drawers', 'freebase_id': '/m/05kyg_'},
351
+ {'id': 345, 'name': 'Land vehicle', 'freebase_id': '/m/01prls'},
352
+ {'id': 346, 'name': 'Bat', 'freebase_id': '/m/01h44'},
353
+ {'id': 347, 'name': 'Monkey', 'freebase_id': '/m/08pbxl'},
354
+ {'id': 348, 'name': 'Dagger', 'freebase_id': '/m/02gzp'},
355
+ {'id': 349, 'name': 'Tableware', 'freebase_id': '/m/04brg2'},
356
+ {'id': 350, 'name': 'Human foot', 'freebase_id': '/m/031n1'},
357
+ {'id': 351, 'name': 'Mug', 'freebase_id': '/m/02jvh9'},
358
+ {'id': 352, 'name': 'Alarm clock', 'freebase_id': '/m/046dlr'},
359
+ {'id': 353, 'name': 'Pressure cooker', 'freebase_id': '/m/0h8ntjv'},
360
+ {'id': 354, 'name': 'Human hand', 'freebase_id': '/m/0k65p'},
361
+ {'id': 355, 'name': 'Tortoise', 'freebase_id': '/m/011k07'},
362
+ {'id': 356, 'name': 'Baseball glove', 'freebase_id': '/m/03grzl'},
363
+ {'id': 357, 'name': 'Sword', 'freebase_id': '/m/06y5r'},
364
+ {'id': 358, 'name': 'Pear', 'freebase_id': '/m/061_f'},
365
+ {'id': 359, 'name': 'Miniskirt', 'freebase_id': '/m/01cmb2'},
366
+ {'id': 360, 'name': 'Traffic sign', 'freebase_id': '/m/01mqdt'},
367
+ {'id': 361, 'name': 'Girl', 'freebase_id': '/m/05r655'},
368
+ {'id': 362, 'name': 'Roller skates', 'freebase_id': '/m/02p3w7d'},
369
+ {'id': 363, 'name': 'Dinosaur', 'freebase_id': '/m/029tx'},
370
+ {'id': 364, 'name': 'Porch', 'freebase_id': '/m/04m6gz'},
371
+ {'id': 365, 'name': 'Human beard', 'freebase_id': '/m/015h_t'},
372
+ {'id': 366, 'name': 'Submarine sandwich', 'freebase_id': '/m/06pcq'},
373
+ {'id': 367, 'name': 'Screwdriver', 'freebase_id': '/m/01bms0'},
374
+ {'id': 368, 'name': 'Strawberry', 'freebase_id': '/m/07fbm7'},
375
+ {'id': 369, 'name': 'Wine glass', 'freebase_id': '/m/09tvcd'},
376
+ {'id': 370, 'name': 'Seafood', 'freebase_id': '/m/06nwz'},
377
+ {'id': 371, 'name': 'Racket', 'freebase_id': '/m/0dv9c'},
378
+ {'id': 372, 'name': 'Wheel', 'freebase_id': '/m/083wq'},
379
+ {'id': 373, 'name': 'Sea lion', 'freebase_id': '/m/0gd36'},
380
+ {'id': 374, 'name': 'Toy', 'freebase_id': '/m/0138tl'},
381
+ {'id': 375, 'name': 'Tea', 'freebase_id': '/m/07clx'},
382
+ {'id': 376, 'name': 'Tennis ball', 'freebase_id': '/m/05ctyq'},
383
+ {'id': 377, 'name': 'Waste container', 'freebase_id': '/m/0bjyj5'},
384
+ {'id': 378, 'name': 'Mule', 'freebase_id': '/m/0dbzx'},
385
+ {'id': 379, 'name': 'Cricket ball', 'freebase_id': '/m/02ctlc'},
386
+ {'id': 380, 'name': 'Pineapple', 'freebase_id': '/m/0fp6w'},
387
+ {'id': 381, 'name': 'Coconut', 'freebase_id': '/m/0djtd'},
388
+ {'id': 382, 'name': 'Doll', 'freebase_id': '/m/0167gd'},
389
+ {'id': 383, 'name': 'Coffee table', 'freebase_id': '/m/078n6m'},
390
+ {'id': 384, 'name': 'Snowman', 'freebase_id': '/m/0152hh'},
391
+ {'id': 385, 'name': 'Lavender', 'freebase_id': '/m/04gth'},
392
+ {'id': 386, 'name': 'Shrimp', 'freebase_id': '/m/0ll1f78'},
393
+ {'id': 387, 'name': 'Maple', 'freebase_id': '/m/0cffdh'},
394
+ {'id': 388, 'name': 'Cowboy hat', 'freebase_id': '/m/025rp__'},
395
+ {'id': 389, 'name': 'Goggles', 'freebase_id': '/m/02_n6y'},
396
+ {'id': 390, 'name': 'Rugby ball', 'freebase_id': '/m/0wdt60w'},
397
+ {'id': 391, 'name': 'Caterpillar', 'freebase_id': '/m/0cydv'},
398
+ {'id': 392, 'name': 'Poster', 'freebase_id': '/m/01n5jq'},
399
+ {'id': 393, 'name': 'Rocket', 'freebase_id': '/m/09rvcxw'},
400
+ {'id': 394, 'name': 'Organ', 'freebase_id': '/m/013y1f'},
401
+ {'id': 395, 'name': 'Saxophone', 'freebase_id': '/m/06ncr'},
402
+ {'id': 396, 'name': 'Traffic light', 'freebase_id': '/m/015qff'},
403
+ {'id': 397, 'name': 'Cocktail', 'freebase_id': '/m/024g6'},
404
+ {'id': 398, 'name': 'Plastic bag', 'freebase_id': '/m/05gqfk'},
405
+ {'id': 399, 'name': 'Squash', 'freebase_id': '/m/0dv77'},
406
+ {'id': 400, 'name': 'Mushroom', 'freebase_id': '/m/052sf'},
407
+ {'id': 401, 'name': 'Hamburger', 'freebase_id': '/m/0cdn1'},
408
+ {'id': 402, 'name': 'Light switch', 'freebase_id': '/m/03jbxj'},
409
+ {'id': 403, 'name': 'Parachute', 'freebase_id': '/m/0cyfs'},
410
+ {'id': 404, 'name': 'Teddy bear', 'freebase_id': '/m/0kmg4'},
411
+ {'id': 405, 'name': 'Winter melon', 'freebase_id': '/m/02cvgx'},
412
+ {'id': 406, 'name': 'Deer', 'freebase_id': '/m/09kx5'},
413
+ {'id': 407, 'name': 'Musical keyboard', 'freebase_id': '/m/057cc'},
414
+ {'id': 408, 'name': 'Plumbing fixture', 'freebase_id': '/m/02pkr5'},
415
+ {'id': 409, 'name': 'Scoreboard', 'freebase_id': '/m/057p5t'},
416
+ {'id': 410, 'name': 'Baseball bat', 'freebase_id': '/m/03g8mr'},
417
+ {'id': 411, 'name': 'Envelope', 'freebase_id': '/m/0frqm'},
418
+ {'id': 412, 'name': 'Adhesive tape', 'freebase_id': '/m/03m3vtv'},
419
+ {'id': 413, 'name': 'Briefcase', 'freebase_id': '/m/0584n8'},
420
+ {'id': 414, 'name': 'Paddle', 'freebase_id': '/m/014y4n'},
421
+ {'id': 415, 'name': 'Bow and arrow', 'freebase_id': '/m/01g3x7'},
422
+ {'id': 416, 'name': 'Telephone', 'freebase_id': '/m/07cx4'},
423
+ {'id': 417, 'name': 'Sheep', 'freebase_id': '/m/07bgp'},
424
+ {'id': 418, 'name': 'Jacket', 'freebase_id': '/m/032b3c'},
425
+ {'id': 419, 'name': 'Boy', 'freebase_id': '/m/01bl7v'},
426
+ {'id': 420, 'name': 'Pizza', 'freebase_id': '/m/0663v'},
427
+ {'id': 421, 'name': 'Otter', 'freebase_id': '/m/0cn6p'},
428
+ {'id': 422, 'name': 'Office supplies', 'freebase_id': '/m/02rdsp'},
429
+ {'id': 423, 'name': 'Couch', 'freebase_id': '/m/02crq1'},
430
+ {'id': 424, 'name': 'Cello', 'freebase_id': '/m/01xqw'},
431
+ {'id': 425, 'name': 'Bull', 'freebase_id': '/m/0cnyhnx'},
432
+ {'id': 426, 'name': 'Camel', 'freebase_id': '/m/01x_v'},
433
+ {'id': 427, 'name': 'Ball', 'freebase_id': '/m/018xm'},
434
+ {'id': 428, 'name': 'Duck', 'freebase_id': '/m/09ddx'},
435
+ {'id': 429, 'name': 'Whale', 'freebase_id': '/m/084zz'},
436
+ {'id': 430, 'name': 'Shirt', 'freebase_id': '/m/01n4qj'},
437
+ {'id': 431, 'name': 'Tank', 'freebase_id': '/m/07cmd'},
438
+ {'id': 432, 'name': 'Motorcycle', 'freebase_id': '/m/04_sv'},
439
+ {'id': 433, 'name': 'Accordion', 'freebase_id': '/m/0mkg'},
440
+ {'id': 434, 'name': 'Owl', 'freebase_id': '/m/09d5_'},
441
+ {'id': 435, 'name': 'Porcupine', 'freebase_id': '/m/0c568'},
442
+ {'id': 436, 'name': 'Sun hat', 'freebase_id': '/m/02wbtzl'},
443
+ {'id': 437, 'name': 'Nail', 'freebase_id': '/m/05bm6'},
444
+ {'id': 438, 'name': 'Scissors', 'freebase_id': '/m/01lsmm'},
445
+ {'id': 439, 'name': 'Swan', 'freebase_id': '/m/0dftk'},
446
+ {'id': 440, 'name': 'Lamp', 'freebase_id': '/m/0dtln'},
447
+ {'id': 441, 'name': 'Crown', 'freebase_id': '/m/0nl46'},
448
+ {'id': 442, 'name': 'Piano', 'freebase_id': '/m/05r5c'},
449
+ {'id': 443, 'name': 'Sculpture', 'freebase_id': '/m/06msq'},
450
+ {'id': 444, 'name': 'Cheetah', 'freebase_id': '/m/0cd4d'},
451
+ {'id': 445, 'name': 'Oboe', 'freebase_id': '/m/05kms'},
452
+ {'id': 446, 'name': 'Tin can', 'freebase_id': '/m/02jnhm'},
453
+ {'id': 447, 'name': 'Mango', 'freebase_id': '/m/0fldg'},
454
+ {'id': 448, 'name': 'Tripod', 'freebase_id': '/m/073bxn'},
455
+ {'id': 449, 'name': 'Oven', 'freebase_id': '/m/029bxz'},
456
+ {'id': 450, 'name': 'Mouse', 'freebase_id': '/m/020lf'},
457
+ {'id': 451, 'name': 'Barge', 'freebase_id': '/m/01btn'},
458
+ {'id': 452, 'name': 'Coffee', 'freebase_id': '/m/02vqfm'},
459
+ {'id': 453, 'name': 'Snowboard', 'freebase_id': '/m/06__v'},
460
+ {'id': 454, 'name': 'Common fig', 'freebase_id': '/m/043nyj'},
461
+ {'id': 455, 'name': 'Salad', 'freebase_id': '/m/0grw1'},
462
+ {'id': 456, 'name': 'Marine invertebrates', 'freebase_id': '/m/03hl4l9'},
463
+ {'id': 457, 'name': 'Umbrella', 'freebase_id': '/m/0hnnb'},
464
+ {'id': 458, 'name': 'Kangaroo', 'freebase_id': '/m/04c0y'},
465
+ {'id': 459, 'name': 'Human arm', 'freebase_id': '/m/0dzf4'},
466
+ {'id': 460, 'name': 'Measuring cup', 'freebase_id': '/m/07v9_z'},
467
+ {'id': 461, 'name': 'Snail', 'freebase_id': '/m/0f9_l'},
468
+ {'id': 462, 'name': 'Loveseat', 'freebase_id': '/m/0703r8'},
469
+ {'id': 463, 'name': 'Suit', 'freebase_id': '/m/01xyhv'},
470
+ {'id': 464, 'name': 'Teapot', 'freebase_id': '/m/01fh4r'},
471
+ {'id': 465, 'name': 'Bottle', 'freebase_id': '/m/04dr76w'},
472
+ {'id': 466, 'name': 'Alpaca', 'freebase_id': '/m/0pcr'},
473
+ {'id': 467, 'name': 'Kettle', 'freebase_id': '/m/03s_tn'},
474
+ {'id': 468, 'name': 'Trousers', 'freebase_id': '/m/07mhn'},
475
+ {'id': 469, 'name': 'Popcorn', 'freebase_id': '/m/01hrv5'},
476
+ {'id': 470, 'name': 'Centipede', 'freebase_id': '/m/019h78'},
477
+ {'id': 471, 'name': 'Spider', 'freebase_id': '/m/09kmb'},
478
+ {'id': 472, 'name': 'Sparrow', 'freebase_id': '/m/0h23m'},
479
+ {'id': 473, 'name': 'Plate', 'freebase_id': '/m/050gv4'},
480
+ {'id': 474, 'name': 'Bagel', 'freebase_id': '/m/01fb_0'},
481
+ {'id': 475, 'name': 'Personal care', 'freebase_id': '/m/02w3_ws'},
482
+ {'id': 476, 'name': 'Apple', 'freebase_id': '/m/014j1m'},
483
+ {'id': 477, 'name': 'Brassiere', 'freebase_id': '/m/01gmv2'},
484
+ {'id': 478, 'name': 'Bathroom cabinet', 'freebase_id': '/m/04y4h8h'},
485
+ {'id': 479, 'name': 'studio couch', 'freebase_id': '/m/026qbn5'},
486
+ {'id': 480, 'name': 'Computer keyboard', 'freebase_id': '/m/01m2v'},
487
+ {'id': 481, 'name': 'Table tennis racket', 'freebase_id': '/m/05_5p_0'},
488
+ {'id': 482, 'name': 'Sushi', 'freebase_id': '/m/07030'},
489
+ {'id': 483, 'name': 'Cabinetry', 'freebase_id': '/m/01s105'},
490
+ {'id': 484, 'name': 'Street light', 'freebase_id': '/m/033rq4'},
491
+ {'id': 485, 'name': 'Towel', 'freebase_id': '/m/0162_1'},
492
+ {'id': 486, 'name': 'Nightstand', 'freebase_id': '/m/02z51p'},
493
+ {'id': 487, 'name': 'Rabbit', 'freebase_id': '/m/06mf6'},
494
+ {'id': 488, 'name': 'Dolphin', 'freebase_id': '/m/02hj4'},
495
+ {'id': 489, 'name': 'Dog', 'freebase_id': '/m/0bt9lr'},
496
+ {'id': 490, 'name': 'Jug', 'freebase_id': '/m/08hvt4'},
497
+ {'id': 491, 'name': 'Wok', 'freebase_id': '/m/084rd'},
498
+ {'id': 492, 'name': 'Fire hydrant', 'freebase_id': '/m/01pns0'},
499
+ {'id': 493, 'name': 'Human eye', 'freebase_id': '/m/014sv8'},
500
+ {'id': 494, 'name': 'Skyscraper', 'freebase_id': '/m/079cl'},
501
+ {'id': 495, 'name': 'Backpack', 'freebase_id': '/m/01940j'},
502
+ {'id': 496, 'name': 'Potato', 'freebase_id': '/m/05vtc'},
503
+ {'id': 497, 'name': 'Paper towel', 'freebase_id': '/m/02w3r3'},
504
+ {'id': 498, 'name': 'Lifejacket', 'freebase_id': '/m/054xkw'},
505
+ {'id': 499, 'name': 'Bicycle wheel', 'freebase_id': '/m/01bqk0'},
506
+ {'id': 500, 'name': 'Toilet', 'freebase_id': '/m/09g1w'},
507
+ ]
508
+
509
+
510
+ def _get_builtin_metadata(cats):
511
+ id_to_name = {x['id']: x['name'] for x in cats}
512
+ thing_dataset_id_to_contiguous_id = {i + 1: i for i in range(len(cats))}
513
+ thing_classes = [x['name'] for x in sorted(cats, key=lambda x: x['id'])]
514
+ return {
515
+ "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
516
+ "thing_classes": thing_classes}
517
+
518
+ _PREDEFINED_SPLITS_OID = {
519
+ # cat threshold: 500, 1500: r 170, c 151, f 179
520
+ "oid_train": ("oid/images/", "oid/annotations/oid_challenge_2019_train_bbox.json"),
521
+ # "expanded" duplicates annotations to their father classes based on the official
522
+ # hierarchy. This is used in the official evaulation protocol.
523
+ # https://storage.googleapis.com/openimages/web/evaluation.html
524
+ "oid_val_expanded": ("oid/images/validation/", "oid/annotations/oid_challenge_2019_val_expanded.json"),
525
+ "oid_val_expanded_rare": ("oid/images/validation/", "oid/annotations/oid_challenge_2019_val_expanded_rare.json"),
526
+ }
527
+
528
+
529
+ for key, (image_root, json_file) in _PREDEFINED_SPLITS_OID.items():
530
+ register_oid_instances(
531
+ key,
532
+ _get_builtin_metadata(categories),
533
+ os.path.join("datasets", json_file) if "://" not in json_file else json_file,
534
+ os.path.join("datasets", image_root),
535
+ )
proxydet/data/datasets/register_oid.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Xingyi Zhou from https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/datasets/coco.py
3
+ import copy
4
+ import io
5
+ import logging
6
+ import contextlib
7
+ import os
8
+ import datetime
9
+ import json
10
+ import numpy as np
11
+
12
+ from PIL import Image
13
+
14
+ from fvcore.common.timer import Timer
15
+ from fvcore.common.file_io import PathManager, file_lock
16
+ from detectron2.structures import BoxMode, PolygonMasks, Boxes
17
+ from detectron2.data import DatasetCatalog, MetadataCatalog
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ """
22
+ This file contains functions to register a COCO-format dataset to the DatasetCatalog.
23
+ """
24
+
25
+ __all__ = ["register_coco_instances", "register_coco_panoptic_separated"]
26
+
27
+
28
+
29
+ def register_oid_instances(name, metadata, json_file, image_root):
30
+ """
31
+ """
32
+ # 1. register a function which returns dicts
33
+ DatasetCatalog.register(name, lambda: load_coco_json_mem_efficient(
34
+ json_file, image_root, name))
35
+
36
+ # 2. Optionally, add metadata about this dataset,
37
+ # since they might be useful in evaluation, visualization or logging
38
+ MetadataCatalog.get(name).set(
39
+ json_file=json_file, image_root=image_root, evaluator_type="oid", **metadata
40
+ )
41
+
42
+
43
+ def load_coco_json_mem_efficient(json_file, image_root, dataset_name=None, extra_annotation_keys=None):
44
+ """
45
+ Actually not mem efficient
46
+ """
47
+ from pycocotools.coco import COCO
48
+
49
+ timer = Timer()
50
+ json_file = PathManager.get_local_path(json_file)
51
+ with contextlib.redirect_stdout(io.StringIO()):
52
+ coco_api = COCO(json_file)
53
+ if timer.seconds() > 1:
54
+ logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
55
+
56
+ id_map = None
57
+ if dataset_name is not None:
58
+ meta = MetadataCatalog.get(dataset_name)
59
+ cat_ids = sorted(coco_api.getCatIds())
60
+ cats = coco_api.loadCats(cat_ids)
61
+ # The categories in a custom json file may not be sorted.
62
+ thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])]
63
+ meta.thing_classes = thing_classes
64
+
65
+ if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)):
66
+ if "coco" not in dataset_name:
67
+ logger.warning(
68
+ """
69
+ Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.
70
+ """
71
+ )
72
+ id_map = {v: i for i, v in enumerate(cat_ids)}
73
+ meta.thing_dataset_id_to_contiguous_id = id_map
74
+
75
+ # sort indices for reproducible results
76
+ img_ids = sorted(coco_api.imgs.keys())
77
+ imgs = coco_api.loadImgs(img_ids)
78
+ logger.info("Loaded {} images in COCO format from {}".format(len(imgs), json_file))
79
+
80
+ dataset_dicts = []
81
+
82
+ ann_keys = ["iscrowd", "bbox", "category_id"] + (extra_annotation_keys or [])
83
+
84
+ for img_dict in imgs:
85
+ record = {}
86
+ record["file_name"] = os.path.join(image_root, img_dict["file_name"])
87
+ record["height"] = img_dict["height"]
88
+ record["width"] = img_dict["width"]
89
+ image_id = record["image_id"] = img_dict["id"]
90
+ anno_dict_list = coco_api.imgToAnns[image_id]
91
+ if 'neg_category_ids' in img_dict:
92
+ record['neg_category_ids'] = \
93
+ [id_map[x] for x in img_dict['neg_category_ids']]
94
+
95
+ objs = []
96
+ for anno in anno_dict_list:
97
+ assert anno["image_id"] == image_id
98
+
99
+ assert anno.get("ignore", 0) == 0
100
+
101
+ obj = {key: anno[key] for key in ann_keys if key in anno}
102
+
103
+ segm = anno.get("segmentation", None)
104
+ if segm: # either list[list[float]] or dict(RLE)
105
+ if not isinstance(segm, dict):
106
+ # filter out invalid polygons (< 3 points)
107
+ segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
108
+ if len(segm) == 0:
109
+ num_instances_without_valid_segmentation += 1
110
+ continue # ignore this instance
111
+ obj["segmentation"] = segm
112
+
113
+ obj["bbox_mode"] = BoxMode.XYWH_ABS
114
+
115
+ if id_map:
116
+ obj["category_id"] = id_map[obj["category_id"]]
117
+ objs.append(obj)
118
+ record["annotations"] = objs
119
+ dataset_dicts.append(record)
120
+
121
+ del coco_api
122
+ return dataset_dicts
proxydet/data/tar_dataset.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ import os
4
+ import gzip
5
+ import numpy as np
6
+ import io
7
+ from PIL import Image
8
+ from torch.utils.data import Dataset
9
+
10
+ try:
11
+ from PIL import UnidentifiedImageError
12
+
13
+ unidentified_error_available = True
14
+ except ImportError:
15
+ # UnidentifiedImageError isn't available in older versions of PIL
16
+ unidentified_error_available = False
17
+
18
+ class DiskTarDataset(Dataset):
19
+ def __init__(self,
20
+ tarfile_path='dataset/imagenet/ImageNet-21k/metadata/tar_files.npy',
21
+ tar_index_dir='dataset/imagenet/ImageNet-21k/metadata/tarindex_npy',
22
+ preload=False,
23
+ num_synsets="all"):
24
+ """
25
+ - preload (bool): Recommend to set preload to False when using
26
+ - num_synsets (integer or string "all"): set to small number for debugging
27
+ will load subset of dataset
28
+ """
29
+ tar_files = np.load(tarfile_path)
30
+
31
+ chunk_datasets = []
32
+ dataset_lens = []
33
+ if isinstance(num_synsets, int):
34
+ assert num_synsets < len(tar_files)
35
+ tar_files = tar_files[:num_synsets]
36
+ for tar_file in tar_files:
37
+ dataset = _TarDataset(tar_file, tar_index_dir, preload=preload)
38
+ chunk_datasets.append(dataset)
39
+ dataset_lens.append(len(dataset))
40
+
41
+ self.chunk_datasets = chunk_datasets
42
+ self.dataset_lens = np.array(dataset_lens).astype(np.int32)
43
+ self.dataset_cumsums = np.cumsum(self.dataset_lens)
44
+ self.num_samples = sum(self.dataset_lens)
45
+ labels = np.zeros(self.dataset_lens.sum(), dtype=np.int64)
46
+ sI = 0
47
+ for k in range(len(self.dataset_lens)):
48
+ assert (sI+self.dataset_lens[k]) <= len(labels), f"{k} {sI+self.dataset_lens[k]} vs. {len(labels)}"
49
+ labels[sI:(sI+self.dataset_lens[k])] = k
50
+ sI += self.dataset_lens[k]
51
+ self.labels = labels
52
+
53
+ def __len__(self):
54
+ return self.num_samples
55
+
56
+ def __getitem__(self, index):
57
+ assert index >= 0 and index < len(self)
58
+ # find the dataset file we need to go to
59
+ d_index = np.searchsorted(self.dataset_cumsums, index)
60
+
61
+ # edge case, if index is at edge of chunks, move right
62
+ if index in self.dataset_cumsums:
63
+ d_index += 1
64
+
65
+ assert d_index == self.labels[index], f"{d_index} vs. {self.labels[index]} mismatch for {index}"
66
+
67
+ # change index to local dataset index
68
+ if d_index == 0:
69
+ local_index = index
70
+ else:
71
+ local_index = index - self.dataset_cumsums[d_index - 1]
72
+ data_bytes = self.chunk_datasets[d_index][local_index]
73
+ exception_to_catch = UnidentifiedImageError if unidentified_error_available else Exception
74
+ try:
75
+ image = Image.open(data_bytes).convert("RGB")
76
+ except exception_to_catch:
77
+ image = Image.fromarray(np.ones((224,224,3), dtype=np.uint8)*128)
78
+ d_index = -1
79
+
80
+ # label is the dataset (synset) we indexed into
81
+ return image, d_index, index
82
+
83
+ def __repr__(self):
84
+ st = f"DiskTarDataset(subdatasets={len(self.dataset_lens)},samples={self.num_samples})"
85
+ return st
86
+
87
+ class _TarDataset(object):
88
+
89
+ def __init__(self, filename, npy_index_dir, preload=False):
90
+ # translated from
91
+ # fbcode/experimental/deeplearning/matthijs/comp_descs/tardataset.lua
92
+ self.filename = filename
93
+ self.names = []
94
+ self.offsets = []
95
+ self.npy_index_dir = npy_index_dir
96
+ names, offsets = self.load_index()
97
+
98
+ self.num_samples = len(names)
99
+ if preload:
100
+ self.data = np.memmap(filename, mode='r', dtype='uint8')
101
+ self.offsets = offsets
102
+ else:
103
+ self.data = None
104
+
105
+
106
+ def __len__(self):
107
+ return self.num_samples
108
+
109
+ def load_index(self):
110
+ basename = os.path.basename(self.filename)
111
+ basename = os.path.splitext(basename)[0]
112
+ names = np.load(os.path.join(self.npy_index_dir, f"{basename}_names.npy"))
113
+ offsets = np.load(os.path.join(self.npy_index_dir, f"{basename}_offsets.npy"))
114
+ return names, offsets
115
+
116
+ def __getitem__(self, idx):
117
+ if self.data is None:
118
+ self.data = np.memmap(self.filename, mode='r', dtype='uint8')
119
+ _, self.offsets = self.load_index()
120
+
121
+ ofs = self.offsets[idx] * 512
122
+ fsize = 512 * (self.offsets[idx + 1] - self.offsets[idx])
123
+ data = self.data[ofs:ofs + fsize]
124
+
125
+ if data[:13].tostring() == '././@LongLink':
126
+ data = data[3 * 512:]
127
+ else:
128
+ data = data[512:]
129
+
130
+ # just to make it more fun a few JPEGs are GZIP compressed...
131
+ # catch this case
132
+ if tuple(data[:2]) == (0x1f, 0x8b):
133
+ s = io.BytesIO(data.tostring())
134
+ g = gzip.GzipFile(None, 'r', 0, s)
135
+ sdata = g.read()
136
+ else:
137
+ sdata = data.tostring()
138
+ return io.BytesIO(sdata)
proxydet/data/transforms/custom_augmentation_impl.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3
+ # Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py
4
+ # Modified by Xingyi Zhou
5
+ # The original code is under Apache-2.0 License
6
+ import numpy as np
7
+ import sys
8
+ from fvcore.transforms.transform import (
9
+ BlendTransform,
10
+ CropTransform,
11
+ HFlipTransform,
12
+ NoOpTransform,
13
+ Transform,
14
+ VFlipTransform,
15
+ )
16
+ from PIL import Image
17
+
18
+ from detectron2.data.transforms.augmentation import Augmentation
19
+ from .custom_transform import EfficientDetResizeCropTransform
20
+
21
+ __all__ = [
22
+ "EfficientDetResizeCrop",
23
+ ]
24
+
25
+ class EfficientDetResizeCrop(Augmentation):
26
+ """
27
+ Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge.
28
+ If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
29
+ """
30
+
31
+ def __init__(
32
+ self, size, scale, interp=Image.BILINEAR
33
+ ):
34
+ """
35
+ """
36
+ super().__init__()
37
+ self.target_size = (size, size)
38
+ self.scale = scale
39
+ self.interp = interp
40
+
41
+ def get_transform(self, img):
42
+ # Select a random scale factor.
43
+ scale_factor = np.random.uniform(*self.scale)
44
+ scaled_target_height = scale_factor * self.target_size[0]
45
+ scaled_target_width = scale_factor * self.target_size[1]
46
+ # Recompute the accurate scale_factor using rounded scaled image size.
47
+ width, height = img.shape[1], img.shape[0]
48
+ img_scale_y = scaled_target_height / height
49
+ img_scale_x = scaled_target_width / width
50
+ img_scale = min(img_scale_y, img_scale_x)
51
+
52
+ # Select non-zero random offset (x, y) if scaled image is larger than target size
53
+ scaled_h = int(height * img_scale)
54
+ scaled_w = int(width * img_scale)
55
+ offset_y = scaled_h - self.target_size[0]
56
+ offset_x = scaled_w - self.target_size[1]
57
+ offset_y = int(max(0.0, float(offset_y)) * np.random.uniform(0, 1))
58
+ offset_x = int(max(0.0, float(offset_x)) * np.random.uniform(0, 1))
59
+ return EfficientDetResizeCropTransform(
60
+ scaled_h, scaled_w, offset_y, offset_x, img_scale, self.target_size, self.interp)
proxydet/data/transforms/custom_transform.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3
+ # Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py
4
+ # Modified by Xingyi Zhou
5
+ # The original code is under Apache-2.0 License
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from fvcore.transforms.transform import (
10
+ CropTransform,
11
+ HFlipTransform,
12
+ NoOpTransform,
13
+ Transform,
14
+ TransformList,
15
+ )
16
+ from PIL import Image
17
+
18
+ try:
19
+ import cv2 # noqa
20
+ except ImportError:
21
+ # OpenCV is an optional dependency at the moment
22
+ pass
23
+
24
+ __all__ = [
25
+ "EfficientDetResizeCropTransform",
26
+ ]
27
+
28
+ class EfficientDetResizeCropTransform(Transform):
29
+ """
30
+ """
31
+
32
+ def __init__(self, scaled_h, scaled_w, offset_y, offset_x, img_scale, \
33
+ target_size, interp=None):
34
+ """
35
+ Args:
36
+ h, w (int): original image size
37
+ new_h, new_w (int): new image size
38
+ interp: PIL interpolation methods, defaults to bilinear.
39
+ """
40
+ # TODO decide on PIL vs opencv
41
+ super().__init__()
42
+ if interp is None:
43
+ interp = Image.BILINEAR
44
+ self._set_attributes(locals())
45
+
46
+ def apply_image(self, img, interp=None):
47
+ assert len(img.shape) <= 4
48
+
49
+ if img.dtype == np.uint8:
50
+ pil_image = Image.fromarray(img)
51
+ interp_method = interp if interp is not None else self.interp
52
+ pil_image = pil_image.resize((self.scaled_w, self.scaled_h), interp_method)
53
+ ret = np.asarray(pil_image)
54
+ right = min(self.scaled_w, self.offset_x + self.target_size[1])
55
+ lower = min(self.scaled_h, self.offset_y + self.target_size[0])
56
+ if len(ret.shape) <= 3:
57
+ ret = ret[self.offset_y: lower, self.offset_x: right]
58
+ else:
59
+ ret = ret[..., self.offset_y: lower, self.offset_x: right, :]
60
+ else:
61
+ # PIL only supports uint8
62
+ img = torch.from_numpy(img)
63
+ shape = list(img.shape)
64
+ shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:]
65
+ img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw
66
+ _PIL_RESIZE_TO_INTERPOLATE_MODE = {Image.BILINEAR: "bilinear", Image.BICUBIC: "bicubic"}
67
+ mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[self.interp]
68
+ img = F.interpolate(img, (self.scaled_h, self.scaled_w), mode=mode, align_corners=False)
69
+ shape[:2] = (self.scaled_h, self.scaled_w)
70
+ ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c)
71
+ right = min(self.scaled_w, self.offset_x + self.target_size[1])
72
+ lower = min(self.scaled_h, self.offset_y + self.target_size[0])
73
+ if len(ret.shape) <= 3:
74
+ ret = ret[self.offset_y: lower, self.offset_x: right]
75
+ else:
76
+ ret = ret[..., self.offset_y: lower, self.offset_x: right, :]
77
+ return ret
78
+
79
+
80
+ def apply_coords(self, coords):
81
+ coords[:, 0] = coords[:, 0] * self.img_scale
82
+ coords[:, 1] = coords[:, 1] * self.img_scale
83
+ coords[:, 0] -= self.offset_x
84
+ coords[:, 1] -= self.offset_y
85
+ return coords
86
+
87
+
88
+ def apply_segmentation(self, segmentation):
89
+ segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
90
+ return segmentation
91
+
92
+
93
+ def inverse(self):
94
+ raise NotImplementedError
95
+
96
+
97
+ def inverse_apply_coords(self, coords):
98
+ coords[:, 0] += self.offset_x
99
+ coords[:, 1] += self.offset_y
100
+ coords[:, 0] = coords[:, 0] / self.img_scale
101
+ coords[:, 1] = coords[:, 1] / self.img_scale
102
+ return coords
103
+
104
+
105
+ def inverse_apply_box(self, box: np.ndarray) -> np.ndarray:
106
+ """
107
+ """
108
+ idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten()
109
+ coords = np.asarray(box).reshape(-1, 4)[:, idxs].reshape(-1, 2)
110
+ coords = self.inverse_apply_coords(coords).reshape((-1, 4, 2))
111
+ minxy = coords.min(axis=1)
112
+ maxxy = coords.max(axis=1)
113
+ trans_boxes = np.concatenate((minxy, maxxy), axis=1)
114
+ return trans_boxes
proxydet/evaluation/custom_coco_eval.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import contextlib
3
+ import copy
4
+ import io
5
+ import itertools
6
+ import json
7
+ import logging
8
+ import numpy as np
9
+ import os
10
+ import pickle
11
+ from collections import OrderedDict
12
+ import pycocotools.mask as mask_util
13
+ import torch
14
+ from pycocotools.coco import COCO
15
+ from pycocotools.cocoeval import COCOeval
16
+ from tabulate import tabulate
17
+
18
+ import detectron2.utils.comm as comm
19
+ from detectron2.config import CfgNode
20
+ from detectron2.data import MetadataCatalog
21
+ from detectron2.data.datasets.coco import convert_to_coco_json
22
+ from detectron2.evaluation.coco_evaluation import COCOEvaluator
23
+ from detectron2.structures import Boxes, BoxMode, pairwise_iou
24
+ from detectron2.utils.file_io import PathManager
25
+ from detectron2.utils.logger import create_small_table
26
+ from ..data.datasets.coco_zeroshot import categories_seen, categories_unseen
27
+
28
+ class CustomCOCOEvaluator(COCOEvaluator):
29
+ def _derive_coco_results(self, coco_eval, iou_type, class_names=None):
30
+ """
31
+ Additionally plot mAP for 'seen classes' and 'unseen classes'
32
+ """
33
+
34
+ metrics = {
35
+ "bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
36
+ "segm": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
37
+ "keypoints": ["AP", "AP50", "AP75", "APm", "APl"],
38
+ }[iou_type]
39
+
40
+ if coco_eval is None:
41
+ self._logger.warn("No predictions from the model!")
42
+ return {metric: float("nan") for metric in metrics}
43
+
44
+ # the standard metrics
45
+ results = {
46
+ metric: float(coco_eval.stats[idx] * 100 if coco_eval.stats[idx] >= 0 else "nan")
47
+ for idx, metric in enumerate(metrics)
48
+ }
49
+ self._logger.info(
50
+ "Evaluation results for {}: \n".format(iou_type) + create_small_table(results)
51
+ )
52
+ if not np.isfinite(sum(results.values())):
53
+ self._logger.info("Some metrics cannot be computed and is shown as NaN.")
54
+
55
+ if class_names is None or len(class_names) <= 1:
56
+ return results
57
+ # Compute per-category AP
58
+ # from https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L222-L252 # noqa
59
+ precisions = coco_eval.eval["precision"]
60
+ # precision has dims (iou, recall, cls, area range, max dets)
61
+ assert len(class_names) == precisions.shape[2]
62
+
63
+ seen_names = set([x['name'] for x in categories_seen])
64
+ unseen_names = set([x['name'] for x in categories_unseen])
65
+ results_per_category = []
66
+ results_per_category50 = []
67
+ results_per_category50_seen = []
68
+ results_per_category50_unseen = []
69
+ for idx, name in enumerate(class_names):
70
+ # area range index 0: all area ranges
71
+ # max dets index -1: typically 100 per image
72
+ precision = precisions[:, :, idx, 0, -1]
73
+ precision = precision[precision > -1]
74
+ ap = np.mean(precision) if precision.size else float("nan")
75
+ results_per_category.append(("{}".format(name), float(ap * 100)))
76
+ precision50 = precisions[0, :, idx, 0, -1]
77
+ precision50 = precision50[precision50 > -1]
78
+ ap50 = np.mean(precision50) if precision50.size else float("nan")
79
+ results_per_category50.append(("{}".format(name), float(ap50 * 100)))
80
+ if name in seen_names:
81
+ results_per_category50_seen.append(float(ap50 * 100))
82
+ if name in unseen_names:
83
+ results_per_category50_unseen.append(float(ap50 * 100))
84
+
85
+ # tabulate it
86
+ N_COLS = min(6, len(results_per_category) * 2)
87
+ results_flatten = list(itertools.chain(*results_per_category))
88
+ results_2d = itertools.zip_longest(*[results_flatten[i::N_COLS] for i in range(N_COLS)])
89
+ table = tabulate(
90
+ results_2d,
91
+ tablefmt="pipe",
92
+ floatfmt=".3f",
93
+ headers=["category", "AP"] * (N_COLS // 2),
94
+ numalign="left",
95
+ )
96
+ self._logger.info("Per-category {} AP: \n".format(iou_type) + table)
97
+
98
+
99
+ N_COLS = min(6, len(results_per_category50) * 2)
100
+ results_flatten = list(itertools.chain(*results_per_category50))
101
+ results_2d = itertools.zip_longest(*[results_flatten[i::N_COLS] for i in range(N_COLS)])
102
+ table = tabulate(
103
+ results_2d,
104
+ tablefmt="pipe",
105
+ floatfmt=".3f",
106
+ headers=["category", "AP50"] * (N_COLS // 2),
107
+ numalign="left",
108
+ )
109
+ self._logger.info("Per-category {} AP50: \n".format(iou_type) + table)
110
+ self._logger.info(
111
+ "Seen {} AP50: {}".format(
112
+ iou_type,
113
+ sum(results_per_category50_seen) / len(results_per_category50_seen),
114
+ ))
115
+ self._logger.info(
116
+ "Unseen {} AP50: {}".format(
117
+ iou_type,
118
+ sum(results_per_category50_unseen) / len(results_per_category50_unseen),
119
+ ))
120
+
121
+ results.update({"AP-" + name: ap for name, ap in results_per_category})
122
+ results["AP50-seen"] = sum(results_per_category50_seen) / len(results_per_category50_seen)
123
+ results["AP50-unseen"] = sum(results_per_category50_unseen) / len(results_per_category50_unseen)
124
+ return results
proxydet/evaluation/oideval.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Part of the code is from https://github.com/tensorflow/models/blob/master/research/object_detection/metrics/oid_challenge_evaluation.py
2
+ # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3
+ # The original code is under Apache License, Version 2.0 (the "License");
4
+ # Part of the code is from https://github.com/lvis-dataset/lvis-api/blob/master/lvis/eval.py
5
+ # Copyright (c) 2019, Agrim Gupta and Ross Girshick
6
+ # Modified by Xingyi Zhou
7
+ # This script re-implement OpenImages evaluation in detectron2
8
+ # The code is from https://github.com/xingyizhou/UniDet/blob/master/projects/UniDet/unidet/evaluation/oideval.py
9
+ # The original code is under Apache-2.0 License
10
+ # Copyright (c) Facebook, Inc. and its affiliates.
11
+ import os
12
+ import datetime
13
+ import logging
14
+ import itertools
15
+ from collections import OrderedDict
16
+ from collections import defaultdict
17
+ import copy
18
+ import json
19
+ import numpy as np
20
+ import torch
21
+ from tabulate import tabulate
22
+
23
+ from lvis.lvis import LVIS
24
+ from lvis.results import LVISResults
25
+
26
+ import pycocotools.mask as mask_utils
27
+
28
+ from fvcore.common.file_io import PathManager
29
+ import detectron2.utils.comm as comm
30
+ from detectron2.data import MetadataCatalog
31
+ from detectron2.evaluation.coco_evaluation import instances_to_coco_json
32
+ from detectron2.utils.logger import create_small_table
33
+ from detectron2.evaluation import DatasetEvaluator
34
+
35
+ def compute_average_precision(precision, recall):
36
+ """Compute Average Precision according to the definition in VOCdevkit.
37
+ Precision is modified to ensure that it does not decrease as recall
38
+ decrease.
39
+ Args:
40
+ precision: A float [N, 1] numpy array of precisions
41
+ recall: A float [N, 1] numpy array of recalls
42
+ Raises:
43
+ ValueError: if the input is not of the correct format
44
+ Returns:
45
+ average_precison: The area under the precision recall curve. NaN if
46
+ precision and recall are None.
47
+ """
48
+ if precision is None:
49
+ if recall is not None:
50
+ raise ValueError("If precision is None, recall must also be None")
51
+ return np.NAN
52
+
53
+ if not isinstance(precision, np.ndarray) or not isinstance(
54
+ recall, np.ndarray):
55
+ raise ValueError("precision and recall must be numpy array")
56
+ if precision.dtype != np.float or recall.dtype != np.float:
57
+ raise ValueError("input must be float numpy array.")
58
+ if len(precision) != len(recall):
59
+ raise ValueError("precision and recall must be of the same size.")
60
+ if not precision.size:
61
+ return 0.0
62
+ if np.amin(precision) < 0 or np.amax(precision) > 1:
63
+ raise ValueError("Precision must be in the range of [0, 1].")
64
+ if np.amin(recall) < 0 or np.amax(recall) > 1:
65
+ raise ValueError("recall must be in the range of [0, 1].")
66
+ if not all(recall[i] <= recall[i + 1] for i in range(len(recall) - 1)):
67
+ raise ValueError("recall must be a non-decreasing array")
68
+
69
+ recall = np.concatenate([[0], recall, [1]])
70
+ precision = np.concatenate([[0], precision, [0]])
71
+
72
+ for i in range(len(precision) - 2, -1, -1):
73
+ precision[i] = np.maximum(precision[i], precision[i + 1])
74
+ indices = np.where(recall[1:] != recall[:-1])[0] + 1
75
+ average_precision = np.sum(
76
+ (recall[indices] - recall[indices - 1]) * precision[indices])
77
+ return average_precision
78
+
79
+ class OIDEval:
80
+ def __init__(
81
+ self, lvis_gt, lvis_dt, iou_type="bbox", expand_pred_label=False,
82
+ oid_hierarchy_path='./datasets/oid/annotations/challenge-2019-label500-hierarchy.json'):
83
+ """Constructor for OIDEval.
84
+ Args:
85
+ lvis_gt (LVIS class instance, or str containing path of annotation file)
86
+ lvis_dt (LVISResult class instance, or str containing path of result file,
87
+ or list of dict)
88
+ iou_type (str): segm or bbox evaluation
89
+ """
90
+ self.logger = logging.getLogger(__name__)
91
+
92
+ if iou_type not in ["bbox", "segm"]:
93
+ raise ValueError("iou_type: {} is not supported.".format(iou_type))
94
+
95
+ if isinstance(lvis_gt, LVIS):
96
+ self.lvis_gt = lvis_gt
97
+ elif isinstance(lvis_gt, str):
98
+ self.lvis_gt = LVIS(lvis_gt)
99
+ else:
100
+ raise TypeError("Unsupported type {} of lvis_gt.".format(lvis_gt))
101
+
102
+ if isinstance(lvis_dt, LVISResults):
103
+ self.lvis_dt = lvis_dt
104
+ elif isinstance(lvis_dt, (str, list)):
105
+ # self.lvis_dt = LVISResults(self.lvis_gt, lvis_dt, max_dets=-1)
106
+ self.lvis_dt = LVISResults(self.lvis_gt, lvis_dt)
107
+ else:
108
+ raise TypeError("Unsupported type {} of lvis_dt.".format(lvis_dt))
109
+
110
+ if expand_pred_label:
111
+ oid_hierarchy = json.load(open(oid_hierarchy_path, 'r'))
112
+ cat_info = self.lvis_gt.dataset['categories']
113
+ freebase2id = {x['freebase_id']: x['id'] for x in cat_info}
114
+ id2freebase = {x['id']: x['freebase_id'] for x in cat_info}
115
+ id2name = {x['id']: x['name'] for x in cat_info}
116
+
117
+ fas = defaultdict(set)
118
+ def dfs(hierarchy, cur_id):
119
+ all_childs = set()
120
+ all_keyed_child = {}
121
+ if 'Subcategory' in hierarchy:
122
+ for x in hierarchy['Subcategory']:
123
+ childs = dfs(x, freebase2id[x['LabelName']])
124
+ all_childs.update(childs)
125
+ if cur_id != -1:
126
+ for c in all_childs:
127
+ fas[c].add(cur_id)
128
+ all_childs.add(cur_id)
129
+ return all_childs
130
+ dfs(oid_hierarchy, -1)
131
+
132
+ expanded_pred = []
133
+ id_count = 0
134
+ for d in self.lvis_dt.dataset['annotations']:
135
+ cur_id = d['category_id']
136
+ ids = [cur_id] + [x for x in fas[cur_id]]
137
+ for cat_id in ids:
138
+ new_box = copy.deepcopy(d)
139
+ id_count = id_count + 1
140
+ new_box['id'] = id_count
141
+ new_box['category_id'] = cat_id
142
+ expanded_pred.append(new_box)
143
+
144
+ print('Expanding original {} preds to {} preds'.format(
145
+ len(self.lvis_dt.dataset['annotations']),
146
+ len(expanded_pred)
147
+ ))
148
+ self.lvis_dt.dataset['annotations'] = expanded_pred
149
+ self.lvis_dt._create_index()
150
+
151
+ # per-image per-category evaluation results
152
+ self.eval_imgs = defaultdict(list)
153
+ self.eval = {} # accumulated evaluation results
154
+ self._gts = defaultdict(list) # gt for evaluation
155
+ self._dts = defaultdict(list) # dt for evaluation
156
+ self.params = Params(iou_type=iou_type) # parameters
157
+ self.results = OrderedDict()
158
+ self.ious = {} # ious between all gts and dts
159
+
160
+ self.params.img_ids = sorted(self.lvis_gt.get_img_ids())
161
+ self.params.cat_ids = sorted(self.lvis_gt.get_cat_ids())
162
+
163
+ def _to_mask(self, anns, lvis):
164
+ for ann in anns:
165
+ rle = lvis.ann_to_rle(ann)
166
+ ann["segmentation"] = rle
167
+
168
+ def _prepare(self):
169
+ """Prepare self._gts and self._dts for evaluation based on params."""
170
+
171
+ cat_ids = self.params.cat_ids if self.params.cat_ids else None
172
+
173
+ gts = self.lvis_gt.load_anns(
174
+ self.lvis_gt.get_ann_ids(img_ids=self.params.img_ids, cat_ids=cat_ids)
175
+ )
176
+ dts = self.lvis_dt.load_anns(
177
+ self.lvis_dt.get_ann_ids(img_ids=self.params.img_ids, cat_ids=cat_ids)
178
+ )
179
+ # convert ground truth to mask if iou_type == 'segm'
180
+ if self.params.iou_type == "segm":
181
+ self._to_mask(gts, self.lvis_gt)
182
+ self._to_mask(dts, self.lvis_dt)
183
+
184
+ for gt in gts:
185
+ self._gts[gt["image_id"], gt["category_id"]].append(gt)
186
+
187
+ # For federated dataset evaluation we will filter out all dt for an
188
+ # image which belong to categories not present in gt and not present in
189
+ # the negative list for an image. In other words detector is not penalized
190
+ # for categories about which we don't have gt information about their
191
+ # presence or absence in an image.
192
+ img_data = self.lvis_gt.load_imgs(ids=self.params.img_ids)
193
+ # per image map of categories not present in image
194
+ img_nl = {d["id"]: d["neg_category_ids"] for d in img_data}
195
+ # per image list of categories present in image
196
+ img_pl = {d["id"]: d["pos_category_ids"] for d in img_data}
197
+ # img_pl = defaultdict(set)
198
+ for ann in gts:
199
+ # img_pl[ann["image_id"]].add(ann["category_id"])
200
+ assert ann["category_id"] in img_pl[ann["image_id"]]
201
+ # print('check pos ids OK.')
202
+
203
+ for dt in dts:
204
+ img_id, cat_id = dt["image_id"], dt["category_id"]
205
+ if cat_id not in img_nl[img_id] and cat_id not in img_pl[img_id]:
206
+ continue
207
+ self._dts[img_id, cat_id].append(dt)
208
+
209
+ def evaluate(self):
210
+ """
211
+ Run per image evaluation on given images and store results
212
+ (a list of dict) in self.eval_imgs.
213
+ """
214
+ self.logger.info("Running per image evaluation.")
215
+ self.logger.info("Evaluate annotation type *{}*".format(self.params.iou_type))
216
+
217
+ self.params.img_ids = list(np.unique(self.params.img_ids))
218
+
219
+ if self.params.use_cats:
220
+ cat_ids = self.params.cat_ids
221
+ else:
222
+ cat_ids = [-1]
223
+
224
+ self._prepare()
225
+
226
+ self.ious = {
227
+ (img_id, cat_id): self.compute_iou(img_id, cat_id)
228
+ for img_id in self.params.img_ids
229
+ for cat_id in cat_ids
230
+ }
231
+
232
+ # loop through images, area range, max detection number
233
+ print('Evaluating ...')
234
+ self.eval_imgs = [
235
+ self.evaluate_img_google(img_id, cat_id, area_rng)
236
+ for cat_id in cat_ids
237
+ for area_rng in self.params.area_rng
238
+ for img_id in self.params.img_ids
239
+ ]
240
+
241
+ def _get_gt_dt(self, img_id, cat_id):
242
+ """Create gt, dt which are list of anns/dets. If use_cats is true
243
+ only anns/dets corresponding to tuple (img_id, cat_id) will be
244
+ used. Else, all anns/dets in image are used and cat_id is not used.
245
+ """
246
+ if self.params.use_cats:
247
+ gt = self._gts[img_id, cat_id]
248
+ dt = self._dts[img_id, cat_id]
249
+ else:
250
+ gt = [
251
+ _ann
252
+ for _cat_id in self.params.cat_ids
253
+ for _ann in self._gts[img_id, cat_id]
254
+ ]
255
+ dt = [
256
+ _ann
257
+ for _cat_id in self.params.cat_ids
258
+ for _ann in self._dts[img_id, cat_id]
259
+ ]
260
+ return gt, dt
261
+
262
+ def compute_iou(self, img_id, cat_id):
263
+ gt, dt = self._get_gt_dt(img_id, cat_id)
264
+
265
+ if len(gt) == 0 and len(dt) == 0:
266
+ return []
267
+
268
+ # Sort detections in decreasing order of score.
269
+ idx = np.argsort([-d["score"] for d in dt], kind="mergesort")
270
+ dt = [dt[i] for i in idx]
271
+
272
+ # iscrowd = [int(False)] * len(gt)
273
+ iscrowd = [int('iscrowd' in g and g['iscrowd'] > 0) for g in gt]
274
+
275
+ if self.params.iou_type == "segm":
276
+ ann_type = "segmentation"
277
+ elif self.params.iou_type == "bbox":
278
+ ann_type = "bbox"
279
+ else:
280
+ raise ValueError("Unknown iou_type for iou computation.")
281
+ gt = [g[ann_type] for g in gt]
282
+ dt = [d[ann_type] for d in dt]
283
+
284
+ # compute iou between each dt and gt region
285
+ # will return array of shape len(dt), len(gt)
286
+ ious = mask_utils.iou(dt, gt, iscrowd)
287
+ return ious
288
+
289
+ def evaluate_img_google(self, img_id, cat_id, area_rng):
290
+ gt, dt = self._get_gt_dt(img_id, cat_id)
291
+ if len(gt) == 0 and len(dt) == 0:
292
+ return None
293
+
294
+ if len(dt) == 0:
295
+ return {
296
+ "image_id": img_id,
297
+ "category_id": cat_id,
298
+ "area_rng": area_rng,
299
+ "dt_ids": [],
300
+ "dt_matches": np.array([], dtype=np.int32).reshape(1, -1),
301
+ "dt_scores": [],
302
+ "dt_ignore": np.array([], dtype=np.int32).reshape(1, -1),
303
+ 'num_gt': len(gt)
304
+ }
305
+
306
+ no_crowd_inds = [i for i, g in enumerate(gt) \
307
+ if ('iscrowd' not in g) or g['iscrowd'] == 0]
308
+ crowd_inds = [i for i, g in enumerate(gt) \
309
+ if 'iscrowd' in g and g['iscrowd'] == 1]
310
+ dt_idx = np.argsort([-d["score"] for d in dt], kind="mergesort")
311
+
312
+ if len(self.ious[img_id, cat_id]) > 0:
313
+ ious = self.ious[img_id, cat_id]
314
+ iou = ious[:, no_crowd_inds]
315
+ iou = iou[dt_idx]
316
+ ioa = ious[:, crowd_inds]
317
+ ioa = ioa[dt_idx]
318
+ else:
319
+ iou = np.zeros((len(dt_idx), 0))
320
+ ioa = np.zeros((len(dt_idx), 0))
321
+ scores = np.array([dt[i]['score'] for i in dt_idx])
322
+
323
+ num_detected_boxes = len(dt)
324
+ tp_fp_labels = np.zeros(num_detected_boxes, dtype=bool)
325
+ is_matched_to_group_of = np.zeros(num_detected_boxes, dtype=bool)
326
+
327
+ def compute_match_iou(iou):
328
+ max_overlap_gt_ids = np.argmax(iou, axis=1)
329
+ is_gt_detected = np.zeros(iou.shape[1], dtype=bool)
330
+ for i in range(num_detected_boxes):
331
+ gt_id = max_overlap_gt_ids[i]
332
+ is_evaluatable = (not tp_fp_labels[i] and
333
+ iou[i, gt_id] >= 0.5 and
334
+ not is_matched_to_group_of[i])
335
+ if is_evaluatable:
336
+ if not is_gt_detected[gt_id]:
337
+ tp_fp_labels[i] = True
338
+ is_gt_detected[gt_id] = True
339
+
340
+ def compute_match_ioa(ioa):
341
+ scores_group_of = np.zeros(ioa.shape[1], dtype=float)
342
+ tp_fp_labels_group_of = np.ones(
343
+ ioa.shape[1], dtype=float)
344
+ max_overlap_group_of_gt_ids = np.argmax(ioa, axis=1)
345
+ for i in range(num_detected_boxes):
346
+ gt_id = max_overlap_group_of_gt_ids[i]
347
+ is_evaluatable = (not tp_fp_labels[i] and
348
+ ioa[i, gt_id] >= 0.5 and
349
+ not is_matched_to_group_of[i])
350
+ if is_evaluatable:
351
+ is_matched_to_group_of[i] = True
352
+ scores_group_of[gt_id] = max(scores_group_of[gt_id], scores[i])
353
+ selector = np.where((scores_group_of > 0) & (tp_fp_labels_group_of > 0))
354
+ scores_group_of = scores_group_of[selector]
355
+ tp_fp_labels_group_of = tp_fp_labels_group_of[selector]
356
+
357
+ return scores_group_of, tp_fp_labels_group_of
358
+
359
+ if iou.shape[1] > 0:
360
+ compute_match_iou(iou)
361
+
362
+ scores_box_group_of = np.ndarray([0], dtype=float)
363
+ tp_fp_labels_box_group_of = np.ndarray([0], dtype=float)
364
+
365
+ if ioa.shape[1] > 0:
366
+ scores_box_group_of, tp_fp_labels_box_group_of = compute_match_ioa(ioa)
367
+
368
+ valid_entries = (~is_matched_to_group_of)
369
+
370
+ scores = np.concatenate(
371
+ (scores[valid_entries], scores_box_group_of))
372
+ tp_fps = np.concatenate(
373
+ (tp_fp_labels[valid_entries].astype(float),
374
+ tp_fp_labels_box_group_of))
375
+
376
+ return {
377
+ "image_id": img_id,
378
+ "category_id": cat_id,
379
+ "area_rng": area_rng,
380
+ "dt_matches": np.array([1 if x > 0 else 0 for x in tp_fps], dtype=np.int32).reshape(1, -1),
381
+ "dt_scores": [x for x in scores],
382
+ "dt_ignore": np.array([0 for x in scores], dtype=np.int32).reshape(1, -1),
383
+ 'num_gt': len(gt)
384
+ }
385
+
386
+ def accumulate(self):
387
+ """Accumulate per image evaluation results and store the result in
388
+ self.eval.
389
+ """
390
+ self.logger.info("Accumulating evaluation results.")
391
+
392
+ if not self.eval_imgs:
393
+ self.logger.warn("Please run evaluate first.")
394
+
395
+ if self.params.use_cats:
396
+ cat_ids = self.params.cat_ids
397
+ else:
398
+ cat_ids = [-1]
399
+
400
+ num_thrs = 1
401
+ num_recalls = 1
402
+
403
+ num_cats = len(cat_ids)
404
+ num_area_rngs = 1
405
+ num_imgs = len(self.params.img_ids)
406
+
407
+ # -1 for absent categories
408
+ precision = -np.ones(
409
+ (num_thrs, num_recalls, num_cats, num_area_rngs)
410
+ )
411
+ recall = -np.ones((num_thrs, num_cats, num_area_rngs))
412
+
413
+ # Initialize dt_pointers
414
+ dt_pointers = {}
415
+ for cat_idx in range(num_cats):
416
+ dt_pointers[cat_idx] = {}
417
+ for area_idx in range(num_area_rngs):
418
+ dt_pointers[cat_idx][area_idx] = {}
419
+
420
+ # Per category evaluation
421
+ for cat_idx in range(num_cats):
422
+ Nk = cat_idx * num_area_rngs * num_imgs
423
+ for area_idx in range(num_area_rngs):
424
+ Na = area_idx * num_imgs
425
+ E = [
426
+ self.eval_imgs[Nk + Na + img_idx]
427
+ for img_idx in range(num_imgs)
428
+ ]
429
+ # Remove elements which are None
430
+ E = [e for e in E if not e is None]
431
+ if len(E) == 0:
432
+ continue
433
+
434
+ dt_scores = np.concatenate([e["dt_scores"] for e in E], axis=0)
435
+ dt_idx = np.argsort(-dt_scores, kind="mergesort")
436
+ dt_scores = dt_scores[dt_idx]
437
+ dt_m = np.concatenate([e["dt_matches"] for e in E], axis=1)[:, dt_idx]
438
+ dt_ig = np.concatenate([e["dt_ignore"] for e in E], axis=1)[:, dt_idx]
439
+
440
+ num_gt = sum([e['num_gt'] for e in E])
441
+ if num_gt == 0:
442
+ continue
443
+
444
+ tps = np.logical_and(dt_m, np.logical_not(dt_ig))
445
+ fps = np.logical_and(np.logical_not(dt_m), np.logical_not(dt_ig))
446
+ tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float)
447
+ fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float)
448
+
449
+ dt_pointers[cat_idx][area_idx] = {
450
+ "tps": tps,
451
+ "fps": fps,
452
+ }
453
+
454
+ for iou_thr_idx, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
455
+ tp = np.array(tp)
456
+ fp = np.array(fp)
457
+ num_tp = len(tp)
458
+ rc = tp / num_gt
459
+
460
+ if num_tp:
461
+ recall[iou_thr_idx, cat_idx, area_idx] = rc[
462
+ -1
463
+ ]
464
+ else:
465
+ recall[iou_thr_idx, cat_idx, area_idx] = 0
466
+
467
+ # np.spacing(1) ~= eps
468
+ pr = tp / (fp + tp + np.spacing(1))
469
+ pr = pr.tolist()
470
+
471
+ for i in range(num_tp - 1, 0, -1):
472
+ if pr[i] > pr[i - 1]:
473
+ pr[i - 1] = pr[i]
474
+
475
+ mAP = compute_average_precision(
476
+ np.array(pr, np.float).reshape(-1),
477
+ np.array(rc, np.float).reshape(-1))
478
+ precision[iou_thr_idx, :, cat_idx, area_idx] = mAP
479
+
480
+ self.eval = {
481
+ "params": self.params,
482
+ "counts": [num_thrs, num_recalls, num_cats, num_area_rngs],
483
+ "date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
484
+ "precision": precision,
485
+ "recall": recall,
486
+ "dt_pointers": dt_pointers,
487
+ }
488
+
489
+ def _summarize(self, summary_type):
490
+ s = self.eval["precision"]
491
+ if len(s[s > -1]) == 0:
492
+ mean_s = -1
493
+ else:
494
+ mean_s = np.mean(s[s > -1])
495
+ # print(s.reshape(1, 1, -1, 1))
496
+ return mean_s
497
+
498
+ def summarize(self):
499
+ """Compute and display summary metrics for evaluation results."""
500
+ if not self.eval:
501
+ raise RuntimeError("Please run accumulate() first.")
502
+
503
+ max_dets = self.params.max_dets
504
+ self.results["AP50"] = self._summarize('ap')
505
+
506
+ def run(self):
507
+ """Wrapper function which calculates the results."""
508
+ self.evaluate()
509
+ self.accumulate()
510
+ self.summarize()
511
+
512
+ def print_results(self):
513
+ template = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} catIds={:>3s}] = {:0.3f}"
514
+
515
+ for key, value in self.results.items():
516
+ max_dets = self.params.max_dets
517
+ if "AP" in key:
518
+ title = "Average Precision"
519
+ _type = "(AP)"
520
+ else:
521
+ title = "Average Recall"
522
+ _type = "(AR)"
523
+
524
+ if len(key) > 2 and key[2].isdigit():
525
+ iou_thr = (float(key[2:]) / 100)
526
+ iou = "{:0.2f}".format(iou_thr)
527
+ else:
528
+ iou = "{:0.2f}:{:0.2f}".format(
529
+ self.params.iou_thrs[0], self.params.iou_thrs[-1]
530
+ )
531
+
532
+ cat_group_name = "all"
533
+ area_rng = "all"
534
+
535
+ print(template.format(title, _type, iou, area_rng, max_dets, cat_group_name, value))
536
+
537
+ def get_results(self):
538
+ if not self.results:
539
+ self.logger.warn("results is empty. Call run().")
540
+ return self.results
541
+
542
+
543
+ class Params:
544
+ def __init__(self, iou_type):
545
+ self.img_ids = []
546
+ self.cat_ids = []
547
+ # np.arange causes trouble. the data point on arange is slightly
548
+ # larger than the true value
549
+ self.iou_thrs = np.linspace(
550
+ 0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True
551
+ )
552
+ self.google_style = True
553
+ # print('Using google style PR curve')
554
+ self.iou_thrs = self.iou_thrs[:1]
555
+ self.max_dets = 1000
556
+
557
+ self.area_rng = [
558
+ [0 ** 2, 1e5 ** 2],
559
+ ]
560
+ self.area_rng_lbl = ["all"]
561
+ self.use_cats = 1
562
+ self.iou_type = iou_type
563
+
564
+
565
+ class OIDEvaluator(DatasetEvaluator):
566
+ def __init__(self, dataset_name, cfg, distributed, output_dir=None):
567
+ self._distributed = distributed
568
+ self._output_dir = output_dir
569
+
570
+ self._cpu_device = torch.device("cpu")
571
+ self._logger = logging.getLogger(__name__)
572
+
573
+ self._metadata = MetadataCatalog.get(dataset_name)
574
+ json_file = PathManager.get_local_path(self._metadata.json_file)
575
+ self._oid_api = LVIS(json_file)
576
+ # Test set json files do not contain annotations (evaluation must be
577
+ # performed using the LVIS evaluation server).
578
+ self._do_evaluation = len(self._oid_api.get_ann_ids()) > 0
579
+ self._mask_on = cfg.MODEL.MASK_ON
580
+
581
+ def reset(self):
582
+ self._predictions = []
583
+ self._oid_results = []
584
+
585
+ def process(self, inputs, outputs):
586
+ for input, output in zip(inputs, outputs):
587
+ prediction = {"image_id": input["image_id"]}
588
+ instances = output["instances"].to(self._cpu_device)
589
+ prediction["instances"] = instances_to_coco_json(
590
+ instances, input["image_id"])
591
+ self._predictions.append(prediction)
592
+
593
+ def evaluate(self):
594
+ if self._distributed:
595
+ comm.synchronize()
596
+ self._predictions = comm.gather(self._predictions, dst=0)
597
+ self._predictions = list(itertools.chain(*self._predictions))
598
+
599
+ if not comm.is_main_process():
600
+ return
601
+
602
+ if len(self._predictions) == 0:
603
+ self._logger.warning("[LVISEvaluator] Did not receive valid predictions.")
604
+ return {}
605
+
606
+ self._logger.info("Preparing results in the OID format ...")
607
+ self._oid_results = list(
608
+ itertools.chain(*[x["instances"] for x in self._predictions]))
609
+
610
+ # unmap the category ids for LVIS (from 0-indexed to 1-indexed)
611
+ for result in self._oid_results:
612
+ result["category_id"] += 1
613
+
614
+ PathManager.mkdirs(self._output_dir)
615
+ file_path = os.path.join(
616
+ self._output_dir, "oid_instances_results.json")
617
+ self._logger.info("Saving results to {}".format(file_path))
618
+ with PathManager.open(file_path, "w") as f:
619
+ f.write(json.dumps(self._oid_results))
620
+ f.flush()
621
+
622
+ if not self._do_evaluation:
623
+ self._logger.info("Annotations are not available for evaluation.")
624
+ return
625
+
626
+ self._logger.info("Evaluating predictions ...")
627
+ self._results = OrderedDict()
628
+ res, mAP = _evaluate_predictions_on_oid(
629
+ self._oid_api,
630
+ file_path,
631
+ eval_seg=self._mask_on,
632
+ class_names=self._metadata.get("thing_classes"),
633
+ )
634
+ self._results['bbox'] = res
635
+ mAP_out_path = os.path.join(self._output_dir, "oid_mAP.npy")
636
+ self._logger.info('Saving mAP to' + mAP_out_path)
637
+ np.save(mAP_out_path, mAP)
638
+ return copy.deepcopy(self._results)
639
+
640
+ def _evaluate_predictions_on_oid(
641
+ oid_gt, oid_results_path, eval_seg=False,
642
+ class_names=None):
643
+ logger = logging.getLogger(__name__)
644
+ metrics = ["AP50", "AP50_expand"]
645
+
646
+ results = {}
647
+ oid_eval = OIDEval(oid_gt, oid_results_path, 'bbox', expand_pred_label=False)
648
+ oid_eval.run()
649
+ oid_eval.print_results()
650
+ results["AP50"] = oid_eval.get_results()["AP50"]
651
+
652
+ if eval_seg:
653
+ oid_eval = OIDEval(oid_gt, oid_results_path, 'segm', expand_pred_label=False)
654
+ oid_eval.run()
655
+ oid_eval.print_results()
656
+ results["AP50_segm"] = oid_eval.get_results()["AP50"]
657
+ else:
658
+ oid_eval = OIDEval(oid_gt, oid_results_path, 'bbox', expand_pred_label=True)
659
+ oid_eval.run()
660
+ oid_eval.print_results()
661
+ results["AP50_expand"] = oid_eval.get_results()["AP50"]
662
+
663
+ mAP = np.zeros(len(class_names)) - 1
664
+ precisions = oid_eval.eval['precision']
665
+ assert len(class_names) == precisions.shape[2]
666
+ results_per_category = []
667
+ id2apiid = sorted(oid_gt.get_cat_ids())
668
+ inst_aware_ap, inst_count = 0, 0
669
+ for idx, name in enumerate(class_names):
670
+ precision = precisions[:, :, idx, 0]
671
+ precision = precision[precision > -1]
672
+ ap = np.mean(precision) if precision.size else float("nan")
673
+ inst_num = len(oid_gt.get_ann_ids(cat_ids=[id2apiid[idx]]))
674
+ if inst_num > 0:
675
+ results_per_category.append(("{} {}".format(
676
+ name.replace(' ', '_'),
677
+ inst_num if inst_num < 1000 else '{:.1f}k'.format(inst_num / 1000)),
678
+ float(ap * 100)))
679
+ inst_aware_ap += inst_num * ap
680
+ inst_count += inst_num
681
+ mAP[idx] = ap
682
+ # logger.info("{} {} {:.2f}".format(name, inst_num, ap * 100))
683
+ inst_aware_ap = inst_aware_ap * 100 / inst_count
684
+ N_COLS = min(6, len(results_per_category) * 2)
685
+ results_flatten = list(itertools.chain(*results_per_category))
686
+ results_2d = itertools.zip_longest(*[results_flatten[i::N_COLS] for i in range(N_COLS)])
687
+ table = tabulate(
688
+ results_2d,
689
+ tablefmt="pipe",
690
+ floatfmt=".3f",
691
+ headers=["category", "AP"] * (N_COLS // 2),
692
+ numalign="left",
693
+ )
694
+ logger.info("Per-category {} AP: \n".format('bbox') + table)
695
+ logger.info("Instance-aware {} AP: {:.4f}".format('bbox', inst_aware_ap))
696
+
697
+ logger.info("Evaluation results for bbox: \n" + \
698
+ create_small_table(results))
699
+ return results, mAP
proxydet/modeling/backbone/swintransformer.py ADDED
@@ -0,0 +1,750 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu, Yutong Lin, Yixuan Wei
6
+ # --------------------------------------------------------
7
+
8
+ # Copyright (c) Facebook, Inc. and its affiliates.
9
+ # Modified by Xingyi Zhou from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
10
+
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint as checkpoint
16
+ import numpy as np
17
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
18
+
19
+ from detectron2.layers import ShapeSpec
20
+ from detectron2.modeling.backbone.backbone import Backbone
21
+ from detectron2.modeling.backbone.build import BACKBONE_REGISTRY
22
+ from detectron2.modeling.backbone.fpn import FPN
23
+
24
+ from centernet.modeling.backbone.fpn_p5 import LastLevelP6P7_P5
25
+ from centernet.modeling.backbone.bifpn import BiFPN
26
+ # from .checkpoint import load_checkpoint
27
+
28
+ class Mlp(nn.Module):
29
+ """ Multilayer perceptron."""
30
+
31
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
32
+ super().__init__()
33
+ out_features = out_features or in_features
34
+ hidden_features = hidden_features or in_features
35
+ self.fc1 = nn.Linear(in_features, hidden_features)
36
+ self.act = act_layer()
37
+ self.fc2 = nn.Linear(hidden_features, out_features)
38
+ self.drop = nn.Dropout(drop)
39
+
40
+ def forward(self, x):
41
+ x = self.fc1(x)
42
+ x = self.act(x)
43
+ x = self.drop(x)
44
+ x = self.fc2(x)
45
+ x = self.drop(x)
46
+ return x
47
+
48
+
49
+ def window_partition(x, window_size):
50
+ """
51
+ Args:
52
+ x: (B, H, W, C)
53
+ window_size (int): window size
54
+ Returns:
55
+ windows: (num_windows*B, window_size, window_size, C)
56
+ """
57
+ B, H, W, C = x.shape
58
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
59
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
60
+ return windows
61
+
62
+
63
+ def window_reverse(windows, window_size, H, W):
64
+ """
65
+ Args:
66
+ windows: (num_windows*B, window_size, window_size, C)
67
+ window_size (int): Window size
68
+ H (int): Height of image
69
+ W (int): Width of image
70
+ Returns:
71
+ x: (B, H, W, C)
72
+ """
73
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
74
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
75
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
76
+ return x
77
+
78
+
79
+ class WindowAttention(nn.Module):
80
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
81
+ It supports both of shifted and non-shifted window.
82
+ Args:
83
+ dim (int): Number of input channels.
84
+ window_size (tuple[int]): The height and width of the window.
85
+ num_heads (int): Number of attention heads.
86
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
87
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
88
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
89
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
90
+ """
91
+
92
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
93
+
94
+ super().__init__()
95
+ self.dim = dim
96
+ self.window_size = window_size # Wh, Ww
97
+ self.num_heads = num_heads
98
+ head_dim = dim // num_heads
99
+ self.scale = qk_scale or head_dim ** -0.5
100
+
101
+ # define a parameter table of relative position bias
102
+ self.relative_position_bias_table = nn.Parameter(
103
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
104
+
105
+ # get pair-wise relative position index for each token inside the window
106
+ coords_h = torch.arange(self.window_size[0])
107
+ coords_w = torch.arange(self.window_size[1])
108
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
109
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
110
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
111
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
112
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
113
+ relative_coords[:, :, 1] += self.window_size[1] - 1
114
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
115
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
116
+ self.register_buffer("relative_position_index", relative_position_index)
117
+
118
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
119
+ self.attn_drop = nn.Dropout(attn_drop)
120
+ self.proj = nn.Linear(dim, dim)
121
+ self.proj_drop = nn.Dropout(proj_drop)
122
+
123
+ trunc_normal_(self.relative_position_bias_table, std=.02)
124
+ self.softmax = nn.Softmax(dim=-1)
125
+
126
+ def forward(self, x, mask=None):
127
+ """ Forward function.
128
+ Args:
129
+ x: input features with shape of (num_windows*B, N, C)
130
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
131
+ """
132
+ B_, N, C = x.shape
133
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
134
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
135
+
136
+ q = q * self.scale
137
+ attn = (q @ k.transpose(-2, -1))
138
+
139
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
140
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
141
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
142
+ attn = attn + relative_position_bias.unsqueeze(0)
143
+
144
+ if mask is not None:
145
+ nW = mask.shape[0]
146
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
147
+ attn = attn.view(-1, self.num_heads, N, N)
148
+ attn = self.softmax(attn)
149
+ else:
150
+ attn = self.softmax(attn)
151
+
152
+ attn = self.attn_drop(attn)
153
+
154
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
155
+ x = self.proj(x)
156
+ x = self.proj_drop(x)
157
+ return x
158
+
159
+
160
+ class SwinTransformerBlock(nn.Module):
161
+ """ Swin Transformer Block.
162
+ Args:
163
+ dim (int): Number of input channels.
164
+ num_heads (int): Number of attention heads.
165
+ window_size (int): Window size.
166
+ shift_size (int): Shift size for SW-MSA.
167
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
168
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
169
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
170
+ drop (float, optional): Dropout rate. Default: 0.0
171
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
172
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
173
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
174
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
175
+ """
176
+
177
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
178
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
179
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
180
+ super().__init__()
181
+ self.dim = dim
182
+ self.num_heads = num_heads
183
+ self.window_size = window_size
184
+ self.shift_size = shift_size
185
+ self.mlp_ratio = mlp_ratio
186
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
187
+
188
+ self.norm1 = norm_layer(dim)
189
+ self.attn = WindowAttention(
190
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
191
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
192
+
193
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
194
+ self.norm2 = norm_layer(dim)
195
+ mlp_hidden_dim = int(dim * mlp_ratio)
196
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
197
+
198
+ self.H = None
199
+ self.W = None
200
+
201
+ def forward(self, x, mask_matrix):
202
+ """ Forward function.
203
+ Args:
204
+ x: Input feature, tensor size (B, H*W, C).
205
+ H, W: Spatial resolution of the input feature.
206
+ mask_matrix: Attention mask for cyclic shift.
207
+ """
208
+ B, L, C = x.shape
209
+ H, W = self.H, self.W
210
+ assert L == H * W, "input feature has wrong size"
211
+
212
+ shortcut = x
213
+ x = self.norm1(x)
214
+ x = x.view(B, H, W, C)
215
+
216
+ # pad feature maps to multiples of window size
217
+ pad_l = pad_t = 0
218
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
219
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
220
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
221
+ _, Hp, Wp, _ = x.shape
222
+
223
+ # cyclic shift
224
+ if self.shift_size > 0:
225
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
226
+ attn_mask = mask_matrix
227
+ else:
228
+ shifted_x = x
229
+ attn_mask = None
230
+
231
+ # partition windows
232
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
233
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
234
+
235
+ # W-MSA/SW-MSA
236
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
237
+
238
+ # merge windows
239
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
240
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
241
+
242
+ # reverse cyclic shift
243
+ if self.shift_size > 0:
244
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
245
+ else:
246
+ x = shifted_x
247
+
248
+ if pad_r > 0 or pad_b > 0:
249
+ x = x[:, :H, :W, :].contiguous()
250
+
251
+ x = x.view(B, H * W, C)
252
+
253
+ # FFN
254
+ x = shortcut + self.drop_path(x)
255
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
256
+
257
+ return x
258
+
259
+
260
+ class PatchMerging(nn.Module):
261
+ """ Patch Merging Layer
262
+ Args:
263
+ dim (int): Number of input channels.
264
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
265
+ """
266
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
267
+ super().__init__()
268
+ self.dim = dim
269
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
270
+ self.norm = norm_layer(4 * dim)
271
+
272
+ def forward(self, x, H, W):
273
+ """ Forward function.
274
+ Args:
275
+ x: Input feature, tensor size (B, H*W, C).
276
+ H, W: Spatial resolution of the input feature.
277
+ """
278
+ B, L, C = x.shape
279
+ assert L == H * W, "input feature has wrong size"
280
+
281
+ x = x.view(B, H, W, C)
282
+
283
+ # padding
284
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
285
+ if pad_input:
286
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
287
+
288
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
289
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
290
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
291
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
292
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
293
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
294
+
295
+ x = self.norm(x)
296
+ x = self.reduction(x)
297
+
298
+ return x
299
+
300
+
301
+ class BasicLayer(nn.Module):
302
+ """ A basic Swin Transformer layer for one stage.
303
+ Args:
304
+ dim (int): Number of feature channels
305
+ depth (int): Depths of this stage.
306
+ num_heads (int): Number of attention head.
307
+ window_size (int): Local window size. Default: 7.
308
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
309
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
310
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
311
+ drop (float, optional): Dropout rate. Default: 0.0
312
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
313
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
314
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
315
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
316
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
317
+ """
318
+
319
+ def __init__(self,
320
+ dim,
321
+ depth,
322
+ num_heads,
323
+ window_size=7,
324
+ mlp_ratio=4.,
325
+ qkv_bias=True,
326
+ qk_scale=None,
327
+ drop=0.,
328
+ attn_drop=0.,
329
+ drop_path=0.,
330
+ norm_layer=nn.LayerNorm,
331
+ downsample=None,
332
+ use_checkpoint=False):
333
+ super().__init__()
334
+ self.window_size = window_size
335
+ self.shift_size = window_size // 2
336
+ self.depth = depth
337
+ self.use_checkpoint = use_checkpoint
338
+
339
+ # build blocks
340
+ self.blocks = nn.ModuleList([
341
+ SwinTransformerBlock(
342
+ dim=dim,
343
+ num_heads=num_heads,
344
+ window_size=window_size,
345
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
346
+ mlp_ratio=mlp_ratio,
347
+ qkv_bias=qkv_bias,
348
+ qk_scale=qk_scale,
349
+ drop=drop,
350
+ attn_drop=attn_drop,
351
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
352
+ norm_layer=norm_layer)
353
+ for i in range(depth)])
354
+
355
+ # patch merging layer
356
+ if downsample is not None:
357
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
358
+ else:
359
+ self.downsample = None
360
+
361
+ def forward(self, x, H, W):
362
+ """ Forward function.
363
+ Args:
364
+ x: Input feature, tensor size (B, H*W, C).
365
+ H, W: Spatial resolution of the input feature.
366
+ """
367
+
368
+ # calculate attention mask for SW-MSA
369
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
370
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
371
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
372
+ h_slices = (slice(0, -self.window_size),
373
+ slice(-self.window_size, -self.shift_size),
374
+ slice(-self.shift_size, None))
375
+ w_slices = (slice(0, -self.window_size),
376
+ slice(-self.window_size, -self.shift_size),
377
+ slice(-self.shift_size, None))
378
+ cnt = 0
379
+ for h in h_slices:
380
+ for w in w_slices:
381
+ img_mask[:, h, w, :] = cnt
382
+ cnt += 1
383
+
384
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
385
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
386
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
387
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
388
+
389
+ for blk in self.blocks:
390
+ blk.H, blk.W = H, W
391
+ if self.use_checkpoint:
392
+ x = checkpoint.checkpoint(blk, x, attn_mask)
393
+ else:
394
+ x = blk(x, attn_mask)
395
+ if self.downsample is not None:
396
+ x_down = self.downsample(x, H, W)
397
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
398
+ return x, H, W, x_down, Wh, Ww
399
+ else:
400
+ return x, H, W, x, H, W
401
+
402
+
403
+ class PatchEmbed(nn.Module):
404
+ """ Image to Patch Embedding
405
+ Args:
406
+ patch_size (int): Patch token size. Default: 4.
407
+ in_chans (int): Number of input image channels. Default: 3.
408
+ embed_dim (int): Number of linear projection output channels. Default: 96.
409
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
410
+ """
411
+
412
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
413
+ super().__init__()
414
+ patch_size = to_2tuple(patch_size)
415
+ self.patch_size = patch_size
416
+
417
+ self.in_chans = in_chans
418
+ self.embed_dim = embed_dim
419
+
420
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
421
+ if norm_layer is not None:
422
+ self.norm = norm_layer(embed_dim)
423
+ else:
424
+ self.norm = None
425
+
426
+ def forward(self, x):
427
+ """Forward function."""
428
+ # padding
429
+ _, _, H, W = x.size()
430
+ if W % self.patch_size[1] != 0:
431
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
432
+ if H % self.patch_size[0] != 0:
433
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
434
+
435
+ x = self.proj(x) # B C Wh Ww
436
+ if self.norm is not None:
437
+ Wh, Ww = x.size(2), x.size(3)
438
+ x = x.flatten(2).transpose(1, 2)
439
+ x = self.norm(x)
440
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
441
+
442
+ return x
443
+
444
+
445
+ class SwinTransformer(Backbone):
446
+ """ Swin Transformer backbone.
447
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
448
+ https://arxiv.org/pdf/2103.14030
449
+ Args:
450
+ pretrain_img_size (int): Input image size for training the pretrained model,
451
+ used in absolute postion embedding. Default 224.
452
+ patch_size (int | tuple(int)): Patch size. Default: 4.
453
+ in_chans (int): Number of input image channels. Default: 3.
454
+ embed_dim (int): Number of linear projection output channels. Default: 96.
455
+ depths (tuple[int]): Depths of each Swin Transformer stage.
456
+ num_heads (tuple[int]): Number of attention head of each stage.
457
+ window_size (int): Window size. Default: 7.
458
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
459
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
460
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
461
+ drop_rate (float): Dropout rate.
462
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
463
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
464
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
465
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
466
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
467
+ out_indices (Sequence[int]): Output from which stages.
468
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
469
+ -1 means not freezing any parameters.
470
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
471
+ """
472
+
473
+ def __init__(self,
474
+ pretrain_img_size=224,
475
+ patch_size=4,
476
+ in_chans=3,
477
+ embed_dim=96,
478
+ depths=[2, 2, 6, 2],
479
+ num_heads=[3, 6, 12, 24],
480
+ window_size=7,
481
+ mlp_ratio=4.,
482
+ qkv_bias=True,
483
+ qk_scale=None,
484
+ drop_rate=0.,
485
+ attn_drop_rate=0.,
486
+ drop_path_rate=0.2,
487
+ norm_layer=nn.LayerNorm,
488
+ ape=False,
489
+ patch_norm=True,
490
+ out_indices=(0, 1, 2, 3),
491
+ frozen_stages=-1,
492
+ use_checkpoint=False):
493
+ super().__init__()
494
+
495
+ self.pretrain_img_size = pretrain_img_size
496
+ self.num_layers = len(depths)
497
+ self.embed_dim = embed_dim
498
+ self.ape = ape
499
+ self.patch_norm = patch_norm
500
+ self.out_indices = out_indices
501
+ self.frozen_stages = frozen_stages
502
+
503
+ # split image into non-overlapping patches
504
+ self.patch_embed = PatchEmbed(
505
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
506
+ norm_layer=norm_layer if self.patch_norm else None)
507
+
508
+ # absolute position embedding
509
+ if self.ape:
510
+ pretrain_img_size = to_2tuple(pretrain_img_size)
511
+ patch_size = to_2tuple(patch_size)
512
+ patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
513
+
514
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
515
+ trunc_normal_(self.absolute_pos_embed, std=.02)
516
+
517
+ self.pos_drop = nn.Dropout(p=drop_rate)
518
+
519
+ # stochastic depth
520
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
521
+
522
+ # build layers
523
+ self.layers = nn.ModuleList()
524
+ for i_layer in range(self.num_layers):
525
+ layer = BasicLayer(
526
+ dim=int(embed_dim * 2 ** i_layer),
527
+ depth=depths[i_layer],
528
+ num_heads=num_heads[i_layer],
529
+ window_size=window_size,
530
+ mlp_ratio=mlp_ratio,
531
+ qkv_bias=qkv_bias,
532
+ qk_scale=qk_scale,
533
+ drop=drop_rate,
534
+ attn_drop=attn_drop_rate,
535
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
536
+ norm_layer=norm_layer,
537
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
538
+ use_checkpoint=use_checkpoint)
539
+ self.layers.append(layer)
540
+
541
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
542
+ self.num_features = num_features
543
+
544
+ # add a norm layer for each output
545
+ for i_layer in out_indices:
546
+ layer = norm_layer(num_features[i_layer])
547
+ layer_name = f'norm{i_layer}'
548
+ self.add_module(layer_name, layer)
549
+
550
+ self._freeze_stages()
551
+ self._out_features = ['swin{}'.format(i) for i in self.out_indices]
552
+ self._out_feature_channels = {
553
+ 'swin{}'.format(i): self.embed_dim * 2 ** i for i in self.out_indices
554
+ }
555
+ self._out_feature_strides = {
556
+ 'swin{}'.format(i): 2 ** (i + 2) for i in self.out_indices
557
+ }
558
+ self._size_devisibility = 32
559
+
560
+
561
+ def _freeze_stages(self):
562
+ if self.frozen_stages >= 0:
563
+ self.patch_embed.eval()
564
+ for param in self.patch_embed.parameters():
565
+ param.requires_grad = False
566
+
567
+ if self.frozen_stages >= 1 and self.ape:
568
+ self.absolute_pos_embed.requires_grad = False
569
+
570
+ if self.frozen_stages >= 2:
571
+ self.pos_drop.eval()
572
+ for i in range(0, self.frozen_stages - 1):
573
+ m = self.layers[i]
574
+ m.eval()
575
+ for param in m.parameters():
576
+ param.requires_grad = False
577
+
578
+ def init_weights(self, pretrained=None):
579
+ """Initialize the weights in backbone.
580
+ Args:
581
+ pretrained (str, optional): Path to pre-trained weights.
582
+ Defaults to None.
583
+ """
584
+
585
+ def _init_weights(m):
586
+ if isinstance(m, nn.Linear):
587
+ trunc_normal_(m.weight, std=.02)
588
+ if isinstance(m, nn.Linear) and m.bias is not None:
589
+ nn.init.constant_(m.bias, 0)
590
+ elif isinstance(m, nn.LayerNorm):
591
+ nn.init.constant_(m.bias, 0)
592
+ nn.init.constant_(m.weight, 1.0)
593
+
594
+ if isinstance(pretrained, str):
595
+ self.apply(_init_weights)
596
+ # load_checkpoint(self, pretrained, strict=False)
597
+ elif pretrained is None:
598
+ self.apply(_init_weights)
599
+ else:
600
+ raise TypeError('pretrained must be a str or None')
601
+
602
+ def forward(self, x):
603
+ """Forward function."""
604
+ x = self.patch_embed(x)
605
+
606
+ Wh, Ww = x.size(2), x.size(3)
607
+ if self.ape:
608
+ # interpolate the position embedding to the corresponding size
609
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
610
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
611
+ else:
612
+ x = x.flatten(2).transpose(1, 2)
613
+ x = self.pos_drop(x)
614
+
615
+ # outs = []
616
+ outs = {}
617
+ for i in range(self.num_layers):
618
+ layer = self.layers[i]
619
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
620
+
621
+ if i in self.out_indices:
622
+ norm_layer = getattr(self, f'norm{i}')
623
+ x_out = norm_layer(x_out)
624
+
625
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
626
+ # outs.append(out)
627
+ outs['swin{}'.format(i)] = out
628
+
629
+ return outs
630
+
631
+ def train(self, mode=True):
632
+ """Convert the model into training mode while keep layers freezed."""
633
+ super(SwinTransformer, self).train(mode)
634
+ self._freeze_stages()
635
+
636
+ size2config = {
637
+ 'T': {
638
+ 'window_size': 7,
639
+ 'embed_dim': 96,
640
+ 'depth': [2, 2, 6, 2],
641
+ 'num_heads': [3, 6, 12, 24],
642
+ 'drop_path_rate': 0.2,
643
+ 'pretrained': 'models/swin_tiny_patch4_window7_224.pth'
644
+ },
645
+ 'S': {
646
+ 'window_size': 7,
647
+ 'embed_dim': 96,
648
+ 'depth': [2, 2, 18, 2],
649
+ 'num_heads': [3, 6, 12, 24],
650
+ 'drop_path_rate': 0.2,
651
+ 'pretrained': 'models/swin_small_patch4_window7_224.pth'
652
+ },
653
+ 'B': {
654
+ 'window_size': 7,
655
+ 'embed_dim': 128,
656
+ 'depth': [2, 2, 18, 2],
657
+ 'num_heads': [4, 8, 16, 32],
658
+ 'drop_path_rate': 0.3,
659
+ 'pretrained': 'models/swin_base_patch4_window7_224.pth'
660
+ },
661
+ 'B-22k': {
662
+ 'window_size': 7,
663
+ 'embed_dim': 128,
664
+ 'depth': [2, 2, 18, 2],
665
+ 'num_heads': [4, 8, 16, 32],
666
+ 'drop_path_rate': 0.3,
667
+ 'pretrained': 'models/swin_base_patch4_window7_224_22k.pth'
668
+ },
669
+ 'B-22k-384': {
670
+ 'window_size': 12,
671
+ 'embed_dim': 128,
672
+ 'depth': [2, 2, 18, 2],
673
+ 'num_heads': [4, 8, 16, 32],
674
+ 'drop_path_rate': 0.3,
675
+ 'pretrained': 'models/swin_base_patch4_window12_384_22k.pth'
676
+ },
677
+ 'L-22k': {
678
+ 'window_size': 7,
679
+ 'embed_dim': 192,
680
+ 'depth': [2, 2, 18, 2],
681
+ 'num_heads': [6, 12, 24, 48],
682
+ 'drop_path_rate': 0.3, # TODO (xingyi): this is unclear
683
+ 'pretrained': 'models/swin_large_patch4_window7_224_22k.pth'
684
+ },
685
+ 'L-22k-384': {
686
+ 'window_size': 12,
687
+ 'embed_dim': 192,
688
+ 'depth': [2, 2, 18, 2],
689
+ 'num_heads': [6, 12, 24, 48],
690
+ 'drop_path_rate': 0.3, # TODO (xingyi): this is unclear
691
+ 'pretrained': 'models/swin_large_patch4_window12_384_22k.pth'
692
+ }
693
+ }
694
+
695
+ @BACKBONE_REGISTRY.register()
696
+ def build_swintransformer_backbone(cfg, input_shape):
697
+ """
698
+ """
699
+ config = size2config[cfg.MODEL.SWIN.SIZE]
700
+ out_indices = cfg.MODEL.SWIN.OUT_FEATURES
701
+ model = SwinTransformer(
702
+ embed_dim=config['embed_dim'],
703
+ window_size=config['window_size'],
704
+ depths=config['depth'],
705
+ num_heads=config['num_heads'],
706
+ drop_path_rate=config['drop_path_rate'],
707
+ out_indices=out_indices,
708
+ frozen_stages=-1,
709
+ use_checkpoint=cfg.MODEL.SWIN.USE_CHECKPOINT
710
+ )
711
+ # print('Initializing', config['pretrained'])
712
+ model.init_weights(config['pretrained'])
713
+ return model
714
+
715
+
716
+ @BACKBONE_REGISTRY.register()
717
+ def build_swintransformer_fpn_backbone(cfg, input_shape: ShapeSpec):
718
+ """
719
+ """
720
+ bottom_up = build_swintransformer_backbone(cfg, input_shape)
721
+ in_features = cfg.MODEL.FPN.IN_FEATURES
722
+ out_channels = cfg.MODEL.FPN.OUT_CHANNELS
723
+ backbone = FPN(
724
+ bottom_up=bottom_up,
725
+ in_features=in_features,
726
+ out_channels=out_channels,
727
+ norm=cfg.MODEL.FPN.NORM,
728
+ top_block=LastLevelP6P7_P5(out_channels, out_channels),
729
+ fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
730
+ )
731
+ return backbone
732
+
733
+
734
+ @BACKBONE_REGISTRY.register()
735
+ def build_swintransformer_bifpn_backbone(cfg, input_shape: ShapeSpec):
736
+ """
737
+ """
738
+ bottom_up = build_swintransformer_backbone(cfg, input_shape)
739
+ in_features = cfg.MODEL.FPN.IN_FEATURES
740
+ backbone = BiFPN(
741
+ cfg=cfg,
742
+ bottom_up=bottom_up,
743
+ in_features=in_features,
744
+ out_channels=cfg.MODEL.BIFPN.OUT_CHANNELS,
745
+ norm=cfg.MODEL.BIFPN.NORM,
746
+ num_levels=cfg.MODEL.BIFPN.NUM_LEVELS,
747
+ num_bifpn=cfg.MODEL.BIFPN.NUM_BIFPN,
748
+ separable_conv=cfg.MODEL.BIFPN.SEPARABLE_CONV,
749
+ )
750
+ return backbone
proxydet/modeling/backbone/timm.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ import math
5
+ from os.path import join
6
+ import numpy as np
7
+ import copy
8
+ from functools import partial
9
+
10
+ import torch
11
+ from torch import nn
12
+ import torch.utils.model_zoo as model_zoo
13
+ import torch.nn.functional as F
14
+ import fvcore.nn.weight_init as weight_init
15
+
16
+ from detectron2.modeling.backbone import FPN
17
+ from detectron2.modeling.backbone.build import BACKBONE_REGISTRY
18
+ from detectron2.layers.batch_norm import get_norm, FrozenBatchNorm2d
19
+ from detectron2.modeling.backbone import Backbone
20
+
21
+ from timm import create_model
22
+ from timm.models.helpers import build_model_with_cfg
23
+ from timm.models.registry import register_model
24
+ from timm.models.resnet import ResNet, Bottleneck
25
+ from timm.models.resnet import default_cfgs as default_cfgs_resnet
26
+ from timm.models.convnext import ConvNeXt, default_cfgs, checkpoint_filter_fn
27
+
28
+
29
+ @register_model
30
+ def convnext_tiny_21k(pretrained=False, **kwargs):
31
+ model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
32
+ cfg = default_cfgs['convnext_tiny']
33
+ cfg['url'] = 'https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth'
34
+ model = build_model_with_cfg(
35
+ ConvNeXt, 'convnext_tiny', pretrained,
36
+ default_cfg=cfg,
37
+ pretrained_filter_fn=checkpoint_filter_fn,
38
+ feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
39
+ **model_args)
40
+ return model
41
+
42
+ class CustomResNet(ResNet):
43
+ def __init__(self, **kwargs):
44
+ self.out_indices = kwargs.pop('out_indices')
45
+ super().__init__(**kwargs)
46
+
47
+
48
+ def forward(self, x):
49
+ x = self.conv1(x)
50
+ x = self.bn1(x)
51
+ x = self.act1(x)
52
+ x = self.maxpool(x)
53
+ ret = [x]
54
+ x = self.layer1(x)
55
+ ret.append(x)
56
+ x = self.layer2(x)
57
+ ret.append(x)
58
+ x = self.layer3(x)
59
+ ret.append(x)
60
+ x = self.layer4(x)
61
+ ret.append(x)
62
+ return [ret[i] for i in self.out_indices]
63
+
64
+
65
+ def load_pretrained(self, cached_file):
66
+ data = torch.load(cached_file, map_location='cpu')
67
+ if 'state_dict' in data:
68
+ self.load_state_dict(data['state_dict'])
69
+ else:
70
+ self.load_state_dict(data)
71
+
72
+
73
+ model_params = {
74
+ 'resnet50_in21k': dict(block=Bottleneck, layers=[3, 4, 6, 3]),
75
+ }
76
+
77
+
78
+ def create_timm_resnet(variant, out_indices, pretrained=False, **kwargs):
79
+ params = model_params[variant]
80
+ default_cfgs_resnet['resnet50_in21k'] = \
81
+ copy.deepcopy(default_cfgs_resnet['resnet50'])
82
+ default_cfgs_resnet['resnet50_in21k']['url'] = \
83
+ 'https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/resnet50_miil_21k.pth'
84
+ default_cfgs_resnet['resnet50_in21k']['num_classes'] = 11221
85
+
86
+ return build_model_with_cfg(
87
+ CustomResNet, variant, pretrained,
88
+ default_cfg=default_cfgs_resnet[variant],
89
+ out_indices=out_indices,
90
+ pretrained_custom_load=True,
91
+ **params,
92
+ **kwargs)
93
+
94
+
95
+ class LastLevelP6P7_P5(nn.Module):
96
+ """
97
+ """
98
+ def __init__(self, in_channels, out_channels):
99
+ super().__init__()
100
+ self.num_levels = 2
101
+ self.in_feature = "p5"
102
+ self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
103
+ self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
104
+ for module in [self.p6, self.p7]:
105
+ weight_init.c2_xavier_fill(module)
106
+
107
+ def forward(self, c5):
108
+ p6 = self.p6(c5)
109
+ p7 = self.p7(F.relu(p6))
110
+ return [p6, p7]
111
+
112
+
113
+ def freeze_module(x):
114
+ """
115
+ """
116
+ for p in x.parameters():
117
+ p.requires_grad = False
118
+ FrozenBatchNorm2d.convert_frozen_batchnorm(x)
119
+ return x
120
+
121
+
122
+ class TIMM(Backbone):
123
+ def __init__(self, base_name, out_levels, freeze_at=0, norm='FrozenBN', pretrained=False):
124
+ super().__init__()
125
+ out_indices = [x - 1 for x in out_levels]
126
+ if base_name in model_params:
127
+ self.base = create_timm_resnet(
128
+ base_name, out_indices=out_indices,
129
+ pretrained=False)
130
+ elif 'eff' in base_name or 'resnet' in base_name or 'regnet' in base_name:
131
+ self.base = create_model(
132
+ base_name, features_only=True,
133
+ out_indices=out_indices, pretrained=pretrained)
134
+ elif 'convnext' in base_name:
135
+ drop_path_rate = 0.2 \
136
+ if ('tiny' in base_name or 'small' in base_name) else 0.3
137
+ self.base = create_model(
138
+ base_name, features_only=True,
139
+ out_indices=out_indices, pretrained=pretrained,
140
+ drop_path_rate=drop_path_rate)
141
+ else:
142
+ assert 0, base_name
143
+ feature_info = [dict(num_chs=f['num_chs'], reduction=f['reduction']) \
144
+ for i, f in enumerate(self.base.feature_info)]
145
+ self._out_features = ['layer{}'.format(x) for x in out_levels]
146
+ self._out_feature_channels = {
147
+ 'layer{}'.format(l): feature_info[l - 1]['num_chs'] for l in out_levels}
148
+ self._out_feature_strides = {
149
+ 'layer{}'.format(l): feature_info[l - 1]['reduction'] for l in out_levels}
150
+ self._size_divisibility = max(self._out_feature_strides.values())
151
+ if 'resnet' in base_name:
152
+ self.freeze(freeze_at)
153
+ if norm == 'FrozenBN':
154
+ self = FrozenBatchNorm2d.convert_frozen_batchnorm(self)
155
+
156
+ def freeze(self, freeze_at=0):
157
+ """
158
+ """
159
+ if freeze_at >= 1:
160
+ print('Frezing', self.base.conv1)
161
+ self.base.conv1 = freeze_module(self.base.conv1)
162
+ if freeze_at >= 2:
163
+ print('Frezing', self.base.layer1)
164
+ self.base.layer1 = freeze_module(self.base.layer1)
165
+
166
+ def forward(self, x):
167
+ features = self.base(x)
168
+ ret = {k: v for k, v in zip(self._out_features, features)}
169
+ return ret
170
+
171
+ @property
172
+ def size_divisibility(self):
173
+ return self._size_divisibility
174
+
175
+
176
+ @BACKBONE_REGISTRY.register()
177
+ def build_timm_backbone(cfg, input_shape):
178
+ model = TIMM(
179
+ cfg.MODEL.TIMM.BASE_NAME,
180
+ cfg.MODEL.TIMM.OUT_LEVELS,
181
+ freeze_at=cfg.MODEL.TIMM.FREEZE_AT,
182
+ norm=cfg.MODEL.TIMM.NORM,
183
+ pretrained=cfg.MODEL.TIMM.PRETRAINED,
184
+ )
185
+ return model
186
+
187
+
188
+ @BACKBONE_REGISTRY.register()
189
+ def build_p67_timm_fpn_backbone(cfg, input_shape):
190
+ """
191
+ """
192
+ bottom_up = build_timm_backbone(cfg, input_shape)
193
+ in_features = cfg.MODEL.FPN.IN_FEATURES
194
+ out_channels = cfg.MODEL.FPN.OUT_CHANNELS
195
+ backbone = FPN(
196
+ bottom_up=bottom_up,
197
+ in_features=in_features,
198
+ out_channels=out_channels,
199
+ norm=cfg.MODEL.FPN.NORM,
200
+ top_block=LastLevelP6P7_P5(out_channels, out_channels),
201
+ fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
202
+ )
203
+ return backbone
204
+
205
+ @BACKBONE_REGISTRY.register()
206
+ def build_p35_timm_fpn_backbone(cfg, input_shape):
207
+ """
208
+ """
209
+ bottom_up = build_timm_backbone(cfg, input_shape)
210
+
211
+ in_features = cfg.MODEL.FPN.IN_FEATURES
212
+ out_channels = cfg.MODEL.FPN.OUT_CHANNELS
213
+ backbone = FPN(
214
+ bottom_up=bottom_up,
215
+ in_features=in_features,
216
+ out_channels=out_channels,
217
+ norm=cfg.MODEL.FPN.NORM,
218
+ top_block=None,
219
+ fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
220
+ )
221
+ return backbone
proxydet/modeling/debug.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import os
7
+
8
+ COLORS = ((np.random.rand(1300, 3) * 0.4 + 0.6) * 255).astype(
9
+ np.uint8).reshape(1300, 1, 1, 3)
10
+
11
+ def _get_color_image(heatmap):
12
+ heatmap = heatmap.reshape(
13
+ heatmap.shape[0], heatmap.shape[1], heatmap.shape[2], 1)
14
+ if heatmap.shape[0] == 1:
15
+ color_map = (heatmap * np.ones((1, 1, 1, 3), np.uint8) * 255).max(
16
+ axis=0).astype(np.uint8) # H, W, 3
17
+ else:
18
+ color_map = (heatmap * COLORS[:heatmap.shape[0]]).max(axis=0).astype(np.uint8) # H, W, 3
19
+
20
+ return color_map
21
+
22
+ def _blend_image(image, color_map, a=0.7):
23
+ color_map = cv2.resize(color_map, (image.shape[1], image.shape[0]))
24
+ ret = np.clip(image * (1 - a) + color_map * a, 0, 255).astype(np.uint8)
25
+ return ret
26
+
27
+ def _blend_image_heatmaps(image, color_maps, a=0.7):
28
+ merges = np.zeros((image.shape[0], image.shape[1], 3), np.float32)
29
+ for color_map in color_maps:
30
+ color_map = cv2.resize(color_map, (image.shape[1], image.shape[0]))
31
+ merges = np.maximum(merges, color_map)
32
+ ret = np.clip(image * (1 - a) + merges * a, 0, 255).astype(np.uint8)
33
+ return ret
34
+
35
+ def _decompose_level(x, shapes_per_level, N):
36
+ '''
37
+ x: LNHiWi x C
38
+ '''
39
+ x = x.view(x.shape[0], -1)
40
+ ret = []
41
+ st = 0
42
+ for l in range(len(shapes_per_level)):
43
+ ret.append([])
44
+ h = shapes_per_level[l][0].int().item()
45
+ w = shapes_per_level[l][1].int().item()
46
+ for i in range(N):
47
+ ret[l].append(x[st + h * w * i:st + h * w * (i + 1)].view(
48
+ h, w, -1).permute(2, 0, 1))
49
+ st += h * w * N
50
+ return ret
51
+
52
+ def _imagelist_to_tensor(images):
53
+ images = [x for x in images]
54
+ image_sizes = [x.shape[-2:] for x in images]
55
+ h = max([size[0] for size in image_sizes])
56
+ w = max([size[1] for size in image_sizes])
57
+ S = 32
58
+ h, w = ((h - 1) // S + 1) * S, ((w - 1) // S + 1) * S
59
+ images = [F.pad(x, (0, w - x.shape[2], 0, h - x.shape[1], 0, 0)) \
60
+ for x in images]
61
+ images = torch.stack(images)
62
+ return images
63
+
64
+
65
+ def _ind2il(ind, shapes_per_level, N):
66
+ r = ind
67
+ l = 0
68
+ S = 0
69
+ while r - S >= N * shapes_per_level[l][0] * shapes_per_level[l][1]:
70
+ S += N * shapes_per_level[l][0] * shapes_per_level[l][1]
71
+ l += 1
72
+ i = (r - S) // (shapes_per_level[l][0] * shapes_per_level[l][1])
73
+ return i, l
74
+
75
+ def debug_train(
76
+ images, gt_instances, flattened_hms, reg_targets, labels, pos_inds,
77
+ shapes_per_level, locations, strides):
78
+ '''
79
+ images: N x 3 x H x W
80
+ flattened_hms: LNHiWi x C
81
+ shapes_per_level: L x 2 [(H_i, W_i)]
82
+ locations: LNHiWi x 2
83
+ '''
84
+ reg_inds = torch.nonzero(
85
+ reg_targets.max(dim=1)[0] > 0).squeeze(1)
86
+ N = len(images)
87
+ images = _imagelist_to_tensor(images)
88
+ repeated_locations = [torch.cat([loc] * N, dim=0) \
89
+ for loc in locations]
90
+ locations = torch.cat(repeated_locations, dim=0)
91
+ gt_hms = _decompose_level(flattened_hms, shapes_per_level, N)
92
+ masks = flattened_hms.new_zeros((flattened_hms.shape[0], 1))
93
+ masks[pos_inds] = 1
94
+ masks = _decompose_level(masks, shapes_per_level, N)
95
+ for i in range(len(images)):
96
+ image = images[i].detach().cpu().numpy().transpose(1, 2, 0)
97
+ color_maps = []
98
+ for l in range(len(gt_hms)):
99
+ color_map = _get_color_image(
100
+ gt_hms[l][i].detach().cpu().numpy())
101
+ color_maps.append(color_map)
102
+ cv2.imshow('gthm_{}'.format(l), color_map)
103
+ blend = _blend_image_heatmaps(image.copy(), color_maps)
104
+ if gt_instances is not None:
105
+ bboxes = gt_instances[i].gt_boxes.tensor
106
+ for j in range(len(bboxes)):
107
+ bbox = bboxes[j]
108
+ cv2.rectangle(
109
+ blend,
110
+ (int(bbox[0]), int(bbox[1])),
111
+ (int(bbox[2]), int(bbox[3])),
112
+ (0, 0, 255), 3, cv2.LINE_AA)
113
+
114
+ for j in range(len(pos_inds)):
115
+ image_id, l = _ind2il(pos_inds[j], shapes_per_level, N)
116
+ if image_id != i:
117
+ continue
118
+ loc = locations[pos_inds[j]]
119
+ cv2.drawMarker(
120
+ blend, (int(loc[0]), int(loc[1])), (0, 255, 255),
121
+ markerSize=(l + 1) * 16)
122
+
123
+ for j in range(len(reg_inds)):
124
+ image_id, l = _ind2il(reg_inds[j], shapes_per_level, N)
125
+ if image_id != i:
126
+ continue
127
+ ltrb = reg_targets[reg_inds[j]]
128
+ ltrb *= strides[l]
129
+ loc = locations[reg_inds[j]]
130
+ bbox = [(loc[0] - ltrb[0]), (loc[1] - ltrb[1]),
131
+ (loc[0] + ltrb[2]), (loc[1] + ltrb[3])]
132
+ cv2.rectangle(
133
+ blend,
134
+ (int(bbox[0]), int(bbox[1])),
135
+ (int(bbox[2]), int(bbox[3])),
136
+ (255, 0, 0), 1, cv2.LINE_AA)
137
+ cv2.circle(blend, (int(loc[0]), int(loc[1])), 2, (255, 0, 0), -1)
138
+
139
+ cv2.imshow('blend', blend)
140
+ cv2.waitKey()
141
+
142
+
143
+ def debug_test(
144
+ images, logits_pred, reg_pred, agn_hm_pred=[], preds=[],
145
+ vis_thresh=0.3, debug_show_name=False, mult_agn=False):
146
+ '''
147
+ images: N x 3 x H x W
148
+ class_target: LNHiWi x C
149
+ cat_agn_heatmap: LNHiWi
150
+ shapes_per_level: L x 2 [(H_i, W_i)]
151
+ '''
152
+ N = len(images)
153
+ for i in range(len(images)):
154
+ image = images[i].detach().cpu().numpy().transpose(1, 2, 0)
155
+ result = image.copy().astype(np.uint8)
156
+ pred_image = image.copy().astype(np.uint8)
157
+ color_maps = []
158
+ L = len(logits_pred)
159
+ for l in range(L):
160
+ if logits_pred[0] is not None:
161
+ stride = min(image.shape[0], image.shape[1]) / min(
162
+ logits_pred[l][i].shape[1], logits_pred[l][i].shape[2])
163
+ else:
164
+ stride = min(image.shape[0], image.shape[1]) / min(
165
+ agn_hm_pred[l][i].shape[1], agn_hm_pred[l][i].shape[2])
166
+ stride = stride if stride < 60 else 64 if stride < 100 else 128
167
+ if logits_pred[0] is not None:
168
+ if mult_agn:
169
+ logits_pred[l][i] = logits_pred[l][i] * agn_hm_pred[l][i]
170
+ color_map = _get_color_image(
171
+ logits_pred[l][i].detach().cpu().numpy())
172
+ color_maps.append(color_map)
173
+ cv2.imshow('predhm_{}'.format(l), color_map)
174
+
175
+ if debug_show_name:
176
+ from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES
177
+ cat2name = [x['name'] for x in LVIS_CATEGORIES]
178
+ for j in range(len(preds[i].scores) if preds is not None else 0):
179
+ if preds[i].scores[j] > vis_thresh:
180
+ bbox = preds[i].proposal_boxes[j] \
181
+ if preds[i].has('proposal_boxes') else \
182
+ preds[i].pred_boxes[j]
183
+ bbox = bbox.tensor[0].detach().cpu().numpy().astype(np.int32)
184
+ cat = int(preds[i].pred_classes[j]) \
185
+ if preds[i].has('pred_classes') else 0
186
+ cl = COLORS[cat, 0, 0]
187
+ cv2.rectangle(
188
+ pred_image, (int(bbox[0]), int(bbox[1])),
189
+ (int(bbox[2]), int(bbox[3])),
190
+ (int(cl[0]), int(cl[1]), int(cl[2])), 2, cv2.LINE_AA)
191
+ if debug_show_name:
192
+ txt = '{}{:.1f}'.format(
193
+ cat2name[cat] if cat > 0 else '',
194
+ preds[i].scores[j])
195
+ font = cv2.FONT_HERSHEY_SIMPLEX
196
+ cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0]
197
+ cv2.rectangle(
198
+ pred_image,
199
+ (int(bbox[0]), int(bbox[1] - cat_size[1] - 2)),
200
+ (int(bbox[0] + cat_size[0]), int(bbox[1] - 2)),
201
+ (int(cl[0]), int(cl[1]), int(cl[2])), -1)
202
+ cv2.putText(
203
+ pred_image, txt, (int(bbox[0]), int(bbox[1] - 2)),
204
+ font, 0.5, (0, 0, 0), thickness=1, lineType=cv2.LINE_AA)
205
+
206
+
207
+ if agn_hm_pred[l] is not None:
208
+ agn_hm_ = agn_hm_pred[l][i, 0, :, :, None].detach().cpu().numpy()
209
+ agn_hm_ = (agn_hm_ * np.array([255, 255, 255]).reshape(
210
+ 1, 1, 3)).astype(np.uint8)
211
+ cv2.imshow('agn_hm_{}'.format(l), agn_hm_)
212
+ blend = _blend_image_heatmaps(image.copy(), color_maps)
213
+ cv2.imshow('blend', blend)
214
+ cv2.imshow('preds', pred_image)
215
+ cv2.waitKey()
216
+
217
+ global cnt
218
+ cnt = 0
219
+
220
+ def debug_second_stage(images, instances, proposals=None, vis_thresh=0.3,
221
+ save_debug=False, debug_show_name=False, image_labels=[],
222
+ save_debug_path='output/save_debug/',
223
+ bgr=False):
224
+ images = _imagelist_to_tensor(images)
225
+ if 'COCO' in save_debug_path:
226
+ from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
227
+ cat2name = [x['name'] for x in COCO_CATEGORIES]
228
+ else:
229
+ from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES
230
+ cat2name = ['({}){}'.format(x['frequency'], x['name']) \
231
+ for x in LVIS_CATEGORIES]
232
+ for i in range(len(images)):
233
+ image = images[i].detach().cpu().numpy().transpose(1, 2, 0).astype(np.uint8).copy()
234
+ if bgr:
235
+ image = image[:, :, ::-1].copy()
236
+ if instances[i].has('gt_boxes'):
237
+ bboxes = instances[i].gt_boxes.tensor.cpu().numpy()
238
+ scores = np.ones(bboxes.shape[0])
239
+ cats = instances[i].gt_classes.cpu().numpy()
240
+ else:
241
+ bboxes = instances[i].pred_boxes.tensor.cpu().numpy()
242
+ scores = instances[i].scores.cpu().numpy()
243
+ cats = instances[i].pred_classes.cpu().numpy()
244
+ for j in range(len(bboxes)):
245
+ if scores[j] > vis_thresh:
246
+ bbox = bboxes[j]
247
+ cl = COLORS[cats[j], 0, 0]
248
+ cl = (int(cl[0]), int(cl[1]), int(cl[2]))
249
+ cv2.rectangle(
250
+ image,
251
+ (int(bbox[0]), int(bbox[1])),
252
+ (int(bbox[2]), int(bbox[3])),
253
+ cl, 2, cv2.LINE_AA)
254
+ if debug_show_name:
255
+ cat = cats[j]
256
+ txt = '{}{:.1f}'.format(
257
+ cat2name[cat] if cat > 0 else '',
258
+ scores[j])
259
+ font = cv2.FONT_HERSHEY_SIMPLEX
260
+ cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0]
261
+ cv2.rectangle(
262
+ image,
263
+ (int(bbox[0]), int(bbox[1] - cat_size[1] - 2)),
264
+ (int(bbox[0] + cat_size[0]), int(bbox[1] - 2)),
265
+ (int(cl[0]), int(cl[1]), int(cl[2])), -1)
266
+ cv2.putText(
267
+ image, txt, (int(bbox[0]), int(bbox[1] - 2)),
268
+ font, 0.5, (0, 0, 0), thickness=1, lineType=cv2.LINE_AA)
269
+ if proposals is not None:
270
+ proposal_image = images[i].detach().cpu().numpy().transpose(1, 2, 0).astype(np.uint8).copy()
271
+ if bgr:
272
+ proposal_image = proposal_image.copy()
273
+ else:
274
+ proposal_image = proposal_image[:, :, ::-1].copy()
275
+ bboxes = proposals[i].proposal_boxes.tensor.cpu().numpy()
276
+ if proposals[i].has('scores'):
277
+ scores = proposals[i].scores.detach().cpu().numpy()
278
+ else:
279
+ scores = proposals[i].objectness_logits.detach().cpu().numpy()
280
+ # selected = -1
281
+ # if proposals[i].has('image_loss'):
282
+ # selected = proposals[i].image_loss.argmin()
283
+ if proposals[i].has('selected'):
284
+ selected = proposals[i].selected
285
+ else:
286
+ selected = [-1 for _ in range(len(bboxes))]
287
+ for j in range(len(bboxes)):
288
+ if scores[j] > vis_thresh or selected[j] >= 0:
289
+ bbox = bboxes[j]
290
+ cl = (209, 159, 83)
291
+ th = 2
292
+ if selected[j] >= 0:
293
+ cl = (0, 0, 0xa4)
294
+ th = 4
295
+ cv2.rectangle(
296
+ proposal_image,
297
+ (int(bbox[0]), int(bbox[1])),
298
+ (int(bbox[2]), int(bbox[3])),
299
+ cl, th, cv2.LINE_AA)
300
+ if selected[j] >= 0 and debug_show_name:
301
+ cat = selected[j].item()
302
+ txt = '{}'.format(cat2name[cat])
303
+ font = cv2.FONT_HERSHEY_SIMPLEX
304
+ cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0]
305
+ cv2.rectangle(
306
+ proposal_image,
307
+ (int(bbox[0]), int(bbox[1] - cat_size[1] - 2)),
308
+ (int(bbox[0] + cat_size[0]), int(bbox[1] - 2)),
309
+ (int(cl[0]), int(cl[1]), int(cl[2])), -1)
310
+ cv2.putText(
311
+ proposal_image, txt,
312
+ (int(bbox[0]), int(bbox[1] - 2)),
313
+ font, 0.5, (0, 0, 0), thickness=1,
314
+ lineType=cv2.LINE_AA)
315
+
316
+ if save_debug:
317
+ global cnt
318
+ cnt = (cnt + 1) % 5000
319
+ if not os.path.exists(save_debug_path):
320
+ os.mkdir(save_debug_path)
321
+ save_name = '{}/{:05d}.jpg'.format(save_debug_path, cnt)
322
+ if i < len(image_labels):
323
+ image_label = image_labels[i]
324
+ save_name = '{}/{:05d}'.format(save_debug_path, cnt)
325
+ for x in image_label:
326
+ class_name = cat2name[x]
327
+ save_name = save_name + '|{}'.format(class_name)
328
+ save_name = save_name + '.jpg'
329
+ cv2.imwrite(save_name, proposal_image)
330
+ else:
331
+ cv2.imshow('image', image)
332
+ if proposals is not None:
333
+ cv2.imshow('proposals', proposal_image)
334
+ cv2.waitKey()
proxydet/modeling/meta_arch/custom_rcnn.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import copy
3
+ import logging
4
+ import numpy as np
5
+ from typing import Dict, List, Optional, Tuple
6
+ import torch
7
+ from torch import nn
8
+ import json
9
+ from detectron2.utils.events import get_event_storage
10
+ from detectron2.config import configurable
11
+ from detectron2.structures import ImageList, Instances, Boxes
12
+ import detectron2.utils.comm as comm
13
+
14
+ from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
15
+ from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN
16
+ from detectron2.modeling.postprocessing import detector_postprocess
17
+ from detectron2.utils.visualizer import Visualizer, _create_text_labels
18
+ from detectron2.data.detection_utils import convert_image_to_rgb
19
+
20
+ from torch.cuda.amp import autocast
21
+ from ..text.text_encoder import build_text_encoder
22
+ from ..utils import load_class_freq, get_fed_loss_inds
23
+
24
+ @META_ARCH_REGISTRY.register()
25
+ class CustomRCNN(GeneralizedRCNN):
26
+ '''
27
+ Add image labels
28
+ '''
29
+ @configurable
30
+ def __init__(
31
+ self,
32
+ with_image_labels = False,
33
+ dataset_loss_weight = [],
34
+ fp16 = False,
35
+ sync_caption_batch = False,
36
+ roi_head_name = '',
37
+ cap_batch_ratio = 4,
38
+ with_caption = False,
39
+ dynamic_classifier = False,
40
+ **kwargs):
41
+ """
42
+ """
43
+ self.with_image_labels = with_image_labels
44
+ self.dataset_loss_weight = dataset_loss_weight
45
+ self.fp16 = fp16
46
+ self.with_caption = with_caption
47
+ self.sync_caption_batch = sync_caption_batch
48
+ self.roi_head_name = roi_head_name
49
+ self.cap_batch_ratio = cap_batch_ratio
50
+ self.dynamic_classifier = dynamic_classifier
51
+ self.return_proposal = False
52
+ if self.dynamic_classifier:
53
+ self.freq_weight = kwargs.pop('freq_weight')
54
+ self.num_classes = kwargs.pop('num_classes')
55
+ self.num_sample_cats = kwargs.pop('num_sample_cats')
56
+ super().__init__(**kwargs)
57
+ assert self.proposal_generator is not None
58
+ if self.with_caption:
59
+ assert not self.dynamic_classifier
60
+ self.text_encoder = build_text_encoder(pretrain=True)
61
+ for v in self.text_encoder.parameters():
62
+ v.requires_grad = False
63
+
64
+
65
+ @classmethod
66
+ def from_config(cls, cfg):
67
+ ret = super().from_config(cfg)
68
+ ret.update({
69
+ 'with_image_labels': cfg.WITH_IMAGE_LABELS,
70
+ 'dataset_loss_weight': cfg.MODEL.DATASET_LOSS_WEIGHT,
71
+ 'fp16': cfg.FP16,
72
+ 'with_caption': cfg.MODEL.WITH_CAPTION,
73
+ 'sync_caption_batch': cfg.MODEL.SYNC_CAPTION_BATCH,
74
+ 'dynamic_classifier': cfg.MODEL.DYNAMIC_CLASSIFIER,
75
+ 'roi_head_name': cfg.MODEL.ROI_HEADS.NAME,
76
+ 'cap_batch_ratio': cfg.MODEL.CAP_BATCH_RATIO,
77
+ })
78
+ if ret['dynamic_classifier']:
79
+ ret['freq_weight'] = load_class_freq(
80
+ cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH,
81
+ cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT)
82
+ ret['num_classes'] = cfg.MODEL.ROI_HEADS.NUM_CLASSES
83
+ ret['num_sample_cats'] = cfg.MODEL.NUM_SAMPLE_CATS
84
+ return ret
85
+
86
+
87
+ def inference(
88
+ self,
89
+ batched_inputs: Tuple[Dict[str, torch.Tensor]],
90
+ detected_instances: Optional[List[Instances]] = None,
91
+ do_postprocess: bool = True,
92
+ ):
93
+ assert not self.training
94
+ assert detected_instances is None
95
+
96
+ images = self.preprocess_image(batched_inputs)
97
+ features = self.backbone(images.tensor)
98
+ proposals, _ = self.proposal_generator(images, features, None)
99
+ results, _ = self.roi_heads(images, features, proposals)
100
+ if do_postprocess:
101
+ assert not torch.jit.is_scripting(), \
102
+ "Scripting is not supported for postprocess."
103
+ return CustomRCNN._postprocess(
104
+ results, batched_inputs, images.image_sizes)
105
+ else:
106
+ return results
107
+
108
+
109
+ def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
110
+ """
111
+ Add ann_type
112
+ Ignore proposal loss when training with image labels
113
+ """
114
+ if not self.training:
115
+ return self.inference(batched_inputs)
116
+
117
+ images = self.preprocess_image(batched_inputs)
118
+
119
+ ann_type = 'box'
120
+ gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
121
+ if self.with_image_labels:
122
+ for inst, x in zip(gt_instances, batched_inputs):
123
+ inst._ann_type = x['ann_type']
124
+ inst._pos_category_ids = x['pos_category_ids']
125
+ ann_types = [x['ann_type'] for x in batched_inputs]
126
+ assert len(set(ann_types)) == 1
127
+ ann_type = ann_types[0]
128
+ if ann_type in ['prop', 'proptag']:
129
+ for t in gt_instances:
130
+ t.gt_classes *= 0
131
+
132
+ if self.fp16: # TODO (zhouxy): improve
133
+ with autocast():
134
+ features = self.backbone(images.tensor.half())
135
+ features = {k: v.float() for k, v in features.items()}
136
+ else:
137
+ features = self.backbone(images.tensor)
138
+
139
+ cls_features, cls_inds, caption_features = None, None, None
140
+
141
+ if self.with_caption and 'caption' in ann_type:
142
+ inds = [torch.randint(len(x['captions']), (1,))[0].item() \
143
+ for x in batched_inputs]
144
+ caps = [x['captions'][ind] for ind, x in zip(inds, batched_inputs)]
145
+ caption_features = self.text_encoder(caps).float()
146
+ if self.sync_caption_batch:
147
+ caption_features = self._sync_caption_features(
148
+ caption_features, ann_type, len(batched_inputs))
149
+
150
+ if self.dynamic_classifier and ann_type != 'caption':
151
+ cls_inds = self._sample_cls_inds(gt_instances, ann_type) # inds, inv_inds
152
+ ind_with_bg = cls_inds[0].tolist() + [-1]
153
+ cls_features = self.roi_heads.box_predictor[
154
+ 0].cls_score.zs_weight[:, ind_with_bg].permute(1, 0).contiguous()
155
+
156
+ classifier_info = cls_features, cls_inds, caption_features
157
+ proposals, proposal_losses = self.proposal_generator(
158
+ images, features, gt_instances)
159
+
160
+ if self.roi_head_name in ['StandardROIHeads', 'CascadeROIHeads']:
161
+ proposals, detector_losses = self.roi_heads(
162
+ images, features, proposals, gt_instances)
163
+ else:
164
+ proposals, detector_losses = self.roi_heads(
165
+ images, features, proposals, gt_instances,
166
+ ann_type=ann_type, classifier_info=classifier_info)
167
+
168
+ if self.vis_period > 0:
169
+ storage = get_event_storage()
170
+ if storage.iter % self.vis_period == 0:
171
+ self.visualize_training(batched_inputs, proposals)
172
+
173
+ losses = {}
174
+ losses.update(detector_losses)
175
+ if self.with_image_labels:
176
+ if ann_type in ['box', 'prop', 'proptag']:
177
+ losses.update(proposal_losses)
178
+ else: # ignore proposal loss for non-bbox data
179
+ losses.update({k: v * 0 for k, v in proposal_losses.items()})
180
+ else:
181
+ losses.update(proposal_losses)
182
+ if len(self.dataset_loss_weight) > 0:
183
+ dataset_sources = [x['dataset_source'] for x in batched_inputs]
184
+ assert len(set(dataset_sources)) == 1
185
+ dataset_source = dataset_sources[0]
186
+ for k in losses:
187
+ losses[k] *= self.dataset_loss_weight[dataset_source]
188
+
189
+ if self.return_proposal:
190
+ return proposals, losses
191
+ else:
192
+ return losses
193
+
194
+
195
+ def _sync_caption_features(self, caption_features, ann_type, BS):
196
+ has_caption_feature = (caption_features is not None)
197
+ BS = (BS * self.cap_batch_ratio) if (ann_type == 'box') else BS
198
+ rank = torch.full(
199
+ (BS, 1), comm.get_rank(), dtype=torch.float32,
200
+ device=self.device)
201
+ if not has_caption_feature:
202
+ caption_features = rank.new_zeros((BS, 512))
203
+ caption_features = torch.cat([caption_features, rank], dim=1)
204
+ global_caption_features = comm.all_gather(caption_features)
205
+ caption_features = torch.cat(
206
+ [x.to(self.device) for x in global_caption_features], dim=0) \
207
+ if has_caption_feature else None # (NB) x (D + 1)
208
+ return caption_features
209
+
210
+
211
+ def _sample_cls_inds(self, gt_instances, ann_type='box'):
212
+ if ann_type == 'box':
213
+ gt_classes = torch.cat(
214
+ [x.gt_classes for x in gt_instances])
215
+ C = len(self.freq_weight)
216
+ freq_weight = self.freq_weight
217
+ else:
218
+ gt_classes = torch.cat(
219
+ [torch.tensor(
220
+ x._pos_category_ids,
221
+ dtype=torch.long, device=x.gt_classes.device) \
222
+ for x in gt_instances])
223
+ C = self.num_classes
224
+ freq_weight = None
225
+ assert gt_classes.max() < C, '{} {}'.format(gt_classes.max(), C)
226
+ inds = get_fed_loss_inds(
227
+ gt_classes, self.num_sample_cats, C,
228
+ weight=freq_weight)
229
+ cls_id_map = gt_classes.new_full(
230
+ (self.num_classes + 1,), len(inds))
231
+ cls_id_map[inds] = torch.arange(len(inds), device=cls_id_map.device)
232
+ return inds, cls_id_map
proxydet/modeling/meta_arch/d2_deformable_detr.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ import math
6
+
7
+ from detectron2.modeling import META_ARCH_REGISTRY, build_backbone
8
+ from detectron2.structures import Boxes, Instances
9
+ from ..utils import load_class_freq, get_fed_loss_inds
10
+
11
+ from models.backbone import Joiner
12
+ from models.deformable_detr import DeformableDETR, SetCriterion, MLP
13
+ from models.deformable_detr import _get_clones
14
+ from models.matcher import HungarianMatcher
15
+ from models.position_encoding import PositionEmbeddingSine
16
+ from models.deformable_transformer import DeformableTransformer
17
+ from models.segmentation import sigmoid_focal_loss
18
+ from util.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh
19
+ from util.misc import NestedTensor, accuracy
20
+
21
+
22
+ __all__ = ["DeformableDetr"]
23
+
24
+ class CustomSetCriterion(SetCriterion):
25
+ def __init__(self, num_classes, matcher, weight_dict, losses, \
26
+ focal_alpha=0.25, use_fed_loss=False):
27
+ super().__init__(num_classes, matcher, weight_dict, losses, focal_alpha)
28
+ self.use_fed_loss = use_fed_loss
29
+ if self.use_fed_loss:
30
+ self.register_buffer(
31
+ 'fed_loss_weight', load_class_freq(freq_weight=0.5))
32
+
33
+ def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
34
+ """Classification loss (NLL)
35
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
36
+ """
37
+ assert 'pred_logits' in outputs
38
+ src_logits = outputs['pred_logits']
39
+
40
+ idx = self._get_src_permutation_idx(indices)
41
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
42
+ target_classes = torch.full(src_logits.shape[:2], self.num_classes,
43
+ dtype=torch.int64, device=src_logits.device)
44
+ target_classes[idx] = target_classes_o
45
+
46
+ target_classes_onehot = torch.zeros(
47
+ [src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
48
+ dtype=src_logits.dtype, layout=src_logits.layout,
49
+ device=src_logits.device)
50
+ target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
51
+
52
+ target_classes_onehot = target_classes_onehot[:,:,:-1] # B x N x C
53
+ if self.use_fed_loss:
54
+ inds = get_fed_loss_inds(
55
+ gt_classes=target_classes_o,
56
+ num_sample_cats=50,
57
+ weight=self.fed_loss_weight,
58
+ C=target_classes_onehot.shape[2])
59
+ loss_ce = sigmoid_focal_loss(
60
+ src_logits[:, :, inds],
61
+ target_classes_onehot[:, :, inds],
62
+ num_boxes,
63
+ alpha=self.focal_alpha,
64
+ gamma=2) * src_logits.shape[1]
65
+ else:
66
+ loss_ce = sigmoid_focal_loss(
67
+ src_logits, target_classes_onehot, num_boxes,
68
+ alpha=self.focal_alpha,
69
+ gamma=2) * src_logits.shape[1]
70
+ losses = {'loss_ce': loss_ce}
71
+
72
+ if log:
73
+ # TODO this should probably be a separate loss, not hacked in this one here
74
+ losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
75
+ return losses
76
+
77
+
78
+ class MaskedBackbone(nn.Module):
79
+ """ This is a thin wrapper around D2's backbone to provide padding masking"""
80
+
81
+ def __init__(self, cfg):
82
+ super().__init__()
83
+ self.backbone = build_backbone(cfg)
84
+ backbone_shape = self.backbone.output_shape()
85
+ self.feature_strides = [backbone_shape[f].stride for f in backbone_shape.keys()]
86
+ self.strides = [backbone_shape[f].stride for f in backbone_shape.keys()]
87
+ self.num_channels = [backbone_shape[x].channels for x in backbone_shape.keys()]
88
+
89
+ def forward(self, tensor_list: NestedTensor):
90
+ xs = self.backbone(tensor_list.tensors)
91
+ out = {}
92
+ for name, x in xs.items():
93
+ m = tensor_list.mask
94
+ assert m is not None
95
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
96
+ out[name] = NestedTensor(x, mask)
97
+ return out
98
+
99
+ @META_ARCH_REGISTRY.register()
100
+ class DeformableDetr(nn.Module):
101
+ """
102
+ Implement Deformable Detr
103
+ """
104
+
105
+ def __init__(self, cfg):
106
+ super().__init__()
107
+ self.with_image_labels = cfg.WITH_IMAGE_LABELS
108
+ self.weak_weight = cfg.MODEL.DETR.WEAK_WEIGHT
109
+
110
+ self.device = torch.device(cfg.MODEL.DEVICE)
111
+ self.test_topk = cfg.TEST.DETECTIONS_PER_IMAGE
112
+ self.num_classes = cfg.MODEL.DETR.NUM_CLASSES
113
+ self.mask_on = cfg.MODEL.MASK_ON
114
+ hidden_dim = cfg.MODEL.DETR.HIDDEN_DIM
115
+ num_queries = cfg.MODEL.DETR.NUM_OBJECT_QUERIES
116
+
117
+ # Transformer parameters:
118
+ nheads = cfg.MODEL.DETR.NHEADS
119
+ dropout = cfg.MODEL.DETR.DROPOUT
120
+ dim_feedforward = cfg.MODEL.DETR.DIM_FEEDFORWARD
121
+ enc_layers = cfg.MODEL.DETR.ENC_LAYERS
122
+ dec_layers = cfg.MODEL.DETR.DEC_LAYERS
123
+ num_feature_levels = cfg.MODEL.DETR.NUM_FEATURE_LEVELS
124
+ two_stage = cfg.MODEL.DETR.TWO_STAGE
125
+ with_box_refine = cfg.MODEL.DETR.WITH_BOX_REFINE
126
+
127
+ # Loss parameters:
128
+ giou_weight = cfg.MODEL.DETR.GIOU_WEIGHT
129
+ l1_weight = cfg.MODEL.DETR.L1_WEIGHT
130
+ deep_supervision = cfg.MODEL.DETR.DEEP_SUPERVISION
131
+ cls_weight = cfg.MODEL.DETR.CLS_WEIGHT
132
+ focal_alpha = cfg.MODEL.DETR.FOCAL_ALPHA
133
+
134
+ N_steps = hidden_dim // 2
135
+ d2_backbone = MaskedBackbone(cfg)
136
+ backbone = Joiner(d2_backbone, PositionEmbeddingSine(N_steps, normalize=True))
137
+
138
+ transformer = DeformableTransformer(
139
+ d_model=hidden_dim,
140
+ nhead=nheads,
141
+ num_encoder_layers=enc_layers,
142
+ num_decoder_layers=dec_layers,
143
+ dim_feedforward=dim_feedforward,
144
+ dropout=dropout,
145
+ activation="relu",
146
+ return_intermediate_dec=True,
147
+ num_feature_levels=num_feature_levels,
148
+ dec_n_points=4,
149
+ enc_n_points=4,
150
+ two_stage=two_stage,
151
+ two_stage_num_proposals=num_queries)
152
+
153
+ self.detr = DeformableDETR(
154
+ backbone, transformer, num_classes=self.num_classes,
155
+ num_queries=num_queries,
156
+ num_feature_levels=num_feature_levels,
157
+ aux_loss=deep_supervision,
158
+ with_box_refine=with_box_refine,
159
+ two_stage=two_stage,
160
+ )
161
+
162
+ if self.mask_on:
163
+ assert 0, 'Mask is not supported yet :('
164
+
165
+ matcher = HungarianMatcher(
166
+ cost_class=cls_weight, cost_bbox=l1_weight, cost_giou=giou_weight)
167
+ weight_dict = {"loss_ce": cls_weight, "loss_bbox": l1_weight}
168
+ weight_dict["loss_giou"] = giou_weight
169
+ if deep_supervision:
170
+ aux_weight_dict = {}
171
+ for i in range(dec_layers - 1):
172
+ aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
173
+ weight_dict.update(aux_weight_dict)
174
+ print('weight_dict', weight_dict)
175
+ losses = ["labels", "boxes", "cardinality"]
176
+ if self.mask_on:
177
+ losses += ["masks"]
178
+ self.criterion = CustomSetCriterion(
179
+ self.num_classes, matcher=matcher, weight_dict=weight_dict,
180
+ focal_alpha=focal_alpha,
181
+ losses=losses,
182
+ use_fed_loss=cfg.MODEL.DETR.USE_FED_LOSS
183
+ )
184
+ pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1)
185
+ pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1)
186
+ self.normalizer = lambda x: (x - pixel_mean) / pixel_std
187
+
188
+
189
+ def forward(self, batched_inputs):
190
+ """
191
+ Args:
192
+ Returns:
193
+ dict[str: Tensor]:
194
+ mapping from a named loss to a tensor storing the loss. Used during training only.
195
+ """
196
+ images = self.preprocess_image(batched_inputs)
197
+ output = self.detr(images)
198
+ if self.training:
199
+ gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
200
+ targets = self.prepare_targets(gt_instances)
201
+ loss_dict = self.criterion(output, targets)
202
+ weight_dict = self.criterion.weight_dict
203
+ for k in loss_dict.keys():
204
+ if k in weight_dict:
205
+ loss_dict[k] *= weight_dict[k]
206
+ if self.with_image_labels:
207
+ if batched_inputs[0]['ann_type'] in ['image', 'captiontag']:
208
+ loss_dict['loss_image'] = self.weak_weight * self._weak_loss(
209
+ output, batched_inputs)
210
+ else:
211
+ loss_dict['loss_image'] = images[0].new_zeros(
212
+ [1], dtype=torch.float32)[0]
213
+ # import pdb; pdb.set_trace()
214
+ return loss_dict
215
+ else:
216
+ image_sizes = output["pred_boxes"].new_tensor(
217
+ [(t["height"], t["width"]) for t in batched_inputs])
218
+ results = self.post_process(output, image_sizes)
219
+ return results
220
+
221
+
222
+ def prepare_targets(self, targets):
223
+ new_targets = []
224
+ for targets_per_image in targets:
225
+ h, w = targets_per_image.image_size
226
+ image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device)
227
+ gt_classes = targets_per_image.gt_classes
228
+ gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy
229
+ gt_boxes = box_xyxy_to_cxcywh(gt_boxes)
230
+ new_targets.append({"labels": gt_classes, "boxes": gt_boxes})
231
+ if self.mask_on and hasattr(targets_per_image, 'gt_masks'):
232
+ assert 0, 'Mask is not supported yet :('
233
+ gt_masks = targets_per_image.gt_masks
234
+ gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)
235
+ new_targets[-1].update({'masks': gt_masks})
236
+ return new_targets
237
+
238
+
239
+ def post_process(self, outputs, target_sizes):
240
+ """
241
+ """
242
+ out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
243
+ assert len(out_logits) == len(target_sizes)
244
+ assert target_sizes.shape[1] == 2
245
+
246
+ prob = out_logits.sigmoid()
247
+ topk_values, topk_indexes = torch.topk(
248
+ prob.view(out_logits.shape[0], -1), self.test_topk, dim=1)
249
+ scores = topk_values
250
+ topk_boxes = topk_indexes // out_logits.shape[2]
251
+ labels = topk_indexes % out_logits.shape[2]
252
+ boxes = box_cxcywh_to_xyxy(out_bbox)
253
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
254
+
255
+ # and from relative [0, 1] to absolute [0, height] coordinates
256
+ img_h, img_w = target_sizes.unbind(1)
257
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
258
+ boxes = boxes * scale_fct[:, None, :]
259
+
260
+ results = []
261
+ for s, l, b, size in zip(scores, labels, boxes, target_sizes):
262
+ r = Instances((size[0], size[1]))
263
+ r.pred_boxes = Boxes(b)
264
+ r.scores = s
265
+ r.pred_classes = l
266
+ results.append({'instances': r})
267
+ return results
268
+
269
+
270
+ def preprocess_image(self, batched_inputs):
271
+ """
272
+ Normalize, pad and batch the input images.
273
+ """
274
+ images = [self.normalizer(x["image"].to(self.device)) for x in batched_inputs]
275
+ return images
276
+
277
+
278
+ def _weak_loss(self, outputs, batched_inputs):
279
+ loss = 0
280
+ for b, x in enumerate(batched_inputs):
281
+ labels = x['pos_category_ids']
282
+ pred_logits = [outputs['pred_logits'][b]]
283
+ pred_boxes = [outputs['pred_boxes'][b]]
284
+ for xx in outputs['aux_outputs']:
285
+ pred_logits.append(xx['pred_logits'][b])
286
+ pred_boxes.append(xx['pred_boxes'][b])
287
+ pred_logits = torch.stack(pred_logits, dim=0) # L x N x C
288
+ pred_boxes = torch.stack(pred_boxes, dim=0) # L x N x 4
289
+ for label in labels:
290
+ loss += self._max_size_loss(
291
+ pred_logits, pred_boxes, label) / len(labels)
292
+ loss = loss / len(batched_inputs)
293
+ return loss
294
+
295
+
296
+ def _max_size_loss(self, logits, boxes, label):
297
+ '''
298
+ Inputs:
299
+ logits: L x N x C
300
+ boxes: L x N x 4
301
+ '''
302
+ target = logits.new_zeros((logits.shape[0], logits.shape[2]))
303
+ target[:, label] = 1.
304
+ sizes = boxes[..., 2] * boxes[..., 3] # L x N
305
+ ind = sizes.argmax(dim=1) # L
306
+ loss = F.binary_cross_entropy_with_logits(
307
+ logits[range(len(ind)), ind], target, reduction='sum')
308
+ return loss
proxydet/modeling/roi_heads/proxydet_fast_rcnn.py ADDED
@@ -0,0 +1,618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ '''
3
+ Modifications Copyright (c) 2024-present NAVER Corp, Apache License v2.0
4
+ original source: https://github.com/facebookresearch/Detic/blob/main/detic/modeling/roi_heads/detic_fast_rcnn.py
5
+ '''
6
+ import logging
7
+ import math
8
+ import json
9
+ import numpy as np
10
+ from typing import Dict, Union
11
+ import torch
12
+ from fvcore.nn import giou_loss, smooth_l1_loss
13
+ from torch import nn
14
+ from torch.nn import functional as F
15
+ import fvcore.nn.weight_init as weight_init
16
+ import detectron2.utils.comm as comm
17
+ from detectron2.config import configurable
18
+ from detectron2.layers import ShapeSpec, batched_nms, cat, cross_entropy, nonzero_tuple
19
+ from detectron2.structures import Boxes, Instances
20
+ from detectron2.utils.events import get_event_storage
21
+ from detectron2.modeling.box_regression import Box2BoxTransform
22
+ from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
23
+ from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference
24
+ from detectron2.modeling.roi_heads.fast_rcnn import _log_classification_stats
25
+
26
+ from torch.cuda.amp import autocast
27
+ from ..utils import load_class_freq, get_fed_loss_inds
28
+ from .zero_shot_classifier import ZeroShotClassifier
29
+
30
+ __all__ = ["ProxydetFastRCNNOutputLayers"]
31
+
32
+
33
+ class ProxydetFastRCNNOutputLayers(FastRCNNOutputLayers):
34
+ @configurable
35
+ def __init__(
36
+ self,
37
+ input_shape: ShapeSpec,
38
+ *,
39
+ mult_proposal_score=False,
40
+ cls_score=None,
41
+ sync_caption_batch = False,
42
+ use_sigmoid_ce = False,
43
+ use_fed_loss = False,
44
+ ignore_zero_cats = False,
45
+ fed_loss_num_cat = 50,
46
+ dynamic_classifier = False,
47
+ image_label_loss = '',
48
+ use_zeroshot_cls = False,
49
+ image_loss_weight = 0.1,
50
+ with_softmax_prop = False,
51
+ caption_weight = 1.0,
52
+ neg_cap_weight = 1.0,
53
+ add_image_box = False,
54
+ debug = False,
55
+ prior_prob = 0.01,
56
+ cat_freq_path = '',
57
+ fed_loss_freq_weight = 0.5,
58
+ softmax_weak_loss = False,
59
+ use_regional_embedding=False,
60
+ base_cat_mask: str = None,
61
+ **kwargs,
62
+ ):
63
+ super().__init__(
64
+ input_shape=input_shape,
65
+ **kwargs,
66
+ )
67
+ self.mult_proposal_score = mult_proposal_score
68
+ self.sync_caption_batch = sync_caption_batch
69
+ self.use_sigmoid_ce = use_sigmoid_ce
70
+ self.use_fed_loss = use_fed_loss
71
+ self.ignore_zero_cats = ignore_zero_cats
72
+ self.fed_loss_num_cat = fed_loss_num_cat
73
+ self.dynamic_classifier = dynamic_classifier
74
+ self.image_label_loss = image_label_loss
75
+ self.use_zeroshot_cls = use_zeroshot_cls
76
+ self.image_loss_weight = image_loss_weight
77
+ self.with_softmax_prop = with_softmax_prop
78
+ self.caption_weight = caption_weight
79
+ self.neg_cap_weight = neg_cap_weight
80
+ self.add_image_box = add_image_box
81
+ self.softmax_weak_loss = softmax_weak_loss
82
+ self.debug = debug
83
+ self.use_regional_embedding = use_regional_embedding
84
+ self.base_cat_mask = torch.tensor(np.load(base_cat_mask).nonzero()[0])
85
+
86
+ if softmax_weak_loss:
87
+ assert image_label_loss in ['max_size']
88
+
89
+ if self.use_sigmoid_ce:
90
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
91
+ nn.init.constant_(self.cls_score.bias, bias_value)
92
+
93
+ if self.use_fed_loss or self.ignore_zero_cats:
94
+ freq_weight = load_class_freq(cat_freq_path, fed_loss_freq_weight)
95
+ self.register_buffer('freq_weight', freq_weight)
96
+ else:
97
+ self.freq_weight = None
98
+
99
+ if self.use_fed_loss and len(self.freq_weight) < self.num_classes:
100
+ # assert self.num_classes == 11493
101
+ print('Extending federated loss weight')
102
+ self.freq_weight = torch.cat(
103
+ [self.freq_weight,
104
+ self.freq_weight.new_zeros(
105
+ self.num_classes - len(self.freq_weight))]
106
+ )
107
+
108
+ assert (not self.dynamic_classifier) or (not self.use_fed_loss)
109
+ input_size = input_shape.channels * \
110
+ (input_shape.width or 1) * (input_shape.height or 1)
111
+
112
+ if self.use_zeroshot_cls:
113
+ del self.cls_score
114
+ del self.bbox_pred
115
+ assert cls_score is not None
116
+ self.cls_score = cls_score
117
+ self.bbox_pred = nn.Sequential(
118
+ nn.Linear(input_size, input_size),
119
+ nn.ReLU(inplace=True),
120
+ nn.Linear(input_size, 4)
121
+ )
122
+ weight_init.c2_xavier_fill(self.bbox_pred[0])
123
+ nn.init.normal_(self.bbox_pred[-1].weight, std=0.001)
124
+ nn.init.constant_(self.bbox_pred[-1].bias, 0)
125
+
126
+ if self.with_softmax_prop:
127
+ self.prop_score = nn.Sequential(
128
+ nn.Linear(input_size, input_size),
129
+ nn.ReLU(inplace=True),
130
+ nn.Linear(input_size, self.num_classes + 1),
131
+ )
132
+ weight_init.c2_xavier_fill(self.prop_score[0])
133
+ nn.init.normal_(self.prop_score[-1].weight, mean=0, std=0.001)
134
+ nn.init.constant_(self.prop_score[-1].bias, 0)
135
+
136
+
137
+ @classmethod
138
+ def from_config(cls, cfg, input_shape):
139
+ ret = super().from_config(cfg, input_shape)
140
+ ret.update({
141
+ 'mult_proposal_score': cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE,
142
+ 'sync_caption_batch': cfg.MODEL.SYNC_CAPTION_BATCH,
143
+ 'use_sigmoid_ce': cfg.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE,
144
+ 'use_fed_loss': cfg.MODEL.ROI_BOX_HEAD.USE_FED_LOSS,
145
+ 'ignore_zero_cats': cfg.MODEL.ROI_BOX_HEAD.IGNORE_ZERO_CATS,
146
+ 'fed_loss_num_cat': cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CAT,
147
+ 'dynamic_classifier': cfg.MODEL.DYNAMIC_CLASSIFIER,
148
+ 'image_label_loss': cfg.MODEL.ROI_BOX_HEAD.IMAGE_LABEL_LOSS,
149
+ 'use_zeroshot_cls': cfg.MODEL.ROI_BOX_HEAD.USE_ZEROSHOT_CLS,
150
+ 'image_loss_weight': cfg.MODEL.ROI_BOX_HEAD.IMAGE_LOSS_WEIGHT,
151
+ 'with_softmax_prop': cfg.MODEL.ROI_BOX_HEAD.WITH_SOFTMAX_PROP,
152
+ 'caption_weight': cfg.MODEL.ROI_BOX_HEAD.CAPTION_WEIGHT,
153
+ 'neg_cap_weight': cfg.MODEL.ROI_BOX_HEAD.NEG_CAP_WEIGHT,
154
+ 'add_image_box': cfg.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX,
155
+ 'debug': cfg.DEBUG or cfg.SAVE_DEBUG or cfg.IS_DEBUG,
156
+ 'prior_prob': cfg.MODEL.ROI_BOX_HEAD.PRIOR_PROB,
157
+ 'cat_freq_path': cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH,
158
+ 'fed_loss_freq_weight': cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT,
159
+ 'softmax_weak_loss': cfg.MODEL.ROI_BOX_HEAD.SOFTMAX_WEAK_LOSS,
160
+ "use_regional_embedding": cfg.MODEL.ROI_BOX_HEAD.USE_REGIONAL_EMBEDDING,
161
+ "base_cat_mask": cfg.MODEL.ROI_HEADS.BASE_CAT_MASK,
162
+ })
163
+ if ret['use_zeroshot_cls']:
164
+ ret['cls_score'] = ZeroShotClassifier(cfg, input_shape)
165
+ return ret
166
+
167
+ def losses(self, predictions, proposals, \
168
+ use_advanced_loss=True,
169
+ classifier_info=(None,None,None)):
170
+ """
171
+ enable advanced loss
172
+ """
173
+ scores, proposal_deltas = predictions
174
+ gt_classes = (
175
+ cat([p.gt_classes for p in proposals], dim=0) if len(proposals) else torch.empty(0)
176
+ )
177
+ num_classes = self.num_classes
178
+ if self.dynamic_classifier:
179
+ _, cls_id_map = classifier_info[1]
180
+ gt_classes = cls_id_map[gt_classes]
181
+ num_classes = scores.shape[1] - 1
182
+ assert cls_id_map[self.num_classes] == num_classes
183
+ _log_classification_stats(scores, gt_classes)
184
+
185
+ if len(proposals):
186
+ proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) # Nx4
187
+ assert not proposal_boxes.requires_grad, "Proposals should not require gradients!"
188
+ gt_boxes = cat(
189
+ [(p.gt_boxes if p.has("gt_boxes") else p.proposal_boxes).tensor for p in proposals],
190
+ dim=0,
191
+ )
192
+ else:
193
+ proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device)
194
+
195
+ if self.use_sigmoid_ce:
196
+ loss_cls = self.sigmoid_cross_entropy_loss(scores, gt_classes)
197
+ else:
198
+ loss_cls = self.softmax_cross_entropy_loss(scores, gt_classes)
199
+ return {
200
+ "loss_cls": loss_cls,
201
+ "loss_box_reg": self.box_reg_loss(
202
+ proposal_boxes, gt_boxes, proposal_deltas, gt_classes,
203
+ num_classes=num_classes)
204
+ }
205
+
206
+
207
+ def sigmoid_cross_entropy_loss(self, pred_class_logits, gt_classes):
208
+ if pred_class_logits.numel() == 0:
209
+ return pred_class_logits.new_zeros([1])[0] # This is more robust than .sum() * 0.
210
+
211
+ B = pred_class_logits.shape[0]
212
+ C = pred_class_logits.shape[1] - 1
213
+
214
+ target = pred_class_logits.new_zeros(B, C + 1)
215
+ target[range(len(gt_classes)), gt_classes] = 1 # B x (C + 1)
216
+ target = target[:, :C] # B x C
217
+
218
+ weight = 1
219
+
220
+ if self.use_fed_loss and (self.freq_weight is not None): # fedloss
221
+ appeared = get_fed_loss_inds(
222
+ gt_classes,
223
+ num_sample_cats=self.fed_loss_num_cat,
224
+ C=C,
225
+ weight=self.freq_weight)
226
+ appeared_mask = appeared.new_zeros(C + 1)
227
+ appeared_mask[appeared] = 1 # C + 1
228
+ appeared_mask = appeared_mask[:C]
229
+ fed_w = appeared_mask.view(1, C).expand(B, C)
230
+ weight = weight * fed_w.float()
231
+ if self.ignore_zero_cats and (self.freq_weight is not None):
232
+ w = (self.freq_weight.view(-1) > 1e-4).float()
233
+ weight = weight * w.view(1, C).expand(B, C)
234
+ # import pdb; pdb.set_trace()
235
+
236
+ cls_loss = F.binary_cross_entropy_with_logits(
237
+ pred_class_logits[:, :-1], target, reduction="none"
238
+ ) # B x C
239
+ loss = torch.sum(cls_loss * weight) / B
240
+ return loss
241
+
242
+
243
+ def softmax_cross_entropy_loss(self, pred_class_logits, gt_classes):
244
+ """
245
+ change _no_instance handling
246
+ """
247
+ if pred_class_logits.numel() == 0:
248
+ return pred_class_logits.new_zeros([1])[0]
249
+
250
+ if self.ignore_zero_cats and (self.freq_weight is not None):
251
+ zero_weight = torch.cat([
252
+ (self.freq_weight.view(-1) > 1e-4).float(),
253
+ self.freq_weight.new_ones(1)]) # C + 1
254
+ loss = F.cross_entropy(
255
+ pred_class_logits, gt_classes,
256
+ weight=zero_weight, reduction="mean")
257
+ elif self.use_fed_loss and (self.freq_weight is not None): # fedloss
258
+ C = pred_class_logits.shape[1] - 1
259
+ appeared = get_fed_loss_inds(
260
+ gt_classes,
261
+ num_sample_cats=self.fed_loss_num_cat,
262
+ C=C,
263
+ weight=self.freq_weight)
264
+ appeared_mask = appeared.new_zeros(C + 1).float()
265
+ appeared_mask[appeared] = 1. # C + 1
266
+ appeared_mask[C] = 1.
267
+ loss = F.cross_entropy(
268
+ pred_class_logits, gt_classes,
269
+ weight=appeared_mask, reduction="mean")
270
+ else:
271
+ loss = F.cross_entropy(
272
+ pred_class_logits, gt_classes, reduction="mean")
273
+ return loss
274
+
275
+
276
+ def box_reg_loss(
277
+ self, proposal_boxes, gt_boxes, pred_deltas, gt_classes,
278
+ num_classes=-1):
279
+ """
280
+ Allow custom background index
281
+ """
282
+ num_classes = num_classes if num_classes > 0 else self.num_classes
283
+ box_dim = proposal_boxes.shape[1] # 4 or 5
284
+ fg_inds = nonzero_tuple((gt_classes >= 0) & (gt_classes < num_classes))[0]
285
+ if pred_deltas.shape[1] == box_dim: # cls-agnostic regression
286
+ fg_pred_deltas = pred_deltas[fg_inds]
287
+ else:
288
+ fg_pred_deltas = pred_deltas.view(-1, self.num_classes, box_dim)[
289
+ fg_inds, gt_classes[fg_inds]
290
+ ]
291
+
292
+ if self.box_reg_loss_type == "smooth_l1":
293
+ gt_pred_deltas = self.box2box_transform.get_deltas(
294
+ proposal_boxes[fg_inds],
295
+ gt_boxes[fg_inds],
296
+ )
297
+ loss_box_reg = smooth_l1_loss(
298
+ fg_pred_deltas, gt_pred_deltas, self.smooth_l1_beta, reduction="sum"
299
+ )
300
+ elif self.box_reg_loss_type == "giou":
301
+ fg_pred_boxes = self.box2box_transform.apply_deltas(
302
+ fg_pred_deltas, proposal_boxes[fg_inds]
303
+ )
304
+ loss_box_reg = giou_loss(fg_pred_boxes, gt_boxes[fg_inds], reduction="sum")
305
+ else:
306
+ raise ValueError(f"Invalid bbox reg loss type '{self.box_reg_loss_type}'")
307
+ return loss_box_reg / max(gt_classes.numel(), 1.0)
308
+
309
+ def inference(self, predictions, proposals):
310
+ """
311
+ enable use proposal boxes
312
+ """
313
+ predictions = (predictions[0], predictions[1])
314
+ boxes = self.predict_boxes(predictions, proposals)
315
+ scores = self.predict_probs(predictions, proposals)
316
+ if self.mult_proposal_score:
317
+ proposal_scores = [p.get('objectness_logits') for p in proposals]
318
+ scores = [(s * ps[:, None]) ** 0.5 \
319
+ for s, ps in zip(scores, proposal_scores)]
320
+ image_shapes = [x.image_size for x in proposals]
321
+ return fast_rcnn_inference(
322
+ boxes,
323
+ scores,
324
+ image_shapes,
325
+ self.test_score_thresh,
326
+ self.test_nms_thresh,
327
+ self.test_topk_per_image,
328
+ )
329
+
330
+
331
+ def predict_probs(self, predictions, proposals):
332
+ """
333
+ support sigmoid
334
+ """
335
+ # scores, _ = predictions
336
+ scores = predictions[0]
337
+ num_inst_per_image = [len(p) for p in proposals]
338
+ if self.use_sigmoid_ce:
339
+ probs = scores.sigmoid()
340
+ else:
341
+ probs = F.softmax(scores, dim=-1)
342
+ return probs.split(num_inst_per_image, dim=0)
343
+
344
+
345
+ def image_label_losses(self, predictions, proposals, image_labels, \
346
+ classifier_info=(None,None,None), ann_type='image'):
347
+ '''
348
+ Inputs:
349
+ scores: N x (C + 1)
350
+ image_labels B x 1
351
+ '''
352
+ num_inst_per_image = [len(p) for p in proposals]
353
+ scores = predictions[0]
354
+ scores = scores.split(num_inst_per_image, dim=0) # B x n x (C + 1)
355
+ if self.with_softmax_prop:
356
+ prop_scores = predictions[2].split(num_inst_per_image, dim=0)
357
+ else:
358
+ prop_scores = [None for _ in num_inst_per_image]
359
+ B = len(scores)
360
+ img_box_count = 0
361
+ select_size_count = 0
362
+ select_x_count = 0
363
+ select_y_count = 0
364
+ max_score_count = 0
365
+ storage = get_event_storage()
366
+ loss = scores[0].new_zeros([1])[0]
367
+ caption_loss = scores[0].new_zeros([1])[0]
368
+ for idx, (score, labels, prop_score, p) in enumerate(zip(
369
+ scores, image_labels, prop_scores, proposals)):
370
+ if score.shape[0] == 0:
371
+ loss += score.new_zeros([1])[0]
372
+ continue
373
+ if 'caption' in ann_type:
374
+ score, caption_loss_img = self._caption_loss(
375
+ score, classifier_info, idx, B)
376
+ caption_loss += self.caption_weight * caption_loss_img
377
+ if ann_type == 'caption':
378
+ continue
379
+
380
+ if self.debug:
381
+ p.selected = score.new_zeros(
382
+ (len(p),), dtype=torch.long) - 1
383
+ for i_l, label in enumerate(labels):
384
+ if self.dynamic_classifier:
385
+ if idx == 0 and i_l == 0 and comm.is_main_process():
386
+ storage.put_scalar('stats_label', label)
387
+ label = classifier_info[1][1][label]
388
+ assert label < score.shape[1]
389
+ if self.image_label_loss in ['wsod', 'wsddn']:
390
+ loss_i, ind = self._wsddn_loss(score, prop_score, label)
391
+ elif self.image_label_loss == 'max_score':
392
+ loss_i, ind = self._max_score_loss(score, label)
393
+ elif self.image_label_loss == 'max_size':
394
+ loss_i, ind = self._max_size_loss(score, label, p)
395
+ elif self.image_label_loss == 'first':
396
+ loss_i, ind = self._first_loss(score, label)
397
+ elif self.image_label_loss == 'image':
398
+ loss_i, ind = self._image_loss(score, label)
399
+ elif self.image_label_loss == 'min_loss':
400
+ loss_i, ind = self._min_loss_loss(score, label)
401
+ else:
402
+ assert 0
403
+ loss += loss_i / len(labels)
404
+ if type(ind) == type([]):
405
+ img_box_count = sum(ind) / len(ind)
406
+ if self.debug:
407
+ for ind_i in ind:
408
+ p.selected[ind_i] = label
409
+ else:
410
+ img_box_count = ind
411
+ select_size_count = p[ind].proposal_boxes.area() / \
412
+ (p.image_size[0] * p.image_size[1])
413
+ max_score_count = score[ind, label].sigmoid()
414
+ select_x_count = (p.proposal_boxes.tensor[ind, 0] + \
415
+ p.proposal_boxes.tensor[ind, 2]) / 2 / p.image_size[1]
416
+ select_y_count = (p.proposal_boxes.tensor[ind, 1] + \
417
+ p.proposal_boxes.tensor[ind, 3]) / 2 / p.image_size[0]
418
+ if self.debug:
419
+ p.selected[ind] = label
420
+
421
+ loss = loss / B
422
+ storage.put_scalar('stats_l_image', loss.item())
423
+ if 'caption' in ann_type:
424
+ caption_loss = caption_loss / B
425
+ loss = loss + caption_loss
426
+ storage.put_scalar('stats_l_caption', caption_loss.item())
427
+ if comm.is_main_process():
428
+ storage.put_scalar('pool_stats', img_box_count)
429
+ storage.put_scalar('stats_select_size', select_size_count)
430
+ storage.put_scalar('stats_select_x', select_x_count)
431
+ storage.put_scalar('stats_select_y', select_y_count)
432
+ storage.put_scalar('stats_max_label_score', max_score_count)
433
+
434
+ return {
435
+ 'image_loss': loss * self.image_loss_weight,
436
+ 'loss_cls': score.new_zeros([1])[0],
437
+ 'loss_box_reg': score.new_zeros([1])[0]}
438
+
439
+
440
+ def forward(self, x, classifier_info=(None,None,None)):
441
+ """
442
+ enable classifier_info
443
+ """
444
+ if x.dim() > 2:
445
+ x = torch.flatten(x, start_dim=1)
446
+ scores = []
447
+
448
+ if classifier_info[0] is not None:
449
+ classifier_out = self.cls_score(x, classifier=classifier_info[0])
450
+ else:
451
+ classifier_out = self.cls_score(x)
452
+ if self.use_regional_embedding:
453
+ cls_scores, regional_embeddings = classifier_out
454
+ else:
455
+ cls_scores = classifier_out
456
+ scores.append(cls_scores)
457
+
458
+ if classifier_info[2] is not None:
459
+ classifier_out = classifier_info[2]
460
+ if self.use_regional_embedding:
461
+ cap_cls, regional_embeddings = classifier_out
462
+ else:
463
+ cap_cls = classifier_out
464
+ if self.sync_caption_batch:
465
+ caption_scores = self.cls_score(x, classifier=cap_cls[:, :-1])
466
+ else:
467
+ caption_scores = self.cls_score(x, classifier=cap_cls)
468
+ scores.append(caption_scores)
469
+ scores = torch.cat(scores, dim=1) # B x C' or B x N or B x (C'+N)
470
+
471
+ proposal_deltas = self.bbox_pred(x)
472
+ if self.with_softmax_prop:
473
+ prop_score = self.prop_score(x)
474
+ return scores, proposal_deltas, prop_score
475
+ elif self.use_regional_embedding:
476
+ # NOTE: scores: [B * # proposals for each image, 1204 (1203 + 1 bg)]
477
+ # NOTE: proposal_deltas: [B * # proposals for each image, 4]
478
+ # NOTE: regional_embeddings: [B * # proposals for each image, 512]
479
+ return scores, proposal_deltas, regional_embeddings
480
+ else:
481
+ return scores, proposal_deltas
482
+
483
+
484
+ def _caption_loss(self, score, classifier_info, idx, B):
485
+ assert (classifier_info[2] is not None)
486
+ assert self.add_image_box
487
+ cls_and_cap_num = score.shape[1]
488
+ cap_num = classifier_info[2].shape[0]
489
+ score, caption_score = score.split(
490
+ [cls_and_cap_num - cap_num, cap_num], dim=1)
491
+ # n x (C + 1), n x B
492
+ caption_score = caption_score[-1:] # 1 x B # -1: image level box
493
+ caption_target = caption_score.new_zeros(
494
+ caption_score.shape) # 1 x B or 1 x MB, M: num machines
495
+ if self.sync_caption_batch:
496
+ # caption_target: 1 x MB
497
+ rank = comm.get_rank()
498
+ global_idx = B * rank + idx
499
+ assert (classifier_info[2][
500
+ global_idx, -1] - rank) ** 2 < 1e-8, \
501
+ '{} {} {} {} {}'.format(
502
+ rank, global_idx,
503
+ classifier_info[2][global_idx, -1],
504
+ classifier_info[2].shape,
505
+ classifier_info[2][:, -1])
506
+ caption_target[:, global_idx] = 1.
507
+ else:
508
+ assert caption_score.shape[1] == B
509
+ caption_target[:, idx] = 1.
510
+ caption_loss_img = F.binary_cross_entropy_with_logits(
511
+ caption_score, caption_target, reduction='none')
512
+ if self.sync_caption_batch:
513
+ fg_mask = (caption_target > 0.5).float()
514
+ assert (fg_mask.sum().item() - 1.) ** 2 < 1e-8, '{} {}'.format(
515
+ fg_mask.shape, fg_mask)
516
+ pos_loss = (caption_loss_img * fg_mask).sum()
517
+ neg_loss = (caption_loss_img * (1. - fg_mask)).sum()
518
+ caption_loss_img = pos_loss + self.neg_cap_weight * neg_loss
519
+ else:
520
+ caption_loss_img = caption_loss_img.sum()
521
+ return score, caption_loss_img
522
+
523
+
524
+ def _wsddn_loss(self, score, prop_score, label):
525
+ assert prop_score is not None
526
+ loss = 0
527
+ final_score = score.sigmoid() * \
528
+ F.softmax(prop_score, dim=0) # B x (C + 1)
529
+ img_score = torch.clamp(
530
+ torch.sum(final_score, dim=0),
531
+ min=1e-10, max=1-1e-10) # (C + 1)
532
+ target = img_score.new_zeros(img_score.shape) # (C + 1)
533
+ target[label] = 1.
534
+ loss += F.binary_cross_entropy(img_score, target)
535
+ ind = final_score[:, label].argmax()
536
+ return loss, ind
537
+
538
+
539
+ def _max_score_loss(self, score, label):
540
+ loss = 0
541
+ target = score.new_zeros(score.shape[1])
542
+ target[label] = 1.
543
+ ind = score[:, label].argmax().item()
544
+ loss += F.binary_cross_entropy_with_logits(
545
+ score[ind], target, reduction='sum')
546
+ return loss, ind
547
+
548
+
549
+ def _min_loss_loss(self, score, label):
550
+ loss = 0
551
+ target = score.new_zeros(score.shape)
552
+ target[:, label] = 1.
553
+ with torch.no_grad():
554
+ x = F.binary_cross_entropy_with_logits(
555
+ score, target, reduction='none').sum(dim=1) # n
556
+ ind = x.argmin().item()
557
+ loss += F.binary_cross_entropy_with_logits(
558
+ score[ind], target[0], reduction='sum')
559
+ return loss, ind
560
+
561
+
562
+ def _first_loss(self, score, label):
563
+ loss = 0
564
+ target = score.new_zeros(score.shape[1])
565
+ target[label] = 1.
566
+ ind = 0
567
+ loss += F.binary_cross_entropy_with_logits(
568
+ score[ind], target, reduction='sum')
569
+ return loss, ind
570
+
571
+
572
+ def _image_loss(self, score, label):
573
+ assert self.add_image_box
574
+ target = score.new_zeros(score.shape[1])
575
+ target[label] = 1.
576
+ ind = score.shape[0] - 1
577
+ loss = F.binary_cross_entropy_with_logits(
578
+ score[ind], target, reduction='sum')
579
+ return loss, ind
580
+
581
+
582
+ def _max_size_loss(self, score, label, p):
583
+ loss = 0
584
+ target = score.new_zeros(score.shape[1])
585
+ target[label] = 1.
586
+ sizes = p.proposal_boxes.area()
587
+ ind = sizes[:-1].argmax().item() if len(sizes) > 1 else 0
588
+ if self.softmax_weak_loss:
589
+ loss += F.cross_entropy(
590
+ score[ind:ind+1],
591
+ score.new_tensor(label, dtype=torch.long).view(1),
592
+ reduction='sum')
593
+ else:
594
+ loss += F.binary_cross_entropy_with_logits(
595
+ score[ind], target, reduction='sum')
596
+ return loss, ind
597
+
598
+
599
+
600
+ def put_label_distribution(storage, hist_name, hist_counts, num_classes):
601
+ """
602
+ """
603
+ ht_min, ht_max = 0, num_classes
604
+ hist_edges = torch.linspace(
605
+ start=ht_min, end=ht_max, steps=num_classes + 1, dtype=torch.float32)
606
+
607
+ hist_params = dict(
608
+ tag=hist_name,
609
+ min=ht_min,
610
+ max=ht_max,
611
+ num=float(hist_counts.sum()),
612
+ sum=float((hist_counts * torch.arange(len(hist_counts))).sum()),
613
+ sum_squares=float(((hist_counts * torch.arange(len(hist_counts))) ** 2).sum()),
614
+ bucket_limits=hist_edges[1:].tolist(),
615
+ bucket_counts=hist_counts.tolist(),
616
+ global_step=storage._iter,
617
+ )
618
+ storage._histograms.append(hist_params)
proxydet/modeling/roi_heads/proxydet_roi_heads.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ '''
3
+ Modifications Copyright (c) 2024-present NAVER Corp, Apache License v2.0
4
+ original source: https://github.com/facebookresearch/Detic/blob/main/detic/modeling/roi_heads/detic_roi_heads.py
5
+ '''
6
+ import copy
7
+ import numpy as np
8
+ import json
9
+ import math
10
+ import torch
11
+ from torch import nn
12
+ from torch.autograd.function import Function
13
+ from typing import Dict, List, Optional, Tuple, Union
14
+ from torch.nn import functional as F
15
+
16
+ from fvcore.nn import giou_loss
17
+
18
+ from detectron2.config import configurable
19
+ from detectron2.layers import ShapeSpec
20
+ from detectron2.layers import batched_nms, cat
21
+ from detectron2.structures import Boxes, Instances, pairwise_iou
22
+ from detectron2.utils.events import get_event_storage
23
+
24
+ from detectron2.modeling.box_regression import Box2BoxTransform
25
+ from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference
26
+ from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, StandardROIHeads
27
+ from detectron2.modeling.roi_heads.cascade_rcnn import CascadeROIHeads, _ScaleGradient
28
+ from detectron2.modeling.roi_heads.box_head import build_box_head
29
+ from .proxydet_fast_rcnn import ProxydetFastRCNNOutputLayers
30
+ from ..debug import debug_second_stage
31
+
32
+ from torch.cuda.amp import autocast
33
+ from copy import deepcopy
34
+
35
+ @ROI_HEADS_REGISTRY.register()
36
+ class ProxydetCascadeROIHeads(CascadeROIHeads):
37
+ @configurable
38
+ def __init__(
39
+ self,
40
+ *,
41
+ mult_proposal_score: bool = False,
42
+ with_image_labels: bool = False,
43
+ add_image_box: bool = False,
44
+ image_box_size: float = 1.0,
45
+ ws_num_props: int = 512,
46
+ add_feature_to_prop: bool = False,
47
+ mask_weight: float = 1.0,
48
+ one_class_per_proposal: bool = False,
49
+ use_regional_embedding: bool = False,
50
+ base_cat_mask: str = None,
51
+ cmm_stage: list = [],
52
+ cmm_stage_test: list = None,
53
+ cmm_beta: float = 1.0,
54
+ cmm_loss: str = "l1",
55
+ cmm_loss_weight: float = 1.0,
56
+ cmm_separated_branch: bool = False,
57
+ cmm_base_alpha: float = 0.5,
58
+ cmm_novel_beta: float = 0.5,
59
+ cmm_use_inl: bool = False,
60
+ cmm_prototype: str = "center",
61
+ cmm_prototype_temp: float = 1.0,
62
+ cmm_classifier_temp: float = None,
63
+ cmm_use_sigmoid_ce: bool = True,
64
+ **kwargs,
65
+ ):
66
+ super().__init__(**kwargs)
67
+ self.mult_proposal_score = mult_proposal_score
68
+ self.with_image_labels = with_image_labels
69
+ self.add_image_box = add_image_box
70
+ self.image_box_size = image_box_size
71
+ self.ws_num_props = ws_num_props
72
+ self.add_feature_to_prop = add_feature_to_prop
73
+ self.mask_weight = mask_weight
74
+ self.one_class_per_proposal = one_class_per_proposal
75
+ self.use_regional_embedding = use_regional_embedding
76
+ self.base_cat_mask = torch.tensor(np.load(base_cat_mask)).bool()
77
+ self.cmm_stage = cmm_stage
78
+ self.cmm_stage_test = cmm_stage_test
79
+ self.cmm_beta = cmm_beta
80
+ self.cmm_loss = cmm_loss
81
+ self.cmm_loss_weight = cmm_loss_weight
82
+ self.cmm_separated_branch = cmm_separated_branch
83
+ self.cmm_base_alpha = cmm_base_alpha
84
+ self.cmm_novel_beta = cmm_novel_beta
85
+ self.cmm_use_inl = cmm_use_inl
86
+ self.cmm_prototype = cmm_prototype
87
+ self.cmm_prototype_temp = cmm_prototype_temp
88
+ self.cmm_classifier_temp = cmm_classifier_temp
89
+ self.cmm_use_sigmoid_ce = cmm_use_sigmoid_ce
90
+
91
+ if self.cmm_separated_branch:
92
+ self.box_head_cmm = deepcopy(self.box_head)
93
+ self.box_predictor_cmm = deepcopy(self.box_predictor)
94
+ if self.cmm_classifier_temp is not None:
95
+ for k in range(self.num_cascade_stages):
96
+ self.box_predictor_cmm[k].cls_score.norm_temperature = self.cmm_classifier_temp
97
+ if not self.cmm_use_sigmoid_ce:
98
+ for k in range(self.num_cascade_stages):
99
+ self.box_predictor_cmm[k].use_sigmoid_ce = self.cmm_use_sigmoid_ce # using bce or ce
100
+
101
+ @classmethod
102
+ def from_config(cls, cfg, input_shape):
103
+ ret = super().from_config(cfg, input_shape)
104
+ ret.update({
105
+ 'mult_proposal_score': cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE,
106
+ 'with_image_labels': cfg.WITH_IMAGE_LABELS,
107
+ 'add_image_box': cfg.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX,
108
+ 'image_box_size': cfg.MODEL.ROI_BOX_HEAD.IMAGE_BOX_SIZE,
109
+ 'ws_num_props': cfg.MODEL.ROI_BOX_HEAD.WS_NUM_PROPS,
110
+ 'add_feature_to_prop': cfg.MODEL.ROI_BOX_HEAD.ADD_FEATURE_TO_PROP,
111
+ 'mask_weight': cfg.MODEL.ROI_HEADS.MASK_WEIGHT,
112
+ 'one_class_per_proposal': cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL,
113
+ "use_regional_embedding": cfg.MODEL.ROI_BOX_HEAD.USE_REGIONAL_EMBEDDING,
114
+ "base_cat_mask": cfg.MODEL.ROI_HEADS.BASE_CAT_MASK,
115
+ "cmm_stage": cfg.MODEL.ROI_HEADS.CMM.MIXUP_STAGE,
116
+ "cmm_stage_test": cfg.MODEL.ROI_HEADS.CMM.MIXUP_STAGE_TEST,
117
+ "cmm_beta": cfg.MODEL.ROI_HEADS.CMM.MIXUP_BETA,
118
+ "cmm_loss": cfg.MODEL.ROI_HEADS.CMM.LOSS,
119
+ "cmm_loss_weight": cfg.MODEL.ROI_HEADS.CMM.LOSS_WEIGHT,
120
+ "cmm_separated_branch": cfg.MODEL.ROI_HEADS.CMM.SEPARATED_BRANCH,
121
+ "cmm_base_alpha": cfg.MODEL.ROI_HEADS.CMM.BASE_ALPHA,
122
+ "cmm_novel_beta": cfg.MODEL.ROI_HEADS.CMM.NOVEL_BETA,
123
+ "cmm_use_inl": cfg.MODEL.ROI_HEADS.CMM.USE_INL,
124
+ "cmm_prototype": cfg.MODEL.ROI_HEADS.CMM.PROTOTYPE,
125
+ "cmm_prototype_temp": cfg.MODEL.ROI_HEADS.CMM.PROTOTYPE_TEMP,
126
+ "cmm_classifier_temp": cfg.MODEL.ROI_HEADS.CMM.CLASSIFIER_TEMP,
127
+ "cmm_use_sigmoid_ce": cfg.MODEL.ROI_HEADS.CMM.USE_SIGMOID_CE,
128
+ })
129
+ return ret
130
+
131
+
132
+ @classmethod
133
+ def _init_box_head(self, cfg, input_shape):
134
+ ret = super()._init_box_head(cfg, input_shape)
135
+ del ret['box_predictors']
136
+ cascade_bbox_reg_weights = cfg.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS
137
+ box_predictors = []
138
+ for box_head, bbox_reg_weights in zip(ret['box_heads'], \
139
+ cascade_bbox_reg_weights):
140
+ box_predictors.append(
141
+ ProxydetFastRCNNOutputLayers(
142
+ cfg, box_head.output_shape,
143
+ box2box_transform=Box2BoxTransform(weights=bbox_reg_weights)
144
+ ))
145
+ ret['box_predictors'] = box_predictors
146
+ return ret
147
+
148
+ def _embed_interp(self, re_i, re_j, te_i, te_j, lam):
149
+ # mix image, text embedding
150
+ _mixed_re = lam * re_i + (1 - lam) * re_j
151
+ _mixed_te = lam * te_i + (1 - lam) * te_j
152
+ return _mixed_re, _mixed_te
153
+
154
+ def _get_head_outputs(self, features, proposals, image_sizes, _run_stage, box_predictor, targets=None, ann_type='box', classifier_info=(None,None,None)):
155
+ head_outputs = [] # (predictor, predictions, proposals)
156
+ for k in range(self.num_cascade_stages):
157
+ if k > 0:
158
+ proposals = self._create_proposals_from_boxes(
159
+ prev_pred_boxes, image_sizes,
160
+ logits=[p.objectness_logits for p in proposals])
161
+ if self.training and ann_type in ['box']:
162
+ proposals = self._match_and_label_boxes(
163
+ proposals, k, targets)
164
+ predictions = _run_stage(features, proposals, k,
165
+ classifier_info=classifier_info)
166
+ prev_pred_boxes = box_predictor[k].predict_boxes(
167
+ (predictions[0], predictions[1]), proposals)
168
+ head_outputs.append((box_predictor[k], predictions, proposals))
169
+ return head_outputs, proposals
170
+
171
+ def loss_mixup(self, mixed_re, mixed_te, text_embeddings):
172
+ # loss
173
+ if self.cmm_loss in ["l1", "l2"]:
174
+ if self.cmm_loss == "l1":
175
+ loss_type = F.l1_loss
176
+ elif self.cmm_loss == "l2":
177
+ loss_type = F.mse_loss
178
+ cmm_loss = loss_type(mixed_re, mixed_te)
179
+ else:
180
+ raise ValueError("No such loss is supported : ", self.cmm_loss)
181
+ return cmm_loss
182
+
183
+ def mixup(self, stage, regional_embeddings, text_embeddings, gt_classes, proto_weights=None):
184
+ # class-wise multi-modal mixup
185
+ try:
186
+ neg_class = text_embeddings.shape[0] - 1
187
+
188
+ # select positive text embeddings
189
+ all_classes = torch.unique(gt_classes)
190
+ pos_classes = all_classes[all_classes != neg_class]
191
+
192
+ # select class-wise regional embeddings & text embeddings
193
+ clswise_re = []
194
+ clswise_te = []
195
+ for p_c in pos_classes:
196
+ mask = (gt_classes == p_c)
197
+ if self.cmm_prototype == "center":
198
+ _clswise_re = torch.mean(regional_embeddings[mask], axis=0, keepdim=True)
199
+ elif self.cmm_prototype in ["obj_score", "iou"]:
200
+ soft_proto_weights = F.softmax(proto_weights[mask] / self.cmm_prototype_temp, dim=0)
201
+ _clswise_re = torch.sum(regional_embeddings[mask] * soft_proto_weights.unsqueeze(-1), 0, keepdim=True)
202
+ _clswise_te = text_embeddings[int(p_c.item())].unsqueeze(0)
203
+ clswise_re.append(_clswise_re)
204
+ clswise_te.append(_clswise_te)
205
+
206
+ if len(clswise_re) == 0:
207
+ raise ValueError("no positive base classes found for mixup.")
208
+
209
+ clswise_re = torch.cat(clswise_re, dim=0)
210
+ clswise_re = F.normalize(clswise_re, p=2, dim=1) # re-normalize
211
+ clswise_te = torch.cat(clswise_te, dim=0)
212
+
213
+ if self.cmm_beta == 0:
214
+ lam = float(np.random.randint(2))
215
+ else:
216
+ lam = np.random.beta(self.cmm_beta, self.cmm_beta)
217
+
218
+ # random shuffle for mixup pair
219
+ rand_index = torch.randperm(clswise_re.size()[0]).to(clswise_re.device)
220
+
221
+ # mixup
222
+ sf_clswise_re = clswise_re[rand_index]
223
+ sf_clswise_te = clswise_te[rand_index]
224
+ mixed_re, mixed_te = self._embed_interp(clswise_re, sf_clswise_re, clswise_te, sf_clswise_te, lam)
225
+ mixed_re = F.normalize(mixed_re, p=2, dim=1)
226
+ mixed_te = F.normalize(mixed_te, p=2, dim=1)
227
+ cmm_loss = self.loss_mixup(mixed_re, mixed_te, text_embeddings)
228
+
229
+ except Exception as e:
230
+ print("Caught this error in mixup: " + repr(e), "Thus skipping current batch w/o mixup...")
231
+ cmm_loss = text_embeddings[0].new_zeros([1])[0]
232
+
233
+ return cmm_loss
234
+
235
+ def _forward_box(self, features, proposals, targets=None,
236
+ ann_type='box', classifier_info=(None,None,None)):
237
+ """
238
+ Add mult proposal scores at testing
239
+ Add ann_type
240
+ """
241
+ if (not self.training) and self.mult_proposal_score:
242
+ if len(proposals) > 0 and proposals[0].has('scores'):
243
+ proposal_scores = [p.get('scores') for p in proposals]
244
+ else:
245
+ proposal_scores = [p.get('objectness_logits') for p in proposals]
246
+
247
+ features = [features[f] for f in self.box_in_features]
248
+ # head_outputs = [] # (predictor, predictions, proposals)
249
+ prev_pred_boxes = None
250
+ image_sizes = [x.image_size for x in proposals]
251
+
252
+ head_outputs, proposals = self._get_head_outputs(features, proposals, image_sizes, self._run_stage, self.box_predictor, targets, ann_type, classifier_info)
253
+ if self.cmm_separated_branch:
254
+ # separated forward
255
+ head_outputs_cmm, proposals_cmm = self._get_head_outputs(features, proposals, image_sizes, self._run_stage_cmm, self.box_predictor_cmm, targets, ann_type, classifier_info)
256
+
257
+ if self.training:
258
+ losses = {}
259
+ storage = get_event_storage()
260
+ for stage, (predictor, predictions, proposals) in enumerate(head_outputs):
261
+ with storage.name_scope("stage{}".format(stage)):
262
+ if ann_type != 'box':
263
+ stage_losses = {}
264
+ if ann_type in ['image', 'caption', 'captiontag']:
265
+ image_labels = [x._pos_category_ids for x in targets]
266
+ weak_losses = predictor.image_label_losses(
267
+ predictions, proposals, image_labels,
268
+ classifier_info=classifier_info,
269
+ ann_type=ann_type)
270
+ stage_losses.update(weak_losses)
271
+
272
+ if self.cmm_use_inl and len(self.cmm_stage) > 0 and stage in self.cmm_stage:
273
+ if self.cmm_separated_branch:
274
+ # get regional embeddings (l2 normalized) from separated branch
275
+ regional_embeddings = head_outputs_cmm[stage][1][2]
276
+ else:
277
+ # get regional embeddings (l2 normalized)
278
+ regional_embeddings = predictions[2]
279
+
280
+ # get text embeddings (L2 normalized), [C (1203 + 1), embedding dim]
281
+ text_embeddings = predictor.cls_score.zs_weight.t()
282
+
283
+ # get max-size proposal's regional embedding, per image
284
+ num_inst_per_image = [len(p) for p in proposals_cmm]
285
+ re_per_image = regional_embeddings.split(num_inst_per_image, dim=0)
286
+
287
+ maxsize_re_per_image = []
288
+ for p, re in zip(proposals_cmm, re_per_image):
289
+ sizes = p.proposal_boxes.area()
290
+ ind = sizes[:-1].argmax().item() if len(sizes) > 1 else 0
291
+ maxsize_re_per_image.append(re[ind].unsqueeze(0))
292
+ maxsize_re_per_image = torch.cat(maxsize_re_per_image, dim=0)
293
+ maxsize_re_per_image = maxsize_re_per_image.to(regional_embeddings.device)
294
+
295
+ # get gt classes per max-size proposal
296
+ # TODO: add best-label per image by cls loss from weak_losses (image-label loss)
297
+ # NOTE: image_labels are not multi-labels.
298
+ gt_classes = (
299
+ regional_embeddings.new_tensor(
300
+ [np.random.choice(labels, 1, replace=False)[0] for labels in image_labels],
301
+ dtype=torch.long
302
+ ) if len(proposals_cmm)
303
+ else torch.empty(0)
304
+ )
305
+
306
+ proto_weights = None
307
+
308
+ # get text embeddings (L2 normalized), [C (1203 + 1), embedding dim]
309
+ text_embeddings = predictor.cls_score.zs_weight.t()
310
+ cmm_loss = self.mixup(stage, maxsize_re_per_image, text_embeddings, gt_classes, proto_weights)
311
+ stage_losses["cmm_image_loss"] = (
312
+ cmm_loss * self.cmm_loss_weight
313
+ )
314
+ stage_losses["cmm_loss"] = \
315
+ predictions[0].new_zeros([1])[0]
316
+
317
+ else: # supervised
318
+ stage_losses = predictor.losses(
319
+ (predictions[0], predictions[1]), proposals,
320
+ classifier_info=classifier_info)
321
+ if self.with_image_labels:
322
+ stage_losses['image_loss'] = \
323
+ predictions[0].new_zeros([1])[0]
324
+
325
+ if len(self.cmm_stage) > 0 and stage in self.cmm_stage:
326
+ assert self.use_regional_embedding
327
+
328
+ # get gt classes per proposal
329
+ # e.g. dtype: torch.int64, value: tensor([ 142, 111, 142, ..., 1203, 1203, 1203], device='cuda:6')
330
+ gt_classes = (
331
+ cat([p.gt_classes for p in proposals_cmm], dim=0)
332
+ if len(proposals_cmm)
333
+ else torch.empty(0)
334
+ )
335
+
336
+ if self.cmm_prototype in ["obj_score"]:
337
+ proto_weights = (
338
+ cat([p.objectness_logits for p in proposals_cmm], dim=0)
339
+ if len(proposals_cmm)
340
+ else torch.empty(0)
341
+ )
342
+ elif self.cmm_prototype in ["iou"]:
343
+ gt_boxes = (
344
+ cat([p.gt_boxes.tensor for p in proposals_cmm], dim=0)
345
+ if len(proposals_cmm)
346
+ else torch.empty(0)
347
+ )
348
+ proposal_boxes = (
349
+ cat([p.proposal_boxes.tensor for p in proposals_cmm], dim=0)
350
+ if len(proposals_cmm)
351
+ else torch.empty(0)
352
+ )
353
+
354
+ proto_weights = 1 - giou_loss(proposal_boxes, gt_boxes, reduction="none") # GIoU. (-1 < x < 1)
355
+ else:
356
+ proto_weights = None
357
+
358
+ if self.cmm_separated_branch:
359
+ # get regional embeddings (l2 normalized) from separated branch
360
+ regional_embeddings = head_outputs_cmm[stage][1][2]
361
+ else:
362
+ # get regional embeddings (l2 normalized)
363
+ regional_embeddings = predictions[2]
364
+
365
+ # get text embeddings (L2 normalized), [C (1203 + 1), embedding dim]
366
+ text_embeddings = predictor.cls_score.zs_weight.t()
367
+
368
+ cmm_loss = self.mixup(stage, regional_embeddings, text_embeddings, gt_classes, proto_weights)
369
+ stage_losses["cmm_loss"] = (
370
+ cmm_loss * self.cmm_loss_weight
371
+ )
372
+
373
+ if self.cmm_use_inl:
374
+ stage_losses["cmm_image_loss"] = \
375
+ predictions[0].new_zeros([1])[0]
376
+
377
+ losses.update({k + "_stage{}".format(stage): v \
378
+ for k, v in stage_losses.items()})
379
+ return losses
380
+ else:
381
+ # Each is a list[Tensor] of length #image. Each tensor is Ri x (K+1)
382
+ scores_per_stage = [h[0].predict_probs(h[1], h[2]) for h in head_outputs]
383
+ scores = [
384
+ sum(list(scores_per_image)) * (1.0 / self.num_cascade_stages)
385
+ for scores_per_image in zip(*scores_per_stage)
386
+ ]
387
+
388
+ if self.cmm_separated_branch:
389
+ # scores from separated branch
390
+ if self.cmm_stage_test is None:
391
+ # average all stage's classification scores
392
+ scores_per_stage_cmm = [h[0].predict_probs(h[1], h[2]) for h in head_outputs_cmm]
393
+ scores_cmm = [
394
+ sum(list(scores_per_image)) * (1.0 / self.num_cascade_stages)
395
+ for scores_per_image in zip(*scores_per_stage_cmm)
396
+ ]
397
+ else:
398
+ # only using specific stages
399
+ scores_per_stage_cmm = [h[0].predict_probs(h[1], h[2]) for k, h in enumerate(head_outputs_cmm) if k in self.cmm_stage_test]
400
+ scores_cmm = [
401
+ sum(list(scores_per_image)) * (1.0 / len(self.cmm_stage_test))
402
+ for scores_per_image in zip(*scores_per_stage_cmm)
403
+ ]
404
+
405
+ base_cat_mask = self.base_cat_mask
406
+ assert len(scores) == 1
407
+ bg_score = scores[0][:, -1].clone()
408
+ scores[0][:, base_cat_mask] = scores[0][:, base_cat_mask].pow(
409
+ 1.0 - self.cmm_base_alpha
410
+ ) * scores_cmm[0][:, base_cat_mask].pow(self.cmm_base_alpha)
411
+ scores[0][:, ~base_cat_mask] = scores[0][:, ~base_cat_mask].pow(
412
+ 1.0 - self.cmm_novel_beta
413
+ ) * scores_cmm[0][:, ~base_cat_mask].pow(self.cmm_novel_beta)
414
+ scores[0][:, -1] = bg_score
415
+
416
+ if self.mult_proposal_score:
417
+ scores = [(s * ps[:, None]) ** 0.5 \
418
+ for s, ps in zip(scores, proposal_scores)]
419
+ if self.one_class_per_proposal:
420
+ scores = [s * (s == s[:, :-1].max(dim=1)[0][:, None]).float() for s in scores]
421
+ predictor, predictions, proposals = head_outputs[-1]
422
+ boxes = predictor.predict_boxes(
423
+ (predictions[0], predictions[1]), proposals)
424
+ pred_instances, _ = fast_rcnn_inference(
425
+ boxes,
426
+ scores,
427
+ image_sizes,
428
+ predictor.test_score_thresh,
429
+ predictor.test_nms_thresh,
430
+ predictor.test_topk_per_image,
431
+ )
432
+ return pred_instances
433
+
434
+
435
+ def forward(self, images, features, proposals, targets=None,
436
+ ann_type='box', classifier_info=(None,None,None)):
437
+ '''
438
+ enable debug and image labels
439
+ classifier_info is shared across the batch
440
+ '''
441
+ if self.training:
442
+ if ann_type in ['box', 'prop', 'proptag']:
443
+ proposals = self.label_and_sample_proposals(
444
+ proposals, targets)
445
+ else:
446
+ proposals = self.get_top_proposals(proposals)
447
+
448
+ losses = self._forward_box(features, proposals, targets, \
449
+ ann_type=ann_type, classifier_info=classifier_info)
450
+ if ann_type == 'box' and targets[0].has('gt_masks'):
451
+ mask_losses = self._forward_mask(features, proposals)
452
+ losses.update({k: v * self.mask_weight \
453
+ for k, v in mask_losses.items()})
454
+ losses.update(self._forward_keypoint(features, proposals))
455
+ else:
456
+ losses.update(self._get_empty_mask_loss(
457
+ features, proposals,
458
+ device=proposals[0].objectness_logits.device))
459
+ return proposals, losses
460
+ else:
461
+ pred_instances = self._forward_box(
462
+ features, proposals, classifier_info=classifier_info)
463
+ pred_instances = self.forward_with_given_boxes(features, pred_instances)
464
+ return pred_instances, {}
465
+
466
+
467
+ def get_top_proposals(self, proposals):
468
+ for i in range(len(proposals)):
469
+ proposals[i].proposal_boxes.clip(proposals[i].image_size)
470
+ proposals = [p[:self.ws_num_props] for p in proposals]
471
+ for i, p in enumerate(proposals):
472
+ p.proposal_boxes.tensor = p.proposal_boxes.tensor.detach()
473
+ if self.add_image_box:
474
+ proposals[i] = self._add_image_box(p)
475
+ return proposals
476
+
477
+
478
+ def _add_image_box(self, p):
479
+ image_box = Instances(p.image_size)
480
+ n = 1
481
+ h, w = p.image_size
482
+ f = self.image_box_size
483
+ image_box.proposal_boxes = Boxes(
484
+ p.proposal_boxes.tensor.new_tensor(
485
+ [w * (1. - f) / 2.,
486
+ h * (1. - f) / 2.,
487
+ w * (1. - (1. - f) / 2.),
488
+ h * (1. - (1. - f) / 2.)]
489
+ ).view(n, 4))
490
+ image_box.objectness_logits = p.objectness_logits.new_ones(n)
491
+ return Instances.cat([p, image_box])
492
+
493
+
494
+ def _get_empty_mask_loss(self, features, proposals, device):
495
+ if self.mask_on:
496
+ return {'loss_mask': torch.zeros(
497
+ (1, ), device=device, dtype=torch.float32)[0]}
498
+ else:
499
+ return {}
500
+
501
+
502
+ def _create_proposals_from_boxes(self, boxes, image_sizes, logits):
503
+ """
504
+ Add objectness_logits
505
+ """
506
+ boxes = [Boxes(b.detach()) for b in boxes]
507
+ proposals = []
508
+ for boxes_per_image, image_size, logit in zip(
509
+ boxes, image_sizes, logits):
510
+ boxes_per_image.clip(image_size)
511
+ if self.training:
512
+ inds = boxes_per_image.nonempty()
513
+ boxes_per_image = boxes_per_image[inds]
514
+ logit = logit[inds]
515
+ prop = Instances(image_size)
516
+ prop.proposal_boxes = boxes_per_image
517
+ prop.objectness_logits = logit
518
+ proposals.append(prop)
519
+ return proposals
520
+
521
+
522
+ def _run_stage(self, features, proposals, stage, \
523
+ classifier_info=(None,None,None)):
524
+ """
525
+ Support classifier_info and add_feature_to_prop
526
+ """
527
+ pool_boxes = [x.proposal_boxes for x in proposals]
528
+ box_features = self.box_pooler(features, pool_boxes)
529
+ box_features = _ScaleGradient.apply(box_features, 1.0 / self.num_cascade_stages)
530
+ box_features = self.box_head[stage](box_features)
531
+ if self.add_feature_to_prop:
532
+ feats_per_image = box_features.split(
533
+ [len(p) for p in proposals], dim=0)
534
+ for feat, p in zip(feats_per_image, proposals):
535
+ p.feat = feat
536
+ return self.box_predictor[stage](
537
+ box_features,
538
+ classifier_info=classifier_info)
539
+
540
+ def _run_stage_cmm(self, features, proposals, stage, \
541
+ classifier_info=(None,None,None)):
542
+ """
543
+ Support classifier_info and add_feature_to_prop
544
+ """
545
+ pool_boxes = [x.proposal_boxes for x in proposals]
546
+ box_features = self.box_pooler(features, pool_boxes)
547
+ box_features = _ScaleGradient.apply(box_features, 1.0 / self.num_cascade_stages)
548
+ box_features = self.box_head_cmm[stage](box_features)
549
+ if self.add_feature_to_prop:
550
+ feats_per_image = box_features.split(
551
+ [len(p) for p in proposals], dim=0)
552
+ for feat, p in zip(feats_per_image, proposals):
553
+ p.feat = feat
554
+ return self.box_predictor_cmm[stage](
555
+ box_features,
556
+ classifier_info=classifier_info)
proxydet/modeling/roi_heads/zero_shot_classifier.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ '''
3
+ Modifications Copyright (c) 2024-present NAVER Corp, Apache License v2.0
4
+ original source: https://github.com/facebookresearch/Detic/blob/main/detic/modeling/roi_heads/zero_shot_classifier.py
5
+ '''
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from detectron2.config import configurable
11
+ from detectron2.layers import Linear, ShapeSpec
12
+
13
+ class ZeroShotClassifier(nn.Module):
14
+ @configurable
15
+ def __init__(
16
+ self,
17
+ input_shape: ShapeSpec,
18
+ *,
19
+ num_classes: int,
20
+ zs_weight_path: str,
21
+ zs_weight_dim: int = 512,
22
+ use_bias: float = 0.0,
23
+ norm_weight: bool = True,
24
+ norm_temperature: float = 50.0,
25
+ use_regional_embedding: bool = False
26
+ ):
27
+ super().__init__()
28
+ if isinstance(input_shape, int): # some backward compatibility
29
+ input_shape = ShapeSpec(channels=input_shape)
30
+ input_size = input_shape.channels * (input_shape.width or 1) * (input_shape.height or 1)
31
+ self.norm_weight = norm_weight
32
+ self.norm_temperature = norm_temperature
33
+
34
+ self.use_bias = use_bias < 0
35
+ if self.use_bias:
36
+ self.cls_bias = nn.Parameter(torch.ones(1) * use_bias)
37
+
38
+ self.linear = nn.Linear(input_size, zs_weight_dim)
39
+
40
+ if zs_weight_path == 'rand':
41
+ zs_weight = torch.randn((zs_weight_dim, num_classes))
42
+ nn.init.normal_(zs_weight, std=0.01)
43
+ else:
44
+ zs_weight = torch.tensor(
45
+ np.load(zs_weight_path),
46
+ dtype=torch.float32).permute(1, 0).contiguous() # D x C
47
+ zs_weight = torch.cat(
48
+ [zs_weight, zs_weight.new_zeros((zs_weight_dim, 1))],
49
+ dim=1) # D x (C + 1)
50
+
51
+ if self.norm_weight:
52
+ zs_weight = F.normalize(zs_weight, p=2, dim=0)
53
+
54
+ if zs_weight_path == 'rand':
55
+ self.zs_weight = nn.Parameter(zs_weight)
56
+ else:
57
+ self.register_buffer('zs_weight', zs_weight)
58
+
59
+ assert self.zs_weight.shape[1] == num_classes + 1, self.zs_weight.shape
60
+
61
+ self.use_regional_embedding = use_regional_embedding
62
+ if self.use_regional_embedding:
63
+ assert (
64
+ self.norm_weight
65
+ ), "norm_weight should be True for using regional embedding."
66
+
67
+
68
+ @classmethod
69
+ def from_config(cls, cfg, input_shape):
70
+ return {
71
+ 'input_shape': input_shape,
72
+ 'num_classes': cfg.MODEL.ROI_HEADS.NUM_CLASSES,
73
+ 'zs_weight_path': cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH,
74
+ 'zs_weight_dim': cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_DIM,
75
+ 'use_bias': cfg.MODEL.ROI_BOX_HEAD.USE_BIAS,
76
+ 'norm_weight': cfg.MODEL.ROI_BOX_HEAD.NORM_WEIGHT,
77
+ 'norm_temperature': cfg.MODEL.ROI_BOX_HEAD.NORM_TEMP,
78
+ "use_regional_embedding": cfg.MODEL.ROI_BOX_HEAD.USE_REGIONAL_EMBEDDING,
79
+ }
80
+
81
+ def forward(self, x, classifier=None):
82
+ '''
83
+ Inputs:
84
+ x: B x D'
85
+ classifier_info: (C', C' x D)
86
+ '''
87
+ x = self.linear(x)
88
+ if classifier is not None:
89
+ zs_weight = classifier.permute(1, 0).contiguous() # D x C'
90
+ zs_weight = F.normalize(zs_weight, p=2, dim=0) \
91
+ if self.norm_weight else zs_weight
92
+ else:
93
+ zs_weight = self.zs_weight
94
+ if self.norm_weight:
95
+ # NOTE: x shape is [Batch size * # proposals for each image, 512 (embedding dim)]
96
+ x = F.normalize(x, p=2, dim=1)
97
+ if self.use_regional_embedding:
98
+ # NOTE: gradient of cloned tensor will be propagated to the original tensor (x): https://discuss.pytorch.org/t/how-does-clone-interact-with-backpropagation/8247/6?u=111368
99
+ regional_embedding = x.clone()
100
+ # NOTE: apply normalizing temperature
101
+ x *= self.norm_temperature
102
+
103
+ x = torch.mm(x, zs_weight)
104
+
105
+ if self.use_bias:
106
+ x = x + self.cls_bias
107
+
108
+ if not self.use_regional_embedding:
109
+ return x # class logits
110
+ else:
111
+ return x, regional_embedding # class logits & regional embeddings
proxydet/modeling/text/text_encoder.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is modified from https://github.com/openai/CLIP/blob/main/clip/clip.py
2
+ # Modified by Xingyi Zhou
3
+ # The original code is under MIT license
4
+ # Copyright (c) Facebook, Inc. and its affiliates.
5
+ from typing import Union, List
6
+ from collections import OrderedDict
7
+ import torch
8
+ from torch import nn
9
+ import torch
10
+
11
+ from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
12
+
13
+ __all__ = ["tokenize"]
14
+
15
+ count = 0
16
+
17
+ class LayerNorm(nn.LayerNorm):
18
+ """Subclass torch's LayerNorm to handle fp16."""
19
+
20
+ def forward(self, x: torch.Tensor):
21
+ orig_type = x.dtype
22
+ ret = super().forward(x.type(torch.float32))
23
+ return ret.type(orig_type)
24
+
25
+
26
+ class QuickGELU(nn.Module):
27
+ def forward(self, x: torch.Tensor):
28
+ return x * torch.sigmoid(1.702 * x)
29
+
30
+
31
+ class ResidualAttentionBlock(nn.Module):
32
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
33
+ super().__init__()
34
+
35
+ self.attn = nn.MultiheadAttention(d_model, n_head)
36
+ self.ln_1 = LayerNorm(d_model)
37
+ self.mlp = nn.Sequential(OrderedDict([
38
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
39
+ ("gelu", QuickGELU()),
40
+ ("c_proj", nn.Linear(d_model * 4, d_model))
41
+ ]))
42
+ self.ln_2 = LayerNorm(d_model)
43
+ self.attn_mask = attn_mask
44
+
45
+ def attention(self, x: torch.Tensor):
46
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
47
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
48
+
49
+ def forward(self, x: torch.Tensor):
50
+ x = x + self.attention(self.ln_1(x))
51
+ x = x + self.mlp(self.ln_2(x))
52
+ return x
53
+
54
+
55
+ class Transformer(nn.Module):
56
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
57
+ super().__init__()
58
+ self.width = width
59
+ self.layers = layers
60
+ self.resblocks = nn.Sequential(
61
+ *[ResidualAttentionBlock(width, heads, attn_mask) \
62
+ for _ in range(layers)])
63
+
64
+ def forward(self, x: torch.Tensor):
65
+ return self.resblocks(x)
66
+
67
+ class CLIPTEXT(nn.Module):
68
+ def __init__(self,
69
+ embed_dim=512,
70
+ # text
71
+ context_length=77,
72
+ vocab_size=49408,
73
+ transformer_width=512,
74
+ transformer_heads=8,
75
+ transformer_layers=12
76
+ ):
77
+ super().__init__()
78
+
79
+ self._tokenizer = _Tokenizer()
80
+ self.context_length = context_length
81
+
82
+ self.transformer = Transformer(
83
+ width=transformer_width,
84
+ layers=transformer_layers,
85
+ heads=transformer_heads,
86
+ attn_mask=self.build_attention_mask()
87
+ )
88
+
89
+ self.vocab_size = vocab_size
90
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
91
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
92
+ self.ln_final = LayerNorm(transformer_width)
93
+
94
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
95
+ # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
96
+
97
+ self.initialize_parameters()
98
+
99
+ def initialize_parameters(self):
100
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
101
+ nn.init.normal_(self.positional_embedding, std=0.01)
102
+
103
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
104
+ attn_std = self.transformer.width ** -0.5
105
+ fc_std = (2 * self.transformer.width) ** -0.5
106
+ for block in self.transformer.resblocks:
107
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
108
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
109
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
110
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
111
+
112
+ if self.text_projection is not None:
113
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
114
+
115
+ def build_attention_mask(self):
116
+ # lazily create causal attention mask, with full attention between the vision tokens
117
+ # pytorch uses additive attention mask; fill with -inf
118
+ mask = torch.empty(self.context_length, self.context_length)
119
+ mask.fill_(float("-inf"))
120
+ mask.triu_(1) # zero out the lower diagonal
121
+ return mask
122
+
123
+ @property
124
+ def device(self):
125
+ return self.text_projection.device
126
+
127
+ @property
128
+ def dtype(self):
129
+ return self.text_projection.dtype
130
+
131
+ def tokenize(self,
132
+ texts: Union[str, List[str]], \
133
+ context_length: int = 77) -> torch.LongTensor:
134
+ """
135
+ """
136
+ if isinstance(texts, str):
137
+ texts = [texts]
138
+
139
+ sot_token = self._tokenizer.encoder["<|startoftext|>"]
140
+ eot_token = self._tokenizer.encoder["<|endoftext|>"]
141
+ all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts]
142
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
143
+
144
+ for i, tokens in enumerate(all_tokens):
145
+ if len(tokens) > context_length:
146
+ st = torch.randint(
147
+ len(tokens) - context_length + 1, (1,))[0].item()
148
+ tokens = tokens[st: st + context_length]
149
+ # raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
150
+ result[i, :len(tokens)] = torch.tensor(tokens)
151
+
152
+ return result
153
+
154
+ def encode_text(self, text):
155
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
156
+ x = x + self.positional_embedding.type(self.dtype)
157
+ x = x.permute(1, 0, 2) # NLD -> LND
158
+ x = self.transformer(x)
159
+ x = x.permute(1, 0, 2) # LND -> NLD
160
+ x = self.ln_final(x).type(self.dtype)
161
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
162
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
163
+ return x
164
+
165
+ def forward(self, captions):
166
+ '''
167
+ captions: list of strings
168
+ '''
169
+ text = self.tokenize(captions).to(self.device) # B x L x D
170
+ features = self.encode_text(text) # B x D
171
+ return features
172
+
173
+
174
+ def build_text_encoder(pretrain=True):
175
+ text_encoder = CLIPTEXT()
176
+ if pretrain:
177
+ import clip
178
+ pretrained_model, _ = clip.load("ViT-B/32", device='cpu')
179
+ state_dict = pretrained_model.state_dict()
180
+ to_delete_keys = ["logit_scale", "input_resolution", \
181
+ "context_length", "vocab_size"] + \
182
+ [k for k in state_dict.keys() if k.startswith('visual.')]
183
+ for k in to_delete_keys:
184
+ if k in state_dict:
185
+ del state_dict[k]
186
+ print('Loading pretrained CLIP')
187
+ text_encoder.load_state_dict(state_dict)
188
+ # import pdb; pdb.set_trace()
189
+ return text_encoder
proxydet/modeling/utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import torch
3
+ import json
4
+ import numpy as np
5
+ from torch.nn import functional as F
6
+
7
+ def load_class_freq(
8
+ path='datasets/metadata/lvis_v1_train_cat_info.json', freq_weight=1.0):
9
+ cat_info = json.load(open(path, 'r'))
10
+ cat_info = torch.tensor(
11
+ [c['image_count'] for c in sorted(cat_info, key=lambda x: x['id'])])
12
+ freq_weight = cat_info.float() ** freq_weight
13
+ return freq_weight
14
+
15
+
16
+ def get_fed_loss_inds(gt_classes, num_sample_cats, C, weight=None):
17
+ appeared = torch.unique(gt_classes) # C'
18
+ prob = appeared.new_ones(C + 1).float()
19
+ prob[-1] = 0
20
+ if len(appeared) < num_sample_cats:
21
+ if weight is not None:
22
+ prob[:C] = weight.float().clone()
23
+ prob[appeared] = 0
24
+ more_appeared = torch.multinomial(
25
+ prob, num_sample_cats - len(appeared),
26
+ replacement=False)
27
+ appeared = torch.cat([appeared, more_appeared])
28
+ return appeared
29
+
30
+
31
+
32
+ def reset_cls_test(model, cls_path, num_classes):
33
+ model.roi_heads.num_classes = num_classes
34
+ if type(cls_path) == str:
35
+ print('Resetting zs_weight', cls_path)
36
+ zs_weight = torch.tensor(
37
+ np.load(cls_path),
38
+ dtype=torch.float32).permute(1, 0).contiguous() # D x C
39
+ else:
40
+ zs_weight = cls_path
41
+ zs_weight = torch.cat(
42
+ [zs_weight, zs_weight.new_zeros((zs_weight.shape[0], 1))],
43
+ dim=1) # D x (C + 1)
44
+ if model.roi_heads.box_predictor[0].cls_score.norm_weight:
45
+ zs_weight = F.normalize(zs_weight, p=2, dim=0)
46
+ zs_weight = zs_weight.to(model.device)
47
+ for k in range(len(model.roi_heads.box_predictor)):
48
+ del model.roi_heads.box_predictor[k].cls_score.zs_weight
49
+ model.roi_heads.box_predictor[k].cls_score.zs_weight = zs_weight
50
+
51
+ if hasattr(model.roi_heads, "box_predictor_cmm"):
52
+ for k in range(len(model.roi_heads.box_predictor_cmm)):
53
+ del model.roi_heads.box_predictor_cmm[k].cls_score.zs_weight
54
+ model.roi_heads.box_predictor_cmm[k].cls_score.zs_weight = zs_weight
proxydet/predictor.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ '''
3
+ Modifications Copyright (c) 2024-present NAVER Corp, Apache License v2.0
4
+ original source: https://github.com/facebookresearch/Detic/blob/main/detic/predictor.py
5
+ '''
6
+ import atexit
7
+ import bisect
8
+ import numpy as np
9
+ import multiprocessing as mp
10
+ from collections import deque
11
+ import cv2
12
+ import torch
13
+ import torch.nn.functional as F
14
+
15
+ from detectron2.data import MetadataCatalog
16
+ from detectron2.engine.defaults import DefaultPredictor
17
+ from detectron2.utils.video_visualizer import VideoVisualizer
18
+ from detectron2.utils.visualizer import ColorMode, Visualizer
19
+
20
+ from .modeling.utils import reset_cls_test
21
+
22
+ def get_text_encoder():
23
+ from proxydet.modeling.text.text_encoder import build_text_encoder
24
+ text_encoder = build_text_encoder(pretrain=True)
25
+ text_encoder.eval()
26
+ return text_encoder
27
+ def get_clip_embeddings(vocabulary, text_encoder, prompt='a '):
28
+ texts = [prompt + x for x in vocabulary]
29
+ emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu()
30
+ return emb
31
+
32
+ BUILDIN_CLASSIFIER = {
33
+ 'lvis': 'datasets/metadata/lvis_v1_clip_a+cname.npy',
34
+ 'objects365': 'datasets/metadata/o365_clip_a+cnamefix.npy',
35
+ 'openimages': 'datasets/metadata/oid_clip_a+cname.npy',
36
+ 'coco': 'datasets/metadata/coco_clip_a+cname.npy',
37
+ }
38
+
39
+ BUILDIN_METADATA_PATH = {
40
+ 'lvis': 'lvis_v1_val',
41
+ 'objects365': 'objects365_v2_val',
42
+ 'openimages': 'oid_val_expanded',
43
+ 'coco': 'coco_2017_val',
44
+ }
45
+
46
+ class VisualizationDemo(object):
47
+ def __init__(self, cfg, args,
48
+ instance_mode=ColorMode.IMAGE, parallel=False):
49
+ """
50
+ Args:
51
+ cfg (CfgNode):
52
+ instance_mode (ColorMode):
53
+ parallel (bool): whether to run the model in different processes from visualization.
54
+ Useful since the visualization logic can be slow.
55
+ """
56
+ if args.vocabulary == 'custom':
57
+ self.text_encoder = get_text_encoder()
58
+ self.metadata = MetadataCatalog.get("__unused")
59
+ self.metadata.thing_classes = args.custom_vocabulary.split(',')
60
+ classifier = get_clip_embeddings(self.metadata.thing_classes, self.text_encoder)
61
+ else:
62
+ self.metadata = MetadataCatalog.get(
63
+ BUILDIN_METADATA_PATH[args.vocabulary])
64
+ classifier = BUILDIN_CLASSIFIER[args.vocabulary]
65
+
66
+ num_classes = len(self.metadata.thing_classes)
67
+ self.cpu_device = torch.device("cpu")
68
+ self.instance_mode = instance_mode
69
+
70
+ self.parallel = parallel
71
+ if parallel:
72
+ num_gpu = torch.cuda.device_count()
73
+ self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu)
74
+ else:
75
+ self.predictor = DefaultPredictor(cfg)
76
+ reset_cls_test(self.predictor.model, classifier, num_classes)
77
+
78
+ # reset base category mask, based on the similarity
79
+ self.zeroshot_weight_path = args.zeroshot_weight_path
80
+ self.base_cat_threshold = args.base_cat_threshold
81
+ if self.zeroshot_weight_path is not None:
82
+ self.trained_zs_classifier = torch.tensor(
83
+ np.load(self.zeroshot_weight_path),
84
+ dtype=torch.float32
85
+ ).permute(1, 0).contiguous().to(cfg.MODEL.DEVICE) # D x C
86
+ self.trained_zs_classifier = F.normalize(self.trained_zs_classifier, p=2, dim=0)
87
+
88
+ self.base_cat_indices = self.predictor.model.roi_heads.base_cat_mask.nonzero().squeeze(-1)
89
+ self.reset_base_cat_mask()
90
+
91
+ def reset_base_cat_mask(self, new_base_cat_mask=None):
92
+ if new_base_cat_mask is None:
93
+ # L2 normalized
94
+ if hasattr(self.predictor.model.roi_heads.box_predictor, "__getitem__"):
95
+ custom_classifier = self.predictor.model.roi_heads.box_predictor[0].cls_score.zs_weight
96
+ else:
97
+ custom_classifier = self.predictor.model.roi_heads.box_predictor.cls_score.zs_weight
98
+
99
+ base_cat_sim = custom_classifier.T[:-1, :] @ self.trained_zs_classifier[:, self.base_cat_indices]
100
+
101
+ # reset base cat mask
102
+ new_base_cat_mask = torch.cat([base_cat_sim.max(dim=1)[0] > self.base_cat_threshold, base_cat_sim.new_zeros((1))], dim=0) # for bg
103
+ new_base_cat_mask = new_base_cat_mask.bool()
104
+ self.predictor.model.roi_heads.base_cat_mask = new_base_cat_mask
105
+ return new_base_cat_mask
106
+
107
+ def reset_classifier(self, vocabs):
108
+ thing_classes = vocabs.split(',')
109
+ classifier = get_clip_embeddings(thing_classes, self.text_encoder)
110
+ reset_cls_test(self.predictor.model, classifier, len(thing_classes))
111
+
112
+ def run_on_image(self, image):
113
+ """
114
+ Args:
115
+ image (np.ndarray): an image of shape (H, W, C) (in BGR order).
116
+ This is the format used by OpenCV.
117
+
118
+ Returns:
119
+ predictions (dict): the output of the model.
120
+ vis_output (VisImage): the visualized image output.
121
+ """
122
+ vis_output = None
123
+ predictions = self.predictor(image)
124
+ # Convert image from OpenCV BGR format to Matplotlib RGB format.
125
+ image = image[:, :, ::-1]
126
+ visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode)
127
+ if "panoptic_seg" in predictions:
128
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
129
+ vis_output = visualizer.draw_panoptic_seg_predictions(
130
+ panoptic_seg.to(self.cpu_device), segments_info
131
+ )
132
+ else:
133
+ if "sem_seg" in predictions:
134
+ vis_output = visualizer.draw_sem_seg(
135
+ predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
136
+ )
137
+ if "instances" in predictions:
138
+ instances = predictions["instances"].to(self.cpu_device)
139
+ vis_output = visualizer.draw_instance_predictions(predictions=instances)
140
+
141
+ return predictions, vis_output
142
+
143
+ def _frame_from_video(self, video):
144
+ while video.isOpened():
145
+ success, frame = video.read()
146
+ if success:
147
+ yield frame
148
+ else:
149
+ break
150
+
151
+ def run_on_video(self, video):
152
+ """
153
+ Visualizes predictions on frames of the input video.
154
+
155
+ Args:
156
+ video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be
157
+ either a webcam or a video file.
158
+
159
+ Yields:
160
+ ndarray: BGR visualizations of each video frame.
161
+ """
162
+ video_visualizer = VideoVisualizer(self.metadata, self.instance_mode)
163
+
164
+ def process_predictions(frame, predictions):
165
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
166
+ if "panoptic_seg" in predictions:
167
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
168
+ vis_frame = video_visualizer.draw_panoptic_seg_predictions(
169
+ frame, panoptic_seg.to(self.cpu_device), segments_info
170
+ )
171
+ elif "instances" in predictions:
172
+ predictions = predictions["instances"].to(self.cpu_device)
173
+ vis_frame = video_visualizer.draw_instance_predictions(frame, predictions)
174
+ elif "sem_seg" in predictions:
175
+ vis_frame = video_visualizer.draw_sem_seg(
176
+ frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
177
+ )
178
+
179
+ # Converts Matplotlib RGB format to OpenCV BGR format
180
+ vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR)
181
+ return vis_frame
182
+
183
+ frame_gen = self._frame_from_video(video)
184
+ if self.parallel:
185
+ buffer_size = self.predictor.default_buffer_size
186
+
187
+ frame_data = deque()
188
+
189
+ for cnt, frame in enumerate(frame_gen):
190
+ frame_data.append(frame)
191
+ self.predictor.put(frame)
192
+
193
+ if cnt >= buffer_size:
194
+ frame = frame_data.popleft()
195
+ predictions = self.predictor.get()
196
+ yield process_predictions(frame, predictions)
197
+
198
+ while len(frame_data):
199
+ frame = frame_data.popleft()
200
+ predictions = self.predictor.get()
201
+ yield process_predictions(frame, predictions)
202
+ else:
203
+ for frame in frame_gen:
204
+ yield process_predictions(frame, self.predictor(frame))
205
+
206
+
207
+ class AsyncPredictor:
208
+ """
209
+ A predictor that runs the model asynchronously, possibly on >1 GPUs.
210
+ Because rendering the visualization takes considerably amount of time,
211
+ this helps improve throughput a little bit when rendering videos.
212
+ """
213
+
214
+ class _StopToken:
215
+ pass
216
+
217
+ class _PredictWorker(mp.Process):
218
+ def __init__(self, cfg, task_queue, result_queue):
219
+ self.cfg = cfg
220
+ self.task_queue = task_queue
221
+ self.result_queue = result_queue
222
+ super().__init__()
223
+
224
+ def run(self):
225
+ predictor = DefaultPredictor(self.cfg)
226
+
227
+ while True:
228
+ task = self.task_queue.get()
229
+ if isinstance(task, AsyncPredictor._StopToken):
230
+ break
231
+ idx, data = task
232
+ result = predictor(data)
233
+ self.result_queue.put((idx, result))
234
+
235
+ def __init__(self, cfg, num_gpus: int = 1):
236
+ """
237
+ Args:
238
+ cfg (CfgNode):
239
+ num_gpus (int): if 0, will run on CPU
240
+ """
241
+ num_workers = max(num_gpus, 1)
242
+ self.task_queue = mp.Queue(maxsize=num_workers * 3)
243
+ self.result_queue = mp.Queue(maxsize=num_workers * 3)
244
+ self.procs = []
245
+ for gpuid in range(max(num_gpus, 1)):
246
+ cfg = cfg.clone()
247
+ cfg.defrost()
248
+ cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
249
+ self.procs.append(
250
+ AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
251
+ )
252
+
253
+ self.put_idx = 0
254
+ self.get_idx = 0
255
+ self.result_rank = []
256
+ self.result_data = []
257
+
258
+ for p in self.procs:
259
+ p.start()
260
+ atexit.register(self.shutdown)
261
+
262
+ def put(self, image):
263
+ self.put_idx += 1
264
+ self.task_queue.put((self.put_idx, image))
265
+
266
+ def get(self):
267
+ self.get_idx += 1 # the index needed for this request
268
+ if len(self.result_rank) and self.result_rank[0] == self.get_idx:
269
+ res = self.result_data[0]
270
+ del self.result_data[0], self.result_rank[0]
271
+ return res
272
+
273
+ while True:
274
+ # make sure the results are returned in the correct order
275
+ idx, res = self.result_queue.get()
276
+ if idx == self.get_idx:
277
+ return res
278
+ insert = bisect.bisect(self.result_rank, idx)
279
+ self.result_rank.insert(insert, idx)
280
+ self.result_data.insert(insert, res)
281
+
282
+ def __len__(self):
283
+ return self.put_idx - self.get_idx
284
+
285
+ def __call__(self, image):
286
+ self.put(image)
287
+ return self.get()
288
+
289
+ def shutdown(self):
290
+ for _ in self.procs:
291
+ self.task_queue.put(AsyncPredictor._StopToken())
292
+
293
+ @property
294
+ def default_buffer_size(self):
295
+ return len(self.procs) * 5
requirements.txt CHANGED
@@ -7,10 +7,4 @@ mss<=6.1.0
7
  timm<=0.5.4
8
  lvis
9
  nltk<=3.7
10
- #git+https://github.com/facebookresearch/detectron2.git
11
-
12
  numpy>=1.18.5
13
- #torch>=1.7.0
14
- #torchvision>=0.8.1
15
- git+https://github.com/huggingface/transformers.git
16
- opencv-python
 
7
  timm<=0.5.4
8
  lvis
9
  nltk<=3.7
 
 
10
  numpy>=1.18.5
 
 
 
 
third_party/CenterNet2/.github/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
4
+ Please read the [full text](https://code.fb.com/codeofconduct/)
5
+ so that you can understand what actions will and will not be tolerated.
third_party/CenterNet2/.github/CONTRIBUTING.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to detectron2
2
+
3
+ ## Issues
4
+ We use GitHub issues to track public bugs and questions.
5
+ Please make sure to follow one of the
6
+ [issue templates](https://github.com/facebookresearch/detectron2/issues/new/choose)
7
+ when reporting any issues.
8
+
9
+ Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
10
+ disclosure of security bugs. In those cases, please go through the process
11
+ outlined on that page and do not file a public issue.
12
+
13
+ ## Pull Requests
14
+ We actively welcome pull requests.
15
+
16
+ However, if you're adding any significant features (e.g. > 50 lines), please
17
+ make sure to discuss with maintainers about your motivation and proposals in an issue
18
+ before sending a PR. This is to save your time so you don't spend time on a PR that we'll not accept.
19
+
20
+ We do not always accept new features, and we take the following
21
+ factors into consideration:
22
+
23
+ 1. Whether the same feature can be achieved without modifying detectron2.
24
+ Detectron2 is designed so that you can implement many extensions from the outside, e.g.
25
+ those in [projects](https://github.com/facebookresearch/detectron2/tree/master/projects).
26
+ * If some part of detectron2 is not extensible enough, you can also bring up a more general issue to
27
+ improve it. Such feature request may be useful to more users.
28
+ 2. Whether the feature is potentially useful to a large audience (e.g. an impactful detection paper, a popular dataset,
29
+ a significant speedup, a widely useful utility),
30
+ or only to a small portion of users (e.g., a less-known paper, an improvement not in the object
31
+ detection field, a trick that's not very popular in the community, code to handle a non-standard type of data)
32
+ * Adoption of additional models, datasets, new task are by default not added to detectron2 before they
33
+ receive significant popularity in the community.
34
+ We sometimes accept such features in `projects/`, or as a link in `projects/README.md`.
35
+ 3. Whether the proposed solution has a good design / interface. This can be discussed in the issue prior to PRs, or
36
+ in the form of a draft PR.
37
+ 4. Whether the proposed solution adds extra mental/practical overhead to users who don't
38
+ need such feature.
39
+ 5. Whether the proposed solution breaks existing APIs.
40
+
41
+ To add a feature to an existing function/class `Func`, there are always two approaches:
42
+ (1) add new arguments to `Func`; (2) write a new `Func_with_new_feature`.
43
+ To meet the above criteria, we often prefer approach (2), because:
44
+
45
+ 1. It does not involve modifying or potentially breaking existing code.
46
+ 2. It does not add overhead to users who do not need the new feature.
47
+ 3. Adding new arguments to a function/class is not scalable w.r.t. all the possible new research ideas in the future.
48
+
49
+ When sending a PR, please do:
50
+
51
+ 1. If a PR contains multiple orthogonal changes, split it to several PRs.
52
+ 2. If you've added code that should be tested, add tests.
53
+ 3. For PRs that need experiments (e.g. adding a new model or new methods),
54
+ you don't need to update model zoo, but do provide experiment results in the description of the PR.
55
+ 4. If APIs are changed, update the documentation.
56
+ 5. We use the [Google style docstrings](https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html) in python.
57
+ 6. Make sure your code lints with `./dev/linter.sh`.
58
+
59
+
60
+ ## Contributor License Agreement ("CLA")
61
+ In order to accept your pull request, we need you to submit a CLA. You only need
62
+ to do this once to work on any of Facebook's open source projects.
63
+
64
+ Complete your CLA here: <https://code.facebook.com/cla>
65
+
66
+ ## License
67
+ By contributing to detectron2, you agree that your contributions will be licensed
68
+ under the LICENSE file in the root directory of this source tree.
third_party/CenterNet2/.github/Detectron2-Logo-Horz.svg ADDED
third_party/CenterNet2/.github/ISSUE_TEMPLATE.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ Please select an issue template from
3
+ https://github.com/facebookresearch/detectron2/issues/new/choose .
4
+
5
+ Otherwise your issue will be closed.
third_party/CenterNet2/.github/ISSUE_TEMPLATE/bugs.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "πŸ› Bugs"
3
+ about: Report bugs in detectron2
4
+ title: Please read & provide the following
5
+
6
+ ---
7
+
8
+ ## Instructions To Reproduce the πŸ› Bug:
9
+ 1. Full runnable code or full changes you made:
10
+ ```
11
+ If making changes to the project itself, please use output of the following command:
12
+ git rev-parse HEAD; git diff
13
+
14
+ <put code or diff here>
15
+ ```
16
+ 2. What exact command you run:
17
+ 3. __Full logs__ or other relevant observations:
18
+ ```
19
+ <put logs here>
20
+ ```
21
+ 4. please simplify the steps as much as possible so they do not require additional resources to
22
+ run, such as a private dataset.
23
+
24
+ ## Expected behavior:
25
+
26
+ If there are no obvious error in "full logs" provided above,
27
+ please tell us the expected behavior.
28
+
29
+ ## Environment:
30
+
31
+ Provide your environment information using the following command:
32
+ ```
33
+ wget -nc -q https://github.com/facebookresearch/detectron2/raw/main/detectron2/utils/collect_env.py && python collect_env.py
34
+ ```
35
+
36
+ If your issue looks like an installation issue / environment issue,
37
+ please first try to solve it yourself with the instructions in
38
+ https://detectron2.readthedocs.io/tutorials/install.html#common-installation-issues
third_party/CenterNet2/.github/ISSUE_TEMPLATE/config.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # require an issue template to be chosen
2
+ blank_issues_enabled: false
3
+
4
+ contact_links:
5
+ - name: How-To / All Other Questions
6
+ url: https://github.com/facebookresearch/detectron2/discussions
7
+ about: Use "github discussions" for community support on general questions that don't belong to the above issue categories
8
+ - name: Detectron2 Documentation
9
+ url: https://detectron2.readthedocs.io/index.html
10
+ about: Check if your question is answered in tutorials or API docs
11
+
12
+ # Unexpected behaviors & bugs are split to two templates.
13
+ # When they are one template, users think "it's not a bug" and don't choose the template.
14
+ #
15
+ # But the file name is still "unexpected-problems-bugs.md" so that old references
16
+ # to this issue template still works.
17
+ # It's ok since this template should be a superset of "bugs.md" (unexpected behaviors is a superset of bugs)
third_party/CenterNet2/.github/ISSUE_TEMPLATE/documentation.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "\U0001F4DA Documentation Issue"
3
+ about: Report a problem about existing documentation, comments, website or tutorials.
4
+ labels: documentation
5
+
6
+ ---
7
+
8
+ ## πŸ“š Documentation Issue
9
+
10
+ This issue category is for problems about existing documentation, not for asking how-to questions.
11
+
12
+ * Provide a link to an existing documentation/comment/tutorial:
13
+
14
+ * How should the above documentation/comment/tutorial improve: