Spaces:
Runtime error
Runtime error
Commit
Β·
8075387
1
Parent(s):
5f752ee
open proxydet demo
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- app.py +56 -39
- assets/beach.jpg +0 -0
- assets/desk.jpg +0 -0
- assets/pikachu.jpg +0 -0
- configs/Base-C2_L_R5021k_640b64_4x.yaml +83 -0
- configs/BoxSup-C2_Lbase_CLIP_R5021k_640b64_4x.yaml +7 -0
- configs/ProxyDet_R50_Lbase_INL.yaml +59 -0
- configs/ProxyDet_SwinB_Lbase_INL.yaml +51 -0
- datasets/metadata/__init__.py +0 -0
- demo.py +245 -0
- packages.txt +3 -0
- proxydet/__init__.py +18 -0
- proxydet/cat_names.py +1 -0
- proxydet/config.py +156 -0
- proxydet/custom_solver.py +78 -0
- proxydet/data/custom_build_augmentation.py +51 -0
- proxydet/data/custom_dataset_dataloader.py +331 -0
- proxydet/data/custom_dataset_mapper.py +280 -0
- proxydet/data/datasets/cc.py +23 -0
- proxydet/data/datasets/coco_zeroshot.py +121 -0
- proxydet/data/datasets/imagenet.py +41 -0
- proxydet/data/datasets/lvis_22k_categories.py +0 -0
- proxydet/data/datasets/lvis_v1.py +155 -0
- proxydet/data/datasets/objects365.py +770 -0
- proxydet/data/datasets/oid.py +535 -0
- proxydet/data/datasets/register_oid.py +122 -0
- proxydet/data/tar_dataset.py +138 -0
- proxydet/data/transforms/custom_augmentation_impl.py +60 -0
- proxydet/data/transforms/custom_transform.py +114 -0
- proxydet/evaluation/custom_coco_eval.py +124 -0
- proxydet/evaluation/oideval.py +699 -0
- proxydet/modeling/backbone/swintransformer.py +750 -0
- proxydet/modeling/backbone/timm.py +221 -0
- proxydet/modeling/debug.py +334 -0
- proxydet/modeling/meta_arch/custom_rcnn.py +232 -0
- proxydet/modeling/meta_arch/d2_deformable_detr.py +308 -0
- proxydet/modeling/roi_heads/proxydet_fast_rcnn.py +618 -0
- proxydet/modeling/roi_heads/proxydet_roi_heads.py +556 -0
- proxydet/modeling/roi_heads/zero_shot_classifier.py +111 -0
- proxydet/modeling/text/text_encoder.py +189 -0
- proxydet/modeling/utils.py +54 -0
- proxydet/predictor.py +295 -0
- requirements.txt +0 -6
- third_party/CenterNet2/.github/CODE_OF_CONDUCT.md +5 -0
- third_party/CenterNet2/.github/CONTRIBUTING.md +68 -0
- third_party/CenterNet2/.github/Detectron2-Logo-Horz.svg +1 -0
- third_party/CenterNet2/.github/ISSUE_TEMPLATE.md +5 -0
- third_party/CenterNet2/.github/ISSUE_TEMPLATE/bugs.md +38 -0
- third_party/CenterNet2/.github/ISSUE_TEMPLATE/config.yml +17 -0
- third_party/CenterNet2/.github/ISSUE_TEMPLATE/documentation.md +14 -0
app.py
CHANGED
|
@@ -3,10 +3,14 @@ import cv2
|
|
| 3 |
import os
|
| 4 |
import gradio as gr
|
| 5 |
import numpy as np
|
| 6 |
-
from
|
| 7 |
-
|
| 8 |
-
|
|
|
|
| 9 |
os.system("python3 -m pip install 'git+https://github.com/facebookresearch/detectron2.git'")
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# Use GPU if available
|
| 12 |
if torch.cuda.is_available():
|
|
@@ -14,26 +18,42 @@ if torch.cuda.is_available():
|
|
| 14 |
else:
|
| 15 |
device = torch.device("cpu")
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def query_image(img, text_queries, score_threshold):
|
| 23 |
-
text_queries = text_queries
|
| 24 |
-
text_queries = text_queries.split(",")
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
with torch.no_grad():
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
|
| 36 |
-
|
| 37 |
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 38 |
|
| 39 |
for box, score, label in zip(boxes, scores, labels):
|
|
@@ -47,35 +67,32 @@ def query_image(img, text_queries, score_threshold):
|
|
| 47 |
y = box[3] + 25
|
| 48 |
|
| 49 |
img = cv2.putText(
|
| 50 |
-
img,
|
| 51 |
)
|
| 52 |
return img
|
| 53 |
|
| 54 |
if __name__ == "__main__":
|
| 55 |
-
setup()
|
| 56 |
-
|
| 57 |
description = """
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
To use it, simply upload an image and enter comma separated text descriptions of objects you want to query the image for. You
|
| 63 |
-
can also use the score threshold slider to set a threshold to filter out low probability predictions.
|
| 64 |
-
\n\nOWL-ViT is trained on text templates,
|
| 65 |
-
hence you can get better predictions by querying the image with text templates used in training the original model: *"photo of a star-spangled banner"*,
|
| 66 |
-
*"image of a shoe"*. Refer to the <a href="https://arxiv.org/abs/2103.00020">CLIP</a> paper to see the full list of text templates used to augment the training data.
|
| 67 |
-
\n\n<a href="https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/zeroshot_object_detection_with_owlvit.ipynb">Colab demo</a>
|
| 68 |
"""
|
| 69 |
demo = gr.Interface(
|
| 70 |
query_image,
|
| 71 |
-
inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1)],
|
| 72 |
outputs="image",
|
| 73 |
-
title="
|
| 74 |
description=description,
|
| 75 |
examples=[
|
| 76 |
-
["assets/
|
| 77 |
-
["assets/
|
| 78 |
-
["assets/
|
| 79 |
],
|
| 80 |
)
|
| 81 |
-
demo.launch(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import os
|
| 4 |
import gradio as gr
|
| 5 |
import numpy as np
|
| 6 |
+
from argparse import Namespace
|
| 7 |
+
try:
|
| 8 |
+
import detectron2
|
| 9 |
+
except:
|
| 10 |
os.system("python3 -m pip install 'git+https://github.com/facebookresearch/detectron2.git'")
|
| 11 |
+
import detectron2
|
| 12 |
+
from demo import setup_cfg
|
| 13 |
+
from proxydet.predictor import VisualizationDemo
|
| 14 |
|
| 15 |
# Use GPU if available
|
| 16 |
if torch.cuda.is_available():
|
|
|
|
| 18 |
else:
|
| 19 |
device = torch.device("cpu")
|
| 20 |
|
| 21 |
+
# download metadata
|
| 22 |
+
zs_weight_path = 'datasets/metadata/lvis_v1_clip_a+cname.npy'
|
| 23 |
+
if not os.path.exists(zs_weight_path):
|
| 24 |
+
wget.download("https://github.com/facebookresearch/Detic/raw/main/datasets/metadata/lvis_v1_clip_a+cname.npy", out=zs_weight_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
args = Namespace(
|
| 27 |
+
base_cat_threshold=0.9,
|
| 28 |
+
confidence_threshold=0.0,
|
| 29 |
+
config_file='configs/ProxyDet_SwinB_Lbase_INL.yaml',
|
| 30 |
+
cpu=not torch.cuda.is_available(),
|
| 31 |
+
custom_vocabulary='headphone,webcam,paper,coffe',
|
| 32 |
+
input=['.assets/desk.jpg'],
|
| 33 |
+
opts=['MODEL.WEIGHTS', 'models/proxydet_swinb_w_inl.pth'],
|
| 34 |
+
output='out.jpg',
|
| 35 |
+
pred_all_class=False,
|
| 36 |
+
video_input=None,
|
| 37 |
+
vocabulary='custom',
|
| 38 |
+
webcam=None,
|
| 39 |
+
zeroshot_weight_path='datasets/metadata/lvis_v1_clip_a+cname.npy'
|
| 40 |
+
)
|
| 41 |
+
cfg = setup_cfg(args)
|
| 42 |
+
ovd_demo = VisualizationDemo(cfg, args)
|
| 43 |
|
| 44 |
+
def query_image(img, text_queries, score_threshold, base_alpha, novel_beta):
|
| 45 |
+
text_queries_split = text_queries.split(",")
|
| 46 |
+
ovd_demo.reset_classifier(text_queries)
|
| 47 |
+
ovd_demo.reset_base_cat_mask()
|
| 48 |
+
ovd_demo.predictor.model.roi_heads.cmm_base_alpha = base_alpha
|
| 49 |
+
ovd_demo.predictor.model.roi_heads.cmm_novel_beta = novel_beta
|
| 50 |
+
img_bgr = img[:, :, ::-1]
|
| 51 |
with torch.no_grad():
|
| 52 |
+
predictions, visualized_output = ovd_demo.run_on_image(img_bgr)
|
| 53 |
+
output_instances = predictions["instances"].to(device)
|
| 54 |
+
boxes = output_instances.pred_boxes.tensor
|
| 55 |
+
scores = output_instances.scores
|
| 56 |
+
labels = output_instances.pred_classes.tolist()
|
|
|
|
|
|
|
| 57 |
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 58 |
|
| 59 |
for box, score, label in zip(boxes, scores, labels):
|
|
|
|
| 67 |
y = box[3] + 25
|
| 68 |
|
| 69 |
img = cv2.putText(
|
| 70 |
+
img, text_queries_split[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA
|
| 71 |
)
|
| 72 |
return img
|
| 73 |
|
| 74 |
if __name__ == "__main__":
|
|
|
|
|
|
|
| 75 |
description = """
|
| 76 |
+
Gradio demo for ProxyDet, introduced in <a href="https://arxiv.org/abs/2312.07266">ProxyDet: Synthesizing Proxy Novel Classes via Classwise Mixup for Open-Vocabulary Object Detection</a>.
|
| 77 |
+
\n\nYou can use ProxyDet to query images with text descriptions of any object.
|
| 78 |
+
Simply upload an image and enter comma separated objects (e.g., "dog,cat,headphone") which you want to detect within the image.
|
| 79 |
+
You can also use the score threshold slider to set a threshold to filter out low probability predictions.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
"""
|
| 81 |
demo = gr.Interface(
|
| 82 |
query_image,
|
| 83 |
+
inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1), gr.Slider(0, 1, value=0.15), gr.Slider(0, 1, value=0.35)],
|
| 84 |
outputs="image",
|
| 85 |
+
title="Open-Vocabulary Object Detection with ProxyDet",
|
| 86 |
description=description,
|
| 87 |
examples=[
|
| 88 |
+
["assets/desk.jpg", "headphone,webcam,paper,coffee", 0.11, 0.15, 0.35],
|
| 89 |
+
["assets/beach.jpg", "person,kite", 0.1, 0.15, 0.35],
|
| 90 |
+
["assets/pikachu.jpg", "pikachu,person", 0.15, 0.15, 0.35],
|
| 91 |
],
|
| 92 |
)
|
| 93 |
+
demo.launch(
|
| 94 |
+
server_name="0.0.0.0",
|
| 95 |
+
server_port=int(os.environ["NSML_PORT1"]),
|
| 96 |
+
share=False,
|
| 97 |
+
debug=True,
|
| 98 |
+
)
|
assets/beach.jpg
ADDED
|
assets/desk.jpg
ADDED
|
assets/pikachu.jpg
ADDED
|
configs/Base-C2_L_R5021k_640b64_4x.yaml
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code Copied from https://github.com/facebookresearch/Detic/blob/main/configs/Base-C2_L_R5021k_640b64_4x.yaml
|
| 2 |
+
MODEL:
|
| 3 |
+
META_ARCHITECTURE: "CustomRCNN"
|
| 4 |
+
MASK_ON: True
|
| 5 |
+
PROPOSAL_GENERATOR:
|
| 6 |
+
NAME: "CenterNet"
|
| 7 |
+
WEIGHTS: "models/resnet50_miil_21k.pkl"
|
| 8 |
+
BACKBONE:
|
| 9 |
+
NAME: build_p67_timm_fpn_backbone
|
| 10 |
+
TIMM:
|
| 11 |
+
BASE_NAME: resnet50_in21k
|
| 12 |
+
FPN:
|
| 13 |
+
IN_FEATURES: ["layer3", "layer4", "layer5"]
|
| 14 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
| 15 |
+
PIXEL_STD: [58.395, 57.12, 57.375]
|
| 16 |
+
ROI_HEADS:
|
| 17 |
+
NAME: ProxydetCascadeROIHeads
|
| 18 |
+
IN_FEATURES: ["p3", "p4", "p5"]
|
| 19 |
+
IOU_THRESHOLDS: [0.6]
|
| 20 |
+
NUM_CLASSES: 1203
|
| 21 |
+
SCORE_THRESH_TEST: 0.02
|
| 22 |
+
NMS_THRESH_TEST: 0.5
|
| 23 |
+
ROI_BOX_CASCADE_HEAD:
|
| 24 |
+
IOUS: [0.6, 0.7, 0.8]
|
| 25 |
+
ROI_BOX_HEAD:
|
| 26 |
+
NAME: "FastRCNNConvFCHead"
|
| 27 |
+
NUM_FC: 2
|
| 28 |
+
POOLER_RESOLUTION: 7
|
| 29 |
+
CLS_AGNOSTIC_BBOX_REG: True
|
| 30 |
+
MULT_PROPOSAL_SCORE: True
|
| 31 |
+
|
| 32 |
+
USE_SIGMOID_CE: True
|
| 33 |
+
USE_FED_LOSS: True
|
| 34 |
+
ROI_MASK_HEAD:
|
| 35 |
+
NAME: "MaskRCNNConvUpsampleHead"
|
| 36 |
+
NUM_CONV: 4
|
| 37 |
+
POOLER_RESOLUTION: 14
|
| 38 |
+
CLS_AGNOSTIC_MASK: True
|
| 39 |
+
CENTERNET:
|
| 40 |
+
NUM_CLASSES: 1203
|
| 41 |
+
REG_WEIGHT: 1.
|
| 42 |
+
NOT_NORM_REG: True
|
| 43 |
+
ONLY_PROPOSAL: True
|
| 44 |
+
WITH_AGN_HM: True
|
| 45 |
+
INFERENCE_TH: 0.0001
|
| 46 |
+
PRE_NMS_TOPK_TRAIN: 4000
|
| 47 |
+
POST_NMS_TOPK_TRAIN: 2000
|
| 48 |
+
PRE_NMS_TOPK_TEST: 1000
|
| 49 |
+
POST_NMS_TOPK_TEST: 256
|
| 50 |
+
NMS_TH_TRAIN: 0.9
|
| 51 |
+
NMS_TH_TEST: 0.9
|
| 52 |
+
POS_WEIGHT: 0.5
|
| 53 |
+
NEG_WEIGHT: 0.5
|
| 54 |
+
IGNORE_HIGH_FP: 0.85
|
| 55 |
+
DATASETS:
|
| 56 |
+
TRAIN: ("lvis_v1_train",)
|
| 57 |
+
TEST: ("lvis_v1_val",)
|
| 58 |
+
DATALOADER:
|
| 59 |
+
SAMPLER_TRAIN: "RepeatFactorTrainingSampler"
|
| 60 |
+
REPEAT_THRESHOLD: 0.001
|
| 61 |
+
NUM_WORKERS: 8
|
| 62 |
+
TEST:
|
| 63 |
+
DETECTIONS_PER_IMAGE: 300
|
| 64 |
+
SOLVER:
|
| 65 |
+
LR_SCHEDULER_NAME: "WarmupCosineLR"
|
| 66 |
+
CHECKPOINT_PERIOD: 1000000000
|
| 67 |
+
WARMUP_ITERS: 10000
|
| 68 |
+
WARMUP_FACTOR: 0.0001
|
| 69 |
+
USE_CUSTOM_SOLVER: True
|
| 70 |
+
OPTIMIZER: "ADAMW"
|
| 71 |
+
MAX_ITER: 90000
|
| 72 |
+
IMS_PER_BATCH: 64
|
| 73 |
+
BASE_LR: 0.0002
|
| 74 |
+
CLIP_GRADIENTS:
|
| 75 |
+
ENABLED: True
|
| 76 |
+
INPUT:
|
| 77 |
+
FORMAT: RGB
|
| 78 |
+
CUSTOM_AUG: EfficientDetResizeCrop
|
| 79 |
+
TRAIN_SIZE: 640
|
| 80 |
+
OUTPUT_DIR: "./output/ProxyDet/auto"
|
| 81 |
+
EVAL_PROPOSAL_AR: False
|
| 82 |
+
VERSION: 2
|
| 83 |
+
FP16: True
|
configs/BoxSup-C2_Lbase_CLIP_R5021k_640b64_4x.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code Copied from https://github.com/facebookresearch/Detic/blob/main/configs/BoxSup-C2_Lbase_CLIP_R5021k_640b64_4x.yaml
|
| 2 |
+
_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml"
|
| 3 |
+
MODEL:
|
| 4 |
+
ROI_BOX_HEAD:
|
| 5 |
+
USE_ZEROSHOT_CLS: True
|
| 6 |
+
DATASETS:
|
| 7 |
+
TRAIN: ("lvis_v1_train_norare",)
|
configs/ProxyDet_R50_Lbase_INL.yaml
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code Adapted from https://github.com/facebookresearch/Detic/blob/main/configs/Detic_LbaseI_CLIP_R5021k_640b64_4x_ft4x_max-size.yaml
|
| 2 |
+
_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml"
|
| 3 |
+
MODEL:
|
| 4 |
+
ROI_BOX_HEAD:
|
| 5 |
+
USE_ZEROSHOT_CLS: True
|
| 6 |
+
IMAGE_LABEL_LOSS: 'max_size'
|
| 7 |
+
USE_REGIONAL_EMBEDDING: True
|
| 8 |
+
ROI_HEADS:
|
| 9 |
+
BASE_CAT_MASK: "datasets/metadata/lvis_v1_base_cat_mask.npy"
|
| 10 |
+
CMM:
|
| 11 |
+
MIXUP_STAGE: [2]
|
| 12 |
+
MIXUP_STAGE_TEST: [2]
|
| 13 |
+
MIXUP_BETA: 1.0
|
| 14 |
+
LOSS: "l1"
|
| 15 |
+
LOSS_WEIGHT: 256.0
|
| 16 |
+
SEPARATED_BRANCH: True
|
| 17 |
+
BASE_ALPHA: 0.15
|
| 18 |
+
NOVEL_BETA: 0.35
|
| 19 |
+
USE_INL: False
|
| 20 |
+
PROTOTYPE: "obj_score"
|
| 21 |
+
PROTOTYPE_TEMP: 1.0
|
| 22 |
+
CLASSIFIER_TEMP: 1.0
|
| 23 |
+
USE_SIGMOID_CE: True
|
| 24 |
+
WEIGHTS: "models/BoxSup-C2_Lbase_CLIP_R5021k_640b64_4x.pth"
|
| 25 |
+
SOLVER:
|
| 26 |
+
MAX_ITER: 90000
|
| 27 |
+
IMS_PER_BATCH: 64
|
| 28 |
+
BASE_LR: 0.0002
|
| 29 |
+
WARMUP_ITERS: 1000
|
| 30 |
+
WARMUP_FACTOR: 0.001
|
| 31 |
+
DATASETS:
|
| 32 |
+
TRAIN: ("lvis_v1_train_norare","imagenet_lvis_v1")
|
| 33 |
+
DATALOADER:
|
| 34 |
+
SAMPLER_TRAIN: "MultiDatasetSampler"
|
| 35 |
+
DATASET_RATIO: [1, 4]
|
| 36 |
+
USE_DIFF_BS_SIZE: True
|
| 37 |
+
DATASET_BS: [8, 32]
|
| 38 |
+
DATASET_INPUT_SIZE: [640, 320]
|
| 39 |
+
USE_RFS: [True, False]
|
| 40 |
+
DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]]
|
| 41 |
+
FILTER_EMPTY_ANNOTATIONS: False
|
| 42 |
+
MULTI_DATASET_GROUPING: True
|
| 43 |
+
DATASET_ANN: ['box', 'image']
|
| 44 |
+
NUM_WORKERS: 8
|
| 45 |
+
WITH_IMAGE_LABELS: True
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
configs/ProxyDet_SwinB_Lbase_INL.yaml
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code Adapted from https://github.com/facebookresearch/Detic/blob/main/configs/Detic_LbaseI_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml
|
| 2 |
+
_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml"
|
| 3 |
+
MODEL:
|
| 4 |
+
ROI_BOX_HEAD:
|
| 5 |
+
USE_ZEROSHOT_CLS: True
|
| 6 |
+
IMAGE_LABEL_LOSS: 'max_size'
|
| 7 |
+
USE_REGIONAL_EMBEDDING: True
|
| 8 |
+
ROI_HEADS:
|
| 9 |
+
BASE_CAT_MASK: "datasets/metadata/lvis_v1_base_cat_mask.npy"
|
| 10 |
+
CMM:
|
| 11 |
+
MIXUP_STAGE: [2]
|
| 12 |
+
MIXUP_STAGE_TEST: [2]
|
| 13 |
+
MIXUP_BETA: 1.0
|
| 14 |
+
LOSS: "l1"
|
| 15 |
+
LOSS_WEIGHT: 256.0
|
| 16 |
+
SEPARATED_BRANCH: True
|
| 17 |
+
BASE_ALPHA: 0.15
|
| 18 |
+
NOVEL_BETA: 0.35
|
| 19 |
+
USE_INL: False
|
| 20 |
+
PROTOTYPE: "obj_score"
|
| 21 |
+
PROTOTYPE_TEMP: 1.0
|
| 22 |
+
CLASSIFIER_TEMP: 1.0
|
| 23 |
+
USE_SIGMOID_CE: True
|
| 24 |
+
BACKBONE:
|
| 25 |
+
NAME: build_swintransformer_fpn_backbone
|
| 26 |
+
SWIN:
|
| 27 |
+
SIZE: B-22k
|
| 28 |
+
FPN:
|
| 29 |
+
IN_FEATURES: ["swin1", "swin2", "swin3"]
|
| 30 |
+
WEIGHTS: "models/BoxSup-C2_Lbase_CLIP_SwinB_896b32_4x.pth"
|
| 31 |
+
SOLVER:
|
| 32 |
+
MAX_ITER: 180000
|
| 33 |
+
IMS_PER_BATCH: 32
|
| 34 |
+
BASE_LR: 0.0001
|
| 35 |
+
WARMUP_ITERS: 1000
|
| 36 |
+
WARMUP_FACTOR: 0.001
|
| 37 |
+
DATASETS:
|
| 38 |
+
TRAIN: ("lvis_v1_train_norare","imagenet_lvis_v1")
|
| 39 |
+
DATALOADER:
|
| 40 |
+
SAMPLER_TRAIN: "MultiDatasetSampler"
|
| 41 |
+
DATASET_RATIO: [1, 4]
|
| 42 |
+
USE_DIFF_BS_SIZE: True
|
| 43 |
+
DATASET_BS: [4, 16]
|
| 44 |
+
DATASET_INPUT_SIZE: [896, 448]
|
| 45 |
+
USE_RFS: [True, False]
|
| 46 |
+
DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]]
|
| 47 |
+
FILTER_EMPTY_ANNOTATIONS: False
|
| 48 |
+
MULTI_DATASET_GROUPING: True
|
| 49 |
+
DATASET_ANN: ['box', 'image']
|
| 50 |
+
NUM_WORKERS: 8
|
| 51 |
+
WITH_IMAGE_LABELS: True
|
datasets/metadata/__init__.py
ADDED
|
File without changes
|
demo.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
'''
|
| 3 |
+
Modifications Copyright (c) 2024-present NAVER Corp, Apache License v2.0
|
| 4 |
+
original source: https://github.com/facebookresearch/Detic/blob/main/demo.py
|
| 5 |
+
'''
|
| 6 |
+
import argparse
|
| 7 |
+
import glob
|
| 8 |
+
import multiprocessing as mp
|
| 9 |
+
import numpy as np
|
| 10 |
+
import os
|
| 11 |
+
import tempfile
|
| 12 |
+
import time
|
| 13 |
+
import warnings
|
| 14 |
+
import cv2
|
| 15 |
+
import tqdm
|
| 16 |
+
import sys
|
| 17 |
+
import mss
|
| 18 |
+
|
| 19 |
+
from detectron2.config import get_cfg
|
| 20 |
+
from detectron2.data.detection_utils import read_image
|
| 21 |
+
from detectron2.utils.logger import setup_logger
|
| 22 |
+
from detectron2.engine.defaults import _highlight
|
| 23 |
+
|
| 24 |
+
sys.path.insert(0, 'third_party/CenterNet2/')
|
| 25 |
+
from centernet.config import add_centernet_config
|
| 26 |
+
from proxydet.config import add_proxydet_config
|
| 27 |
+
|
| 28 |
+
from proxydet.predictor import VisualizationDemo
|
| 29 |
+
|
| 30 |
+
# Fake a video capture object OpenCV style - half width, half height of first screen using MSS
|
| 31 |
+
class ScreenGrab:
|
| 32 |
+
def __init__(self):
|
| 33 |
+
self.sct = mss.mss()
|
| 34 |
+
m0 = self.sct.monitors[0]
|
| 35 |
+
self.monitor = {'top': 0, 'left': 0, 'width': m0['width'] / 2, 'height': m0['height'] / 2}
|
| 36 |
+
|
| 37 |
+
def read(self):
|
| 38 |
+
img = np.array(self.sct.grab(self.monitor))
|
| 39 |
+
nf = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
|
| 40 |
+
return (True, nf)
|
| 41 |
+
|
| 42 |
+
def isOpened(self):
|
| 43 |
+
return True
|
| 44 |
+
def release(self):
|
| 45 |
+
return True
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# constants
|
| 49 |
+
WINDOW_NAME = "ProxyDet"
|
| 50 |
+
|
| 51 |
+
def setup_cfg(args):
|
| 52 |
+
cfg = get_cfg()
|
| 53 |
+
if args.cpu:
|
| 54 |
+
cfg.MODEL.DEVICE="cpu"
|
| 55 |
+
add_centernet_config(cfg)
|
| 56 |
+
add_proxydet_config(cfg)
|
| 57 |
+
cfg.merge_from_file(args.config_file)
|
| 58 |
+
cfg.merge_from_list(args.opts)
|
| 59 |
+
# Set score_threshold for builtin models
|
| 60 |
+
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold
|
| 61 |
+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
|
| 62 |
+
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold
|
| 63 |
+
cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = 'rand' # load later
|
| 64 |
+
if not args.pred_all_class:
|
| 65 |
+
cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = True
|
| 66 |
+
cfg.freeze()
|
| 67 |
+
return cfg
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_parser():
|
| 71 |
+
parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--config-file",
|
| 74 |
+
default="configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml",
|
| 75 |
+
metavar="FILE",
|
| 76 |
+
help="path to config file",
|
| 77 |
+
)
|
| 78 |
+
parser.add_argument("--webcam", help="Take inputs from webcam.")
|
| 79 |
+
parser.add_argument("--cpu", action='store_true', help="Use CPU only.")
|
| 80 |
+
parser.add_argument("--video-input", help="Path to video file.")
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"--input",
|
| 83 |
+
nargs="+",
|
| 84 |
+
help="A list of space separated input images; "
|
| 85 |
+
"or a single glob pattern such as 'directory/*.jpg'",
|
| 86 |
+
)
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--output",
|
| 89 |
+
help="A file or directory to save output visualizations. "
|
| 90 |
+
"If not given, will show output in an OpenCV window.",
|
| 91 |
+
)
|
| 92 |
+
parser.add_argument(
|
| 93 |
+
"--vocabulary",
|
| 94 |
+
default="lvis",
|
| 95 |
+
choices=['lvis', 'openimages', 'objects365', 'coco', 'custom'],
|
| 96 |
+
help="",
|
| 97 |
+
)
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--custom_vocabulary",
|
| 100 |
+
default="",
|
| 101 |
+
help="",
|
| 102 |
+
)
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--zeroshot_weight_path",
|
| 105 |
+
default=None,
|
| 106 |
+
help="zeroshot text embedding path used during training",
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument("--pred_all_class", action='store_true')
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--confidence-threshold",
|
| 111 |
+
type=float,
|
| 112 |
+
default=0.5,
|
| 113 |
+
help="Minimum score for instance predictions to be shown",
|
| 114 |
+
)
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--base-cat-threshold",
|
| 117 |
+
type=float,
|
| 118 |
+
default=0.9,
|
| 119 |
+
help="Minimum score for similarity with trained base categories",
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--opts",
|
| 123 |
+
help="Modify config options using the command-line 'KEY VALUE' pairs",
|
| 124 |
+
default=[],
|
| 125 |
+
nargs=argparse.REMAINDER,
|
| 126 |
+
)
|
| 127 |
+
return parser
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def test_opencv_video_format(codec, file_ext):
|
| 131 |
+
with tempfile.TemporaryDirectory(prefix="video_format_test") as dir:
|
| 132 |
+
filename = os.path.join(dir, "test_file" + file_ext)
|
| 133 |
+
writer = cv2.VideoWriter(
|
| 134 |
+
filename=filename,
|
| 135 |
+
fourcc=cv2.VideoWriter_fourcc(*codec),
|
| 136 |
+
fps=float(30),
|
| 137 |
+
frameSize=(10, 10),
|
| 138 |
+
isColor=True,
|
| 139 |
+
)
|
| 140 |
+
[writer.write(np.zeros((10, 10, 3), np.uint8)) for _ in range(30)]
|
| 141 |
+
writer.release()
|
| 142 |
+
if os.path.isfile(filename):
|
| 143 |
+
return True
|
| 144 |
+
return False
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
mp.set_start_method("spawn", force=True)
|
| 149 |
+
args = get_parser().parse_args()
|
| 150 |
+
setup_logger(name="fvcore")
|
| 151 |
+
logger = setup_logger()
|
| 152 |
+
logger.info("Arguments: " + str(args))
|
| 153 |
+
|
| 154 |
+
cfg = setup_cfg(args)
|
| 155 |
+
print(_highlight(cfg.dump(), ".yaml"))
|
| 156 |
+
|
| 157 |
+
demo = VisualizationDemo(cfg, args)
|
| 158 |
+
|
| 159 |
+
if args.input:
|
| 160 |
+
if len(args.input) == 1:
|
| 161 |
+
args.input = glob.glob(os.path.expanduser(args.input[0]))
|
| 162 |
+
assert args.input, "The input path(s) was not found"
|
| 163 |
+
for path in tqdm.tqdm(args.input, disable=not args.output):
|
| 164 |
+
img = read_image(path, format="BGR")
|
| 165 |
+
start_time = time.time()
|
| 166 |
+
predictions, visualized_output = demo.run_on_image(img)
|
| 167 |
+
logger.info(
|
| 168 |
+
"{}: {} in {:.2f}s".format(
|
| 169 |
+
path,
|
| 170 |
+
"detected {} instances".format(len(predictions["instances"]))
|
| 171 |
+
if "instances" in predictions
|
| 172 |
+
else "finished",
|
| 173 |
+
time.time() - start_time,
|
| 174 |
+
)
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
if args.output:
|
| 178 |
+
if os.path.isdir(args.output):
|
| 179 |
+
assert os.path.isdir(args.output), args.output
|
| 180 |
+
out_filename = os.path.join(args.output, os.path.basename(path))
|
| 181 |
+
else:
|
| 182 |
+
assert len(args.input) == 1, "Please specify a directory with args.output"
|
| 183 |
+
out_filename = args.output
|
| 184 |
+
visualized_output.save(out_filename)
|
| 185 |
+
else:
|
| 186 |
+
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
|
| 187 |
+
cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1])
|
| 188 |
+
if cv2.waitKey(0) == 27:
|
| 189 |
+
break # esc to quit
|
| 190 |
+
elif args.webcam:
|
| 191 |
+
assert args.input is None, "Cannot have both --input and --webcam!"
|
| 192 |
+
assert args.output is None, "output not yet supported with --webcam!"
|
| 193 |
+
if args.webcam == "screen":
|
| 194 |
+
cam = ScreenGrab()
|
| 195 |
+
else:
|
| 196 |
+
cam = cv2.VideoCapture(int(args.webcam))
|
| 197 |
+
for vis in tqdm.tqdm(demo.run_on_video(cam)):
|
| 198 |
+
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
|
| 199 |
+
cv2.imshow(WINDOW_NAME, vis)
|
| 200 |
+
if cv2.waitKey(1) == 27:
|
| 201 |
+
break # esc to quit
|
| 202 |
+
cam.release()
|
| 203 |
+
cv2.destroyAllWindows()
|
| 204 |
+
elif args.video_input:
|
| 205 |
+
video = cv2.VideoCapture(args.video_input)
|
| 206 |
+
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 207 |
+
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 208 |
+
frames_per_second = video.get(cv2.CAP_PROP_FPS)
|
| 209 |
+
num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 210 |
+
basename = os.path.basename(args.video_input)
|
| 211 |
+
codec, file_ext = (
|
| 212 |
+
("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4")
|
| 213 |
+
)
|
| 214 |
+
if codec == ".mp4v":
|
| 215 |
+
warnings.warn("x264 codec not available, switching to mp4v")
|
| 216 |
+
if args.output:
|
| 217 |
+
if os.path.isdir(args.output):
|
| 218 |
+
output_fname = os.path.join(args.output, basename)
|
| 219 |
+
output_fname = os.path.splitext(output_fname)[0] + file_ext
|
| 220 |
+
else:
|
| 221 |
+
output_fname = args.output
|
| 222 |
+
assert not os.path.isfile(output_fname), output_fname
|
| 223 |
+
output_file = cv2.VideoWriter(
|
| 224 |
+
filename=output_fname,
|
| 225 |
+
# some installation of opencv may not support x264 (due to its license),
|
| 226 |
+
# you can try other format (e.g. MPEG)
|
| 227 |
+
fourcc=cv2.VideoWriter_fourcc(*codec),
|
| 228 |
+
fps=float(frames_per_second),
|
| 229 |
+
frameSize=(width, height),
|
| 230 |
+
isColor=True,
|
| 231 |
+
)
|
| 232 |
+
assert os.path.isfile(args.video_input)
|
| 233 |
+
for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames):
|
| 234 |
+
if args.output:
|
| 235 |
+
output_file.write(vis_frame)
|
| 236 |
+
else:
|
| 237 |
+
cv2.namedWindow(basename, cv2.WINDOW_NORMAL)
|
| 238 |
+
cv2.imshow(basename, vis_frame)
|
| 239 |
+
if cv2.waitKey(1) == 27:
|
| 240 |
+
break # esc to quit
|
| 241 |
+
video.release()
|
| 242 |
+
if args.output:
|
| 243 |
+
output_file.release()
|
| 244 |
+
else:
|
| 245 |
+
cv2.destroyAllWindows()
|
packages.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ffmpeg
|
| 2 |
+
libsm6
|
| 3 |
+
libxext6
|
proxydet/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
from .modeling.meta_arch import custom_rcnn
|
| 3 |
+
from .modeling.roi_heads import proxydet_roi_heads
|
| 4 |
+
from .modeling.backbone import swintransformer
|
| 5 |
+
from .modeling.backbone import timm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from .data.datasets import lvis_v1
|
| 9 |
+
from .data.datasets import imagenet
|
| 10 |
+
from .data.datasets import cc
|
| 11 |
+
from .data.datasets import objects365
|
| 12 |
+
from .data.datasets import oid
|
| 13 |
+
from .data.datasets import coco_zeroshot
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from .modeling.meta_arch import d2_deformable_detr
|
| 17 |
+
except:
|
| 18 |
+
pass
|
proxydet/cat_names.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
lvis_cat_names = ['aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock', 'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet', 'antenna', 'apple', 'applesauce', 'apricot', 'apron', 'aquarium', 'arctic_(type_of_shoe)', 'armband', 'armchair', 'armoire', 'armor', 'artichoke', 'trash_can', 'ashtray', 'asparagus', 'atomizer', 'avocado', 'award', 'awning', 'ax', 'baboon', 'baby_buggy', 'basketball_backboard', 'backpack', 'handbag', 'suitcase', 'bagel', 'bagpipe', 'baguet', 'bait', 'ball', 'ballet_skirt', 'balloon', 'bamboo', 'banana', 'Band_Aid', 'bandage', 'bandanna', 'banjo', 'banner', 'barbell', 'barge', 'barrel', 'barrette', 'barrow', 'baseball_base', 'baseball', 'baseball_bat', 'baseball_cap', 'baseball_glove', 'basket', 'basketball', 'bass_horn', 'bat_(animal)', 'bath_mat', 'bath_towel', 'bathrobe', 'bathtub', 'batter_(food)', 'battery', 'beachball', 'bead', 'bean_curd', 'beanbag', 'beanie', 'bear', 'bed', 'bedpan', 'bedspread', 'cow', 'beef_(food)', 'beeper', 'beer_bottle', 'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt', 'belt_buckle', 'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor', 'billboard', 'binder', 'binoculars', 'bird', 'birdfeeder', 'birdbath', 'birdcage', 'birdhouse', 'birthday_cake', 'birthday_card', 'pirate_flag', 'black_sheep', 'blackberry', 'blackboard', 'blanket', 'blazer', 'blender', 'blimp', 'blinker', 'blouse', 'blueberry', 'gameboard', 'boat', 'bob', 'bobbin', 'bobby_pin', 'boiled_egg', 'bolo_tie', 'deadbolt', 'bolt', 'bonnet', 'book', 'bookcase', 'booklet', 'bookmark', 'boom_microphone', 'boot', 'bottle', 'bottle_opener', 'bouquet', 'bow_(weapon)', 'bow_(decorative_ribbons)', 'bow-tie', 'bowl', 'pipe_bowl', 'bowler_hat', 'bowling_ball', 'box', 'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere', 'bread-bin', 'bread', 'breechcloth', 'bridal_gown', 'briefcase', 'broccoli', 'broach', 'broom', 'brownie', 'brussels_sprouts', 'bubble_gum', 'bucket', 'horse_buggy', 'bull', 'bulldog', 'bulldozer', 'bullet_train', 'bulletin_board', 'bulletproof_vest', 'bullhorn', 'bun', 'bunk_bed', 'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car', 'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf', 'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)', 'can', 'can_opener', 'candle', 'candle_holder', 'candy_bar', 'candy_cane', 'walking_cane', 'canister', 'canoe', 'cantaloup', 'canteen', 'cap_(headwear)', 'bottle_cap', 'cape', 'cappuccino', 'car_(automobile)', 'railcar_(part_of_a_train)', 'elevator_car', 'car_battery', 'identity_card', 'card', 'cardigan', 'cargo_ship', 'carnation', 'horse_carriage', 'carrot', 'tote_bag', 'cart', 'carton', 'cash_register', 'casserole', 'cassette', 'cast', 'cat', 'cauliflower', 'cayenne_(spice)', 'CD_player', 'celery', 'cellular_telephone', 'chain_mail', 'chair', 'chaise_longue', 'chalice', 'chandelier', 'chap', 'checkbook', 'checkerboard', 'cherry', 'chessboard', 'chicken_(animal)', 'chickpea', 'chili_(vegetable)', 'chime', 'chinaware', 'crisp_(potato_chip)', 'poker_chip', 'chocolate_bar', 'chocolate_cake', 'chocolate_milk', 'chocolate_mousse', 'choker', 'chopping_board', 'chopstick', 'Christmas_tree', 'slide', 'cider', 'cigar_box', 'cigarette', 'cigarette_case', 'cistern', 'clarinet', 'clasp', 'cleansing_agent', 'cleat_(for_securing_rope)', 'clementine', 'clip', 'clipboard', 'clippers_(for_plants)', 'cloak', 'clock', 'clock_tower', 'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster', 'coat', 'coat_hanger', 'coatrack', 'cock', 'cockroach', 'cocoa_(beverage)', 'coconut', 'coffee_maker', 'coffee_table', 'coffeepot', 'coil', 'coin', 'colander', 'coleslaw', 'coloring_material', 'combination_lock', 'pacifier', 'comic_book', 'compass', 'computer_keyboard', 'condiment', 'cone', 'control', 'convertible_(automobile)', 'sofa_bed', 'cooker', 'cookie', 'cooking_utensil', 'cooler_(for_food)', 'cork_(bottle_plug)', 'corkboard', 'corkscrew', 'edible_corn', 'cornbread', 'cornet', 'cornice', 'cornmeal', 'corset', 'costume', 'cougar', 'coverall', 'cowbell', 'cowboy_hat', 'crab_(animal)', 'crabmeat', 'cracker', 'crape', 'crate', 'crayon', 'cream_pitcher', 'crescent_roll', 'crib', 'crock_pot', 'crossbar', 'crouton', 'crow', 'crowbar', 'crown', 'crucifix', 'cruise_ship', 'police_cruiser', 'crumb', 'crutch', 'cub_(animal)', 'cube', 'cucumber', 'cufflink', 'cup', 'trophy_cup', 'cupboard', 'cupcake', 'hair_curler', 'curling_iron', 'curtain', 'cushion', 'cylinder', 'cymbal', 'dagger', 'dalmatian', 'dartboard', 'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk', 'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table', 'tux', 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher', 'dishwasher_detergent', 'dispenser', 'diving_board', 'Dixie_cup', 'dog', 'dog_collar', 'doll', 'dollar', 'dollhouse', 'dolphin', 'domestic_ass', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly', 'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit', 'dresser', 'drill', 'drone', 'dropper', 'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling', 'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan', 'eagle', 'earphone', 'earplug', 'earring', 'easel', 'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater', 'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk', 'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan', 'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)', 'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)', 'fire_alarm', 'fire_engine', 'fire_extinguisher', 'fire_hose', 'fireplace', 'fireplug', 'first-aid_kit', 'fish', 'fish_(food)', 'fishbowl', 'fishing_rod', 'flag', 'flagpole', 'flamingo', 'flannel', 'flap', 'flash', 'flashlight', 'fleece', 'flip-flop_(sandal)', 'flipper_(footwear)', 'flower_arrangement', 'flute_glass', 'foal', 'folding_chair', 'food_processor', 'football_(American)', 'football_helmet', 'footstool', 'fork', 'forklift', 'freight_car', 'French_toast', 'freshener', 'frisbee', 'frog', 'fruit_juice', 'frying_pan', 'fudge', 'funnel', 'futon', 'gag', 'garbage', 'garbage_truck', 'garden_hose', 'gargle', 'gargoyle', 'garlic', 'gasmask', 'gazelle', 'gelatin', 'gemstone', 'generator', 'giant_panda', 'gift_wrap', 'ginger', 'giraffe', 'cincture', 'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles', 'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose', 'gorilla', 'gourd', 'grape', 'grater', 'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle', 'grill', 'grits', 'grizzly', 'grocery_bag', 'guitar', 'gull', 'gun', 'hairbrush', 'hairnet', 'hairpin', 'halter_top', 'ham', 'hamburger', 'hammer', 'hammock', 'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel', 'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw', 'hardback_book', 'harmonium', 'hat', 'hatbox', 'veil', 'headband', 'headboard', 'headlight', 'headscarf', 'headset', 'headstall_(for_horses)', 'heart', 'heater', 'helicopter', 'helmet', 'heron', 'highchair', 'hinge', 'hippopotamus', 'hockey_stick', 'hog', 'home_plate_(baseball)', 'honey', 'fume_hood', 'hook', 'hookah', 'hornet', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce', 'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear', 'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate', 'igniter', 'inhaler', 'iPod', 'iron_(for_clothing)', 'ironing_board', 'jacket', 'jam', 'jar', 'jean', 'jeep', 'jelly_bean', 'jersey', 'jet_plane', 'jewel', 'jewelry', 'joystick', 'jumpsuit', 'kayak', 'keg', 'kennel', 'kettle', 'key', 'keycard', 'kilt', 'kimono', 'kitchen_sink', 'kitchen_table', 'kite', 'kitten', 'kiwi_fruit', 'knee_pad', 'knife', 'knitting_needle', 'knob', 'knocker_(on_a_door)', 'koala', 'lab_coat', 'ladder', 'ladle', 'ladybug', 'lamb_(animal)', 'lamb-chop', 'lamp', 'lamppost', 'lampshade', 'lantern', 'lanyard', 'laptop_computer', 'lasagna', 'latch', 'lawn_mower', 'leather', 'legging_(clothing)', 'Lego', 'legume', 'lemon', 'lemonade', 'lettuce', 'license_plate', 'life_buoy', 'life_jacket', 'lightbulb', 'lightning_rod', 'lime', 'limousine', 'lion', 'lip_balm', 'liquor', 'lizard', 'log', 'lollipop', 'speaker_(stero_equipment)', 'loveseat', 'machine_gun', 'magazine', 'magnet', 'mail_slot', 'mailbox_(at_home)', 'mallard', 'mallet', 'mammoth', 'manatee', 'mandarin_orange', 'manger', 'manhole', 'map', 'marker', 'martini', 'mascot', 'mashed_potato', 'masher', 'mask', 'mast', 'mat_(gym_equipment)', 'matchbox', 'mattress', 'measuring_cup', 'measuring_stick', 'meatball', 'medicine', 'melon', 'microphone', 'microscope', 'microwave_oven', 'milestone', 'milk', 'milk_can', 'milkshake', 'minivan', 'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)', 'money', 'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor', 'motor_scooter', 'motor_vehicle', 'motorcycle', 'mound_(baseball)', 'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom', 'music_stool', 'musical_instrument', 'nailfile', 'napkin', 'neckerchief', 'necklace', 'necktie', 'needle', 'nest', 'newspaper', 'newsstand', 'nightshirt', 'nosebag_(for_animals)', 'noseband_(for_animals)', 'notebook', 'notepad', 'nut', 'nutcracker', 'oar', 'octopus_(food)', 'octopus_(animal)', 'oil_lamp', 'olive_oil', 'omelet', 'onion', 'orange_(fruit)', 'orange_juice', 'ostrich', 'ottoman', 'oven', 'overalls_(clothing)', 'owl', 'packet', 'inkpad', 'pad', 'paddle', 'padlock', 'paintbrush', 'painting', 'pajamas', 'palette', 'pan_(for_cooking)', 'pan_(metal_container)', 'pancake', 'pantyhose', 'papaya', 'paper_plate', 'paper_towel', 'paperback_book', 'paperweight', 'parachute', 'parakeet', 'parasail_(sports)', 'parasol', 'parchment', 'parka', 'parking_meter', 'parrot', 'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport', 'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter', 'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'wooden_leg', 'pegboard', 'pelican', 'pen', 'pencil', 'pencil_box', 'pencil_sharpener', 'pendulum', 'penguin', 'pennant', 'penny_(coin)', 'pepper', 'pepper_mill', 'perfume', 'persimmon', 'person', 'pet', 'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano', 'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow', 'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball', 'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)', 'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat', 'plate', 'platter', 'playpen', 'pliers', 'plow_(farm_equipment)', 'plume', 'pocket_watch', 'pocketknife', 'poker_(fire_stirring_tool)', 'pole', 'polo_shirt', 'poncho', 'pony', 'pool_table', 'pop_(soda)', 'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot', 'potato', 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn', 'pretzel', 'printer', 'projectile_(weapon)', 'projector', 'propeller', 'prune', 'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher', 'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit', 'race_car', 'racket', 'radar', 'radiator', 'radio_receiver', 'radish', 'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat', 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt', 'recliner', 'record_player', 'reflector', 'remote_control', 'rhinoceros', 'rib_(food)', 'rifle', 'ring', 'river_boat', 'road_map', 'robe', 'rocking_chair', 'rodent', 'roller_skate', 'Rollerblade', 'rolling_pin', 'root_beer', 'router_(computer_equipment)', 'rubber_band', 'runner_(carpet)', 'plastic_bag', 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag', 'safety_pin', 'sail', 'salad', 'salad_plate', 'salami', 'salmon_(fish)', 'salmon_(food)', 'salsa', 'saltshaker', 'sandal_(type_of_shoe)', 'sandwich', 'satchel', 'saucepan', 'saucer', 'sausage', 'sawhorse', 'saxophone', 'scale_(measuring_instrument)', 'scarecrow', 'scarf', 'school_bus', 'scissors', 'scoreboard', 'scraper', 'screwdriver', 'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane', 'seashell', 'sewing_machine', 'shaker', 'shampoo', 'shark', 'sharpener', 'Sharpie', 'shaver_(electric)', 'shaving_cream', 'shawl', 'shears', 'sheep', 'shepherd_dog', 'sherbert', 'shield', 'shirt', 'shoe', 'shopping_bag', 'shopping_cart', 'short_pants', 'shot_glass', 'shoulder_bag', 'shovel', 'shower_head', 'shower_cap', 'shower_curtain', 'shredder_(for_paper)', 'signboard', 'silo', 'sink', 'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka', 'ski_pole', 'skirt', 'skullcap', 'sled', 'sleeping_bag', 'sling_(bandage)', 'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman', 'snowmobile', 'soap', 'soccer_ball', 'sock', 'sofa', 'softball', 'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon', 'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)', 'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'crawfish', 'sponge', 'spoon', 'sportswear', 'spotlight', 'squid_(food)', 'squirrel', 'stagecoach', 'stapler_(stapling_machine)', 'starfish', 'statue_(sculpture)', 'steak_(food)', 'steak_knife', 'steering_wheel', 'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew', 'stirrer', 'stirrup', 'stool', 'stop_sign', 'brake_light', 'stove', 'strainer', 'strap', 'straw_(for_drinking)', 'strawberry', 'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer', 'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower', 'sunglasses', 'sunhat', 'surfboard', 'sushi', 'mop', 'sweat_pants', 'sweatband', 'sweater', 'sweatshirt', 'sweet_potato', 'swimsuit', 'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table', 'table', 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag', 'taillight', 'tambourine', 'army_tank', 'tank_(storage_vessel)', 'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure', 'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup', 'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth', 'telephone_pole', 'telephoto_lens', 'television_camera', 'television_set', 'tennis_ball', 'tennis_racket', 'tequila', 'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread', 'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer', 'tinfoil', 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster', 'toaster_oven', 'toilet', 'toilet_tissue', 'tomato', 'tongs', 'toolbox', 'toothbrush', 'toothpaste', 'toothpick', 'cover', 'tortilla', 'tow_truck', 'towel', 'towel_rack', 'toy', 'tractor_(farm_equipment)', 'traffic_light', 'dirt_bike', 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', 'tray', 'trench_coat', 'triangle_(musical_instrument)', 'tricycle', 'tripod', 'trousers', 'truck', 'truffle_(chocolate)', 'trunk', 'vat', 'turban', 'turkey_(food)', 'turnip', 'turtle', 'turtleneck_(clothing)', 'typewriter', 'umbrella', 'underwear', 'unicycle', 'urinal', 'urn', 'vacuum_cleaner', 'vase', 'vending_machine', 'vent', 'vest', 'videotape', 'vinegar', 'violin', 'vodka', 'volleyball', 'vulture', 'waffle', 'waffle_iron', 'wagon', 'wagon_wheel', 'walking_stick', 'wall_clock', 'wall_socket', 'wallet', 'walrus', 'wardrobe', 'washbasin', 'automatic_washer', 'watch', 'water_bottle', 'water_cooler', 'water_faucet', 'water_heater', 'water_jug', 'water_gun', 'water_scooter', 'water_ski', 'water_tower', 'watering_can', 'watermelon', 'weathervane', 'webcam', 'wedding_cake', 'wedding_ring', 'wet_suit', 'wheel', 'wheelchair', 'whipped_cream', 'whistle', 'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)', 'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket', 'wineglass', 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon', 'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', 'yogurt', 'yoke_(animal_equipment)', 'zebra', 'zucchini']
|
proxydet/config.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
'''
|
| 3 |
+
Modifications Copyright (c) 2024-present NAVER Corp, Apache License v2.0
|
| 4 |
+
original source: https://github.com/facebookresearch/Detic/blob/main/detic/config.py
|
| 5 |
+
'''
|
| 6 |
+
from detectron2.config import CfgNode as CN
|
| 7 |
+
|
| 8 |
+
def add_proxydet_config(cfg):
|
| 9 |
+
_C = cfg
|
| 10 |
+
|
| 11 |
+
_C.WITH_IMAGE_LABELS = False # Turn on co-training with classification data
|
| 12 |
+
|
| 13 |
+
# Open-vocabulary classifier
|
| 14 |
+
_C.MODEL.ROI_BOX_HEAD.USE_ZEROSHOT_CLS = False # Use fixed classifier for open-vocabulary detection
|
| 15 |
+
_C.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = 'datasets/metadata/lvis_v1_clip_a+cname.npy'
|
| 16 |
+
_C.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_DIM = 512
|
| 17 |
+
_C.MODEL.ROI_BOX_HEAD.NORM_WEIGHT = True
|
| 18 |
+
_C.MODEL.ROI_BOX_HEAD.NORM_TEMP = 50.0
|
| 19 |
+
_C.MODEL.ROI_BOX_HEAD.IGNORE_ZERO_CATS = False
|
| 20 |
+
_C.MODEL.ROI_BOX_HEAD.USE_BIAS = 0.0 # >= 0: not use
|
| 21 |
+
|
| 22 |
+
_C.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE = False # CenterNet2
|
| 23 |
+
_C.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE = False
|
| 24 |
+
_C.MODEL.ROI_BOX_HEAD.PRIOR_PROB = 0.01
|
| 25 |
+
_C.MODEL.ROI_BOX_HEAD.USE_FED_LOSS = False # Federated Loss
|
| 26 |
+
_C.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH = \
|
| 27 |
+
'datasets/metadata/lvis_v1_train_cat_info.json'
|
| 28 |
+
_C.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CAT = 50
|
| 29 |
+
_C.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT = 0.5
|
| 30 |
+
|
| 31 |
+
# Classification data configs
|
| 32 |
+
_C.MODEL.ROI_BOX_HEAD.IMAGE_LABEL_LOSS = 'max_size' # max, softmax, sum
|
| 33 |
+
_C.MODEL.ROI_BOX_HEAD.IMAGE_LOSS_WEIGHT = 0.1
|
| 34 |
+
_C.MODEL.ROI_BOX_HEAD.IMAGE_BOX_SIZE = 1.0
|
| 35 |
+
_C.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX = False # Used for image-box loss and caption loss
|
| 36 |
+
_C.MODEL.ROI_BOX_HEAD.WS_NUM_PROPS = 128 # num proposals for image-labeled data
|
| 37 |
+
_C.MODEL.ROI_BOX_HEAD.WITH_SOFTMAX_PROP = False # Used for WSDDN
|
| 38 |
+
_C.MODEL.ROI_BOX_HEAD.CAPTION_WEIGHT = 1.0 # Caption loss weight
|
| 39 |
+
_C.MODEL.ROI_BOX_HEAD.NEG_CAP_WEIGHT = 0.125 # Caption loss hyper-parameter
|
| 40 |
+
_C.MODEL.ROI_BOX_HEAD.ADD_FEATURE_TO_PROP = False # Used for WSDDN
|
| 41 |
+
_C.MODEL.ROI_BOX_HEAD.SOFTMAX_WEAK_LOSS = False # Used when USE_SIGMOID_CE is False
|
| 42 |
+
|
| 43 |
+
_C.MODEL.ROI_HEADS.MASK_WEIGHT = 1.0
|
| 44 |
+
_C.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = False # For demo only
|
| 45 |
+
|
| 46 |
+
# Class-wise Multi-Modal Mixup
|
| 47 |
+
_C.MODEL.ROI_BOX_HEAD.USE_REGIONAL_EMBEDDING = False
|
| 48 |
+
_C.MODEL.ROI_HEADS.BASE_CAT_MASK = "datasets/metadata/lvis_v1_base_cat_mask.npy"
|
| 49 |
+
_C.MODEL.ROI_HEADS.CMM = CN()
|
| 50 |
+
_C.MODEL.ROI_HEADS.CMM.MIXUP_STAGE = []
|
| 51 |
+
_C.MODEL.ROI_HEADS.CMM.MIXUP_STAGE_TEST = None
|
| 52 |
+
_C.MODEL.ROI_HEADS.CMM.MIXUP_BETA = 1.0
|
| 53 |
+
_C.MODEL.ROI_HEADS.CMM.LOSS = "l1"
|
| 54 |
+
_C.MODEL.ROI_HEADS.CMM.LOSS_WEIGHT = 1.0
|
| 55 |
+
_C.MODEL.ROI_HEADS.CMM.SEPARATED_BRANCH = False
|
| 56 |
+
_C.MODEL.ROI_HEADS.CMM.BASE_ALPHA = 0.5
|
| 57 |
+
_C.MODEL.ROI_HEADS.CMM.NOVEL_BETA = 0.5
|
| 58 |
+
_C.MODEL.ROI_HEADS.CMM.USE_INL = False
|
| 59 |
+
_C.MODEL.ROI_HEADS.CMM.PROTOTYPE = "center"
|
| 60 |
+
_C.MODEL.ROI_HEADS.CMM.PROTOTYPE_TEMP = 1.0
|
| 61 |
+
_C.MODEL.ROI_HEADS.CMM.CLASSIFIER_TEMP = None
|
| 62 |
+
_C.MODEL.ROI_HEADS.CMM.USE_SIGMOID_CE = True
|
| 63 |
+
|
| 64 |
+
# Caption losses
|
| 65 |
+
_C.MODEL.CAP_BATCH_RATIO = 4 # Ratio between detection data and caption data
|
| 66 |
+
_C.MODEL.WITH_CAPTION = False
|
| 67 |
+
_C.MODEL.SYNC_CAPTION_BATCH = False # synchronize across GPUs to enlarge # "classes"
|
| 68 |
+
|
| 69 |
+
# dynamic class sampling when training with 21K classes
|
| 70 |
+
_C.MODEL.DYNAMIC_CLASSIFIER = False
|
| 71 |
+
_C.MODEL.NUM_SAMPLE_CATS = 50
|
| 72 |
+
|
| 73 |
+
# Different classifiers in testing, used in cross-dataset evaluation
|
| 74 |
+
_C.MODEL.RESET_CLS_TESTS = False
|
| 75 |
+
_C.MODEL.TEST_CLASSIFIERS = []
|
| 76 |
+
_C.MODEL.TEST_NUM_CLASSES = []
|
| 77 |
+
|
| 78 |
+
# Backbones
|
| 79 |
+
_C.MODEL.SWIN = CN()
|
| 80 |
+
_C.MODEL.SWIN.SIZE = 'T' # 'T', 'S', 'B'
|
| 81 |
+
_C.MODEL.SWIN.USE_CHECKPOINT = False
|
| 82 |
+
_C.MODEL.SWIN.OUT_FEATURES = (1, 2, 3) # FPN stride 8 - 32
|
| 83 |
+
|
| 84 |
+
_C.MODEL.TIMM = CN()
|
| 85 |
+
_C.MODEL.TIMM.BASE_NAME = 'resnet50'
|
| 86 |
+
_C.MODEL.TIMM.OUT_LEVELS = (3, 4, 5)
|
| 87 |
+
_C.MODEL.TIMM.NORM = 'FrozenBN'
|
| 88 |
+
_C.MODEL.TIMM.FREEZE_AT = 0
|
| 89 |
+
_C.MODEL.TIMM.PRETRAINED = False
|
| 90 |
+
_C.MODEL.DATASET_LOSS_WEIGHT = []
|
| 91 |
+
|
| 92 |
+
# Multi-dataset dataloader
|
| 93 |
+
_C.DATALOADER.DATASET_RATIO = [1, 1] # sample ratio
|
| 94 |
+
_C.DATALOADER.USE_RFS = [False, False]
|
| 95 |
+
_C.DATALOADER.MULTI_DATASET_GROUPING = False # Always true when multi-dataset is enabled
|
| 96 |
+
_C.DATALOADER.DATASET_ANN = ['box', 'box'] # Annotation type of each dataset
|
| 97 |
+
_C.DATALOADER.USE_DIFF_BS_SIZE = False # Use different batchsize for each dataset
|
| 98 |
+
_C.DATALOADER.DATASET_BS = [8, 32] # Used when USE_DIFF_BS_SIZE is on
|
| 99 |
+
_C.DATALOADER.DATASET_INPUT_SIZE = [896, 384] # Used when USE_DIFF_BS_SIZE is on
|
| 100 |
+
_C.DATALOADER.DATASET_INPUT_SCALE = [(0.1, 2.0), (0.5, 1.5)] # Used when USE_DIFF_BS_SIZE is on
|
| 101 |
+
_C.DATALOADER.DATASET_MIN_SIZES = [(640, 800), (320, 400)] # Used when USE_DIFF_BS_SIZE is on
|
| 102 |
+
_C.DATALOADER.DATASET_MAX_SIZES = [1333, 667] # Used when USE_DIFF_BS_SIZE is on
|
| 103 |
+
_C.DATALOADER.USE_TAR_DATASET = False # for ImageNet-21K, directly reading from unziped files
|
| 104 |
+
_C.DATALOADER.TARFILE_PATH = 'datasets/imagenet/metadata-22k/tar_files.npy'
|
| 105 |
+
_C.DATALOADER.TAR_INDEX_DIR = 'datasets/imagenet/metadata-22k/tarindex_npy'
|
| 106 |
+
|
| 107 |
+
_C.SOLVER.USE_CUSTOM_SOLVER = False
|
| 108 |
+
_C.SOLVER.OPTIMIZER = 'SGD'
|
| 109 |
+
_C.SOLVER.BACKBONE_MULTIPLIER = 1.0 # Used in DETR
|
| 110 |
+
_C.SOLVER.CUSTOM_MULTIPLIER = 1.0 # Used in DETR
|
| 111 |
+
_C.SOLVER.CUSTOM_MULTIPLIER_NAME = [] # Used in DETR
|
| 112 |
+
|
| 113 |
+
# Deformable DETR
|
| 114 |
+
_C.MODEL.DETR = CN()
|
| 115 |
+
_C.MODEL.DETR.NUM_CLASSES = 80
|
| 116 |
+
_C.MODEL.DETR.FROZEN_WEIGHTS = '' # For Segmentation
|
| 117 |
+
_C.MODEL.DETR.GIOU_WEIGHT = 2.0
|
| 118 |
+
_C.MODEL.DETR.L1_WEIGHT = 5.0
|
| 119 |
+
_C.MODEL.DETR.DEEP_SUPERVISION = True
|
| 120 |
+
_C.MODEL.DETR.NO_OBJECT_WEIGHT = 0.1
|
| 121 |
+
_C.MODEL.DETR.CLS_WEIGHT = 2.0
|
| 122 |
+
_C.MODEL.DETR.NUM_FEATURE_LEVELS = 4
|
| 123 |
+
_C.MODEL.DETR.TWO_STAGE = False
|
| 124 |
+
_C.MODEL.DETR.WITH_BOX_REFINE = False
|
| 125 |
+
_C.MODEL.DETR.FOCAL_ALPHA = 0.25
|
| 126 |
+
_C.MODEL.DETR.NHEADS = 8
|
| 127 |
+
_C.MODEL.DETR.DROPOUT = 0.1
|
| 128 |
+
_C.MODEL.DETR.DIM_FEEDFORWARD = 2048
|
| 129 |
+
_C.MODEL.DETR.ENC_LAYERS = 6
|
| 130 |
+
_C.MODEL.DETR.DEC_LAYERS = 6
|
| 131 |
+
_C.MODEL.DETR.PRE_NORM = False
|
| 132 |
+
_C.MODEL.DETR.HIDDEN_DIM = 256
|
| 133 |
+
_C.MODEL.DETR.NUM_OBJECT_QUERIES = 100
|
| 134 |
+
|
| 135 |
+
_C.MODEL.DETR.USE_FED_LOSS = False
|
| 136 |
+
_C.MODEL.DETR.WEAK_WEIGHT = 0.1
|
| 137 |
+
|
| 138 |
+
_C.INPUT.CUSTOM_AUG = ''
|
| 139 |
+
_C.INPUT.TRAIN_SIZE = 640
|
| 140 |
+
_C.INPUT.TEST_SIZE = 640
|
| 141 |
+
_C.INPUT.SCALE_RANGE = (0.1, 2.)
|
| 142 |
+
# 'default' for fixed short/ long edge, 'square' for max size=INPUT.SIZE
|
| 143 |
+
_C.INPUT.TEST_INPUT_TYPE = 'default'
|
| 144 |
+
|
| 145 |
+
_C.FIND_UNUSED_PARAM = True
|
| 146 |
+
_C.EVAL_PRED_AR = False
|
| 147 |
+
_C.EVAL_PROPOSAL_AR = False
|
| 148 |
+
_C.EVAL_CAT_SPEC_AR = False
|
| 149 |
+
_C.IS_DEBUG = False
|
| 150 |
+
_C.QUICK_DEBUG = False
|
| 151 |
+
_C.FP16 = False
|
| 152 |
+
_C.EVAL_AP_FIX = False
|
| 153 |
+
_C.GEN_PSEDO_LABELS = False
|
| 154 |
+
_C.SAVE_DEBUG_PATH = 'output/save_debug/'
|
| 155 |
+
|
| 156 |
+
_C.EVAL_START = 0
|
proxydet/custom_solver.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
from enum import Enum
|
| 3 |
+
import itertools
|
| 4 |
+
from typing import Any, Callable, Dict, Iterable, List, Set, Type, Union
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from detectron2.config import CfgNode
|
| 8 |
+
|
| 9 |
+
from detectron2.solver.build import maybe_add_gradient_clipping
|
| 10 |
+
|
| 11 |
+
def match_name_keywords(n, name_keywords):
|
| 12 |
+
out = False
|
| 13 |
+
for b in name_keywords:
|
| 14 |
+
if b in n:
|
| 15 |
+
out = True
|
| 16 |
+
break
|
| 17 |
+
return out
|
| 18 |
+
|
| 19 |
+
def build_custom_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
|
| 20 |
+
"""
|
| 21 |
+
Build an optimizer from config.
|
| 22 |
+
"""
|
| 23 |
+
params: List[Dict[str, Any]] = []
|
| 24 |
+
memo: Set[torch.nn.parameter.Parameter] = set()
|
| 25 |
+
custom_multiplier_name = cfg.SOLVER.CUSTOM_MULTIPLIER_NAME
|
| 26 |
+
optimizer_type = cfg.SOLVER.OPTIMIZER
|
| 27 |
+
for key, value in model.named_parameters(recurse=True):
|
| 28 |
+
if not value.requires_grad:
|
| 29 |
+
continue
|
| 30 |
+
# Avoid duplicating parameters
|
| 31 |
+
if value in memo:
|
| 32 |
+
continue
|
| 33 |
+
memo.add(value)
|
| 34 |
+
lr = cfg.SOLVER.BASE_LR
|
| 35 |
+
weight_decay = cfg.SOLVER.WEIGHT_DECAY
|
| 36 |
+
if "backbone" in key:
|
| 37 |
+
lr = lr * cfg.SOLVER.BACKBONE_MULTIPLIER
|
| 38 |
+
if match_name_keywords(key, custom_multiplier_name):
|
| 39 |
+
lr = lr * cfg.SOLVER.CUSTOM_MULTIPLIER
|
| 40 |
+
print('Costum LR', key, lr)
|
| 41 |
+
param = {"params": [value], "lr": lr}
|
| 42 |
+
if optimizer_type != 'ADAMW':
|
| 43 |
+
param['weight_decay'] = weight_decay
|
| 44 |
+
params += [param]
|
| 45 |
+
|
| 46 |
+
def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class
|
| 47 |
+
# detectron2 doesn't have full model gradient clipping now
|
| 48 |
+
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
|
| 49 |
+
enable = (
|
| 50 |
+
cfg.SOLVER.CLIP_GRADIENTS.ENABLED
|
| 51 |
+
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
|
| 52 |
+
and clip_norm_val > 0.0
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
class FullModelGradientClippingOptimizer(optim):
|
| 56 |
+
def step(self, closure=None):
|
| 57 |
+
all_params = itertools.chain(*[x["params"] for x in self.param_groups])
|
| 58 |
+
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
|
| 59 |
+
super().step(closure=closure)
|
| 60 |
+
|
| 61 |
+
return FullModelGradientClippingOptimizer if enable else optim
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if optimizer_type == 'SGD':
|
| 65 |
+
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
|
| 66 |
+
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM,
|
| 67 |
+
nesterov=cfg.SOLVER.NESTEROV
|
| 68 |
+
)
|
| 69 |
+
elif optimizer_type == 'ADAMW':
|
| 70 |
+
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
|
| 71 |
+
params, cfg.SOLVER.BASE_LR,
|
| 72 |
+
weight_decay=cfg.SOLVER.WEIGHT_DECAY
|
| 73 |
+
)
|
| 74 |
+
else:
|
| 75 |
+
raise NotImplementedError(f"no optimizer type {optimizer_type}")
|
| 76 |
+
if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
|
| 77 |
+
optimizer = maybe_add_gradient_clipping(cfg, optimizer)
|
| 78 |
+
return optimizer
|
proxydet/data/custom_build_augmentation.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import logging
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pycocotools.mask as mask_util
|
| 5 |
+
import torch
|
| 6 |
+
from fvcore.common.file_io import PathManager
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
from detectron2.data import transforms as T
|
| 11 |
+
from .transforms.custom_augmentation_impl import EfficientDetResizeCrop
|
| 12 |
+
|
| 13 |
+
def build_custom_augmentation(cfg, is_train, scale=None, size=None, \
|
| 14 |
+
min_size=None, max_size=None):
|
| 15 |
+
"""
|
| 16 |
+
Create a list of default :class:`Augmentation` from config.
|
| 17 |
+
Now it includes resizing and flipping.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
list[Augmentation]
|
| 21 |
+
"""
|
| 22 |
+
if cfg.INPUT.CUSTOM_AUG == 'ResizeShortestEdge':
|
| 23 |
+
if is_train:
|
| 24 |
+
min_size = cfg.INPUT.MIN_SIZE_TRAIN if min_size is None else min_size
|
| 25 |
+
max_size = cfg.INPUT.MAX_SIZE_TRAIN if max_size is None else max_size
|
| 26 |
+
sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
|
| 27 |
+
else:
|
| 28 |
+
min_size = cfg.INPUT.MIN_SIZE_TEST
|
| 29 |
+
max_size = cfg.INPUT.MAX_SIZE_TEST
|
| 30 |
+
sample_style = "choice"
|
| 31 |
+
augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)]
|
| 32 |
+
elif cfg.INPUT.CUSTOM_AUG == 'EfficientDetResizeCrop':
|
| 33 |
+
if is_train:
|
| 34 |
+
scale = cfg.INPUT.SCALE_RANGE if scale is None else scale
|
| 35 |
+
size = cfg.INPUT.TRAIN_SIZE if size is None else size
|
| 36 |
+
else:
|
| 37 |
+
scale = (1, 1)
|
| 38 |
+
size = cfg.INPUT.TEST_SIZE
|
| 39 |
+
augmentation = [EfficientDetResizeCrop(size, scale)]
|
| 40 |
+
else:
|
| 41 |
+
assert 0, cfg.INPUT.CUSTOM_AUG
|
| 42 |
+
|
| 43 |
+
if is_train:
|
| 44 |
+
augmentation.append(T.RandomFlip())
|
| 45 |
+
return augmentation
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
build_custom_transform_gen = build_custom_augmentation
|
| 49 |
+
"""
|
| 50 |
+
Alias for backward-compatibility.
|
| 51 |
+
"""
|
proxydet/data/custom_dataset_dataloader.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# Part of the code is from https://github.com/xingyizhou/UniDet/blob/master/projects/UniDet/unidet/data/multi_dataset_dataloader.py (Apache-2.0 License)
|
| 3 |
+
import copy
|
| 4 |
+
import logging
|
| 5 |
+
import numpy as np
|
| 6 |
+
import operator
|
| 7 |
+
import torch
|
| 8 |
+
import torch.utils.data
|
| 9 |
+
import json
|
| 10 |
+
from detectron2.utils.comm import get_world_size
|
| 11 |
+
from detectron2.utils.logger import _log_api_usage, log_first_n
|
| 12 |
+
|
| 13 |
+
from detectron2.config import configurable
|
| 14 |
+
from detectron2.data import samplers
|
| 15 |
+
from torch.utils.data.sampler import BatchSampler, Sampler
|
| 16 |
+
from detectron2.data.common import DatasetFromList, MapDataset
|
| 17 |
+
from detectron2.data.dataset_mapper import DatasetMapper
|
| 18 |
+
from detectron2.data.build import get_detection_dataset_dicts, build_batch_data_loader
|
| 19 |
+
from detectron2.data.samplers import TrainingSampler, RepeatFactorTrainingSampler
|
| 20 |
+
from detectron2.data.build import worker_init_reset_seed, print_instances_class_histogram
|
| 21 |
+
from detectron2.data.build import filter_images_with_only_crowd_annotations
|
| 22 |
+
from detectron2.data.build import filter_images_with_few_keypoints
|
| 23 |
+
from detectron2.data.build import check_metadata_consistency
|
| 24 |
+
from detectron2.data.catalog import MetadataCatalog, DatasetCatalog
|
| 25 |
+
from detectron2.utils import comm
|
| 26 |
+
import itertools
|
| 27 |
+
import math
|
| 28 |
+
from collections import defaultdict
|
| 29 |
+
from typing import Optional
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _custom_train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
|
| 33 |
+
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
|
| 34 |
+
if 'MultiDataset' in sampler_name:
|
| 35 |
+
dataset_dicts = get_detection_dataset_dicts_with_source(
|
| 36 |
+
cfg.DATASETS.TRAIN,
|
| 37 |
+
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
|
| 38 |
+
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
|
| 39 |
+
if cfg.MODEL.KEYPOINT_ON else 0,
|
| 40 |
+
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
|
| 41 |
+
)
|
| 42 |
+
else:
|
| 43 |
+
dataset_dicts = get_detection_dataset_dicts(
|
| 44 |
+
cfg.DATASETS.TRAIN,
|
| 45 |
+
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
|
| 46 |
+
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
|
| 47 |
+
if cfg.MODEL.KEYPOINT_ON else 0,
|
| 48 |
+
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
if mapper is None:
|
| 52 |
+
mapper = DatasetMapper(cfg, True)
|
| 53 |
+
|
| 54 |
+
if sampler is not None:
|
| 55 |
+
pass
|
| 56 |
+
elif sampler_name == "TrainingSampler":
|
| 57 |
+
sampler = TrainingSampler(len(dataset))
|
| 58 |
+
elif sampler_name == "MultiDatasetSampler":
|
| 59 |
+
sampler = MultiDatasetSampler(
|
| 60 |
+
dataset_dicts,
|
| 61 |
+
dataset_ratio = cfg.DATALOADER.DATASET_RATIO,
|
| 62 |
+
use_rfs = cfg.DATALOADER.USE_RFS,
|
| 63 |
+
dataset_ann = cfg.DATALOADER.DATASET_ANN,
|
| 64 |
+
repeat_threshold = cfg.DATALOADER.REPEAT_THRESHOLD,
|
| 65 |
+
)
|
| 66 |
+
elif sampler_name == "RepeatFactorTrainingSampler":
|
| 67 |
+
repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
|
| 68 |
+
dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD
|
| 69 |
+
)
|
| 70 |
+
sampler = RepeatFactorTrainingSampler(repeat_factors)
|
| 71 |
+
else:
|
| 72 |
+
raise ValueError("Unknown training sampler: {}".format(sampler_name))
|
| 73 |
+
|
| 74 |
+
return {
|
| 75 |
+
"dataset": dataset_dicts,
|
| 76 |
+
"sampler": sampler,
|
| 77 |
+
"mapper": mapper,
|
| 78 |
+
"total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
|
| 79 |
+
"aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
|
| 80 |
+
"num_workers": cfg.DATALOADER.NUM_WORKERS,
|
| 81 |
+
'multi_dataset_grouping': cfg.DATALOADER.MULTI_DATASET_GROUPING,
|
| 82 |
+
'use_diff_bs_size': cfg.DATALOADER.USE_DIFF_BS_SIZE,
|
| 83 |
+
'dataset_bs': cfg.DATALOADER.DATASET_BS,
|
| 84 |
+
'num_datasets': len(cfg.DATASETS.TRAIN)
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@configurable(from_config=_custom_train_loader_from_config)
|
| 89 |
+
def build_custom_train_loader(
|
| 90 |
+
dataset, *, mapper, sampler,
|
| 91 |
+
total_batch_size=16,
|
| 92 |
+
aspect_ratio_grouping=True,
|
| 93 |
+
num_workers=0,
|
| 94 |
+
num_datasets=1,
|
| 95 |
+
multi_dataset_grouping=False,
|
| 96 |
+
use_diff_bs_size=False,
|
| 97 |
+
dataset_bs=[]
|
| 98 |
+
):
|
| 99 |
+
"""
|
| 100 |
+
Modified from detectron2.data.build.build_custom_train_loader, but supports
|
| 101 |
+
different samplers
|
| 102 |
+
"""
|
| 103 |
+
if isinstance(dataset, list):
|
| 104 |
+
dataset = DatasetFromList(dataset, copy=False)
|
| 105 |
+
if mapper is not None:
|
| 106 |
+
dataset = MapDataset(dataset, mapper)
|
| 107 |
+
if sampler is None:
|
| 108 |
+
sampler = TrainingSampler(len(dataset))
|
| 109 |
+
assert isinstance(sampler, torch.utils.data.sampler.Sampler)
|
| 110 |
+
if multi_dataset_grouping:
|
| 111 |
+
return build_multi_dataset_batch_data_loader(
|
| 112 |
+
use_diff_bs_size,
|
| 113 |
+
dataset_bs,
|
| 114 |
+
dataset,
|
| 115 |
+
sampler,
|
| 116 |
+
total_batch_size,
|
| 117 |
+
num_datasets=num_datasets,
|
| 118 |
+
num_workers=num_workers,
|
| 119 |
+
)
|
| 120 |
+
else:
|
| 121 |
+
return build_batch_data_loader(
|
| 122 |
+
dataset,
|
| 123 |
+
sampler,
|
| 124 |
+
total_batch_size,
|
| 125 |
+
aspect_ratio_grouping=aspect_ratio_grouping,
|
| 126 |
+
num_workers=num_workers,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def build_multi_dataset_batch_data_loader(
|
| 131 |
+
use_diff_bs_size, dataset_bs,
|
| 132 |
+
dataset, sampler, total_batch_size, num_datasets, num_workers=0
|
| 133 |
+
):
|
| 134 |
+
"""
|
| 135 |
+
"""
|
| 136 |
+
world_size = get_world_size()
|
| 137 |
+
assert (
|
| 138 |
+
total_batch_size > 0 and total_batch_size % world_size == 0
|
| 139 |
+
), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
|
| 140 |
+
total_batch_size, world_size
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
batch_size = total_batch_size // world_size
|
| 144 |
+
data_loader = torch.utils.data.DataLoader(
|
| 145 |
+
dataset,
|
| 146 |
+
sampler=sampler,
|
| 147 |
+
num_workers=num_workers,
|
| 148 |
+
batch_sampler=None,
|
| 149 |
+
collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
|
| 150 |
+
worker_init_fn=worker_init_reset_seed,
|
| 151 |
+
) # yield individual mapped dict
|
| 152 |
+
if use_diff_bs_size:
|
| 153 |
+
return DIFFMDAspectRatioGroupedDataset(
|
| 154 |
+
data_loader, dataset_bs, num_datasets)
|
| 155 |
+
else:
|
| 156 |
+
return MDAspectRatioGroupedDataset(
|
| 157 |
+
data_loader, batch_size, num_datasets)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def get_detection_dataset_dicts_with_source(
|
| 161 |
+
dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None
|
| 162 |
+
):
|
| 163 |
+
assert len(dataset_names)
|
| 164 |
+
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]
|
| 165 |
+
for dataset_name, dicts in zip(dataset_names, dataset_dicts):
|
| 166 |
+
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
| 167 |
+
|
| 168 |
+
for source_id, (dataset_name, dicts) in \
|
| 169 |
+
enumerate(zip(dataset_names, dataset_dicts)):
|
| 170 |
+
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
| 171 |
+
for d in dicts:
|
| 172 |
+
d['dataset_source'] = source_id
|
| 173 |
+
|
| 174 |
+
if "annotations" in dicts[0]:
|
| 175 |
+
try:
|
| 176 |
+
class_names = MetadataCatalog.get(dataset_name).thing_classes
|
| 177 |
+
check_metadata_consistency("thing_classes", dataset_name)
|
| 178 |
+
print_instances_class_histogram(dicts, class_names)
|
| 179 |
+
except AttributeError: # class names are not available for this dataset
|
| 180 |
+
pass
|
| 181 |
+
|
| 182 |
+
assert proposal_files is None
|
| 183 |
+
|
| 184 |
+
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
|
| 185 |
+
|
| 186 |
+
has_instances = "annotations" in dataset_dicts[0]
|
| 187 |
+
if filter_empty and has_instances:
|
| 188 |
+
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
|
| 189 |
+
if min_keypoints > 0 and has_instances:
|
| 190 |
+
dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
|
| 191 |
+
|
| 192 |
+
return dataset_dicts
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class MultiDatasetSampler(Sampler):
|
| 196 |
+
def __init__(
|
| 197 |
+
self,
|
| 198 |
+
dataset_dicts,
|
| 199 |
+
dataset_ratio,
|
| 200 |
+
use_rfs,
|
| 201 |
+
dataset_ann,
|
| 202 |
+
repeat_threshold=0.001,
|
| 203 |
+
seed: Optional[int] = None,
|
| 204 |
+
):
|
| 205 |
+
"""
|
| 206 |
+
"""
|
| 207 |
+
sizes = [0 for _ in range(len(dataset_ratio))]
|
| 208 |
+
for d in dataset_dicts:
|
| 209 |
+
sizes[d['dataset_source']] += 1
|
| 210 |
+
print('dataset sizes', sizes)
|
| 211 |
+
self.sizes = sizes
|
| 212 |
+
assert len(dataset_ratio) == len(sizes), \
|
| 213 |
+
'length of dataset ratio {} should be equal to number if dataset {}'.format(
|
| 214 |
+
len(dataset_ratio), len(sizes)
|
| 215 |
+
)
|
| 216 |
+
if seed is None:
|
| 217 |
+
seed = comm.shared_random_seed()
|
| 218 |
+
self._seed = int(seed)
|
| 219 |
+
self._rank = comm.get_rank()
|
| 220 |
+
self._world_size = comm.get_world_size()
|
| 221 |
+
|
| 222 |
+
self.dataset_ids = torch.tensor(
|
| 223 |
+
[d['dataset_source'] for d in dataset_dicts], dtype=torch.long)
|
| 224 |
+
|
| 225 |
+
dataset_weight = [torch.ones(s) * max(sizes) / s * r / sum(dataset_ratio) \
|
| 226 |
+
for i, (r, s) in enumerate(zip(dataset_ratio, sizes))]
|
| 227 |
+
dataset_weight = torch.cat(dataset_weight)
|
| 228 |
+
|
| 229 |
+
rfs_factors = []
|
| 230 |
+
st = 0
|
| 231 |
+
for i, s in enumerate(sizes):
|
| 232 |
+
if use_rfs[i]:
|
| 233 |
+
if dataset_ann[i] == 'box':
|
| 234 |
+
rfs_func = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency
|
| 235 |
+
else:
|
| 236 |
+
rfs_func = repeat_factors_from_tag_frequency
|
| 237 |
+
rfs_factor = rfs_func(
|
| 238 |
+
dataset_dicts[st: st + s],
|
| 239 |
+
repeat_thresh=repeat_threshold)
|
| 240 |
+
rfs_factor = rfs_factor * (s / rfs_factor.sum())
|
| 241 |
+
else:
|
| 242 |
+
rfs_factor = torch.ones(s)
|
| 243 |
+
rfs_factors.append(rfs_factor)
|
| 244 |
+
st = st + s
|
| 245 |
+
rfs_factors = torch.cat(rfs_factors)
|
| 246 |
+
|
| 247 |
+
self.weights = dataset_weight * rfs_factors
|
| 248 |
+
self.sample_epoch_size = len(self.weights)
|
| 249 |
+
|
| 250 |
+
def __iter__(self):
|
| 251 |
+
start = self._rank
|
| 252 |
+
yield from itertools.islice(
|
| 253 |
+
self._infinite_indices(), start, None, self._world_size)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def _infinite_indices(self):
|
| 257 |
+
g = torch.Generator()
|
| 258 |
+
g.manual_seed(self._seed)
|
| 259 |
+
while True:
|
| 260 |
+
ids = torch.multinomial(
|
| 261 |
+
self.weights, self.sample_epoch_size, generator=g,
|
| 262 |
+
replacement=True)
|
| 263 |
+
nums = [(self.dataset_ids[ids] == i).sum().int().item() \
|
| 264 |
+
for i in range(len(self.sizes))]
|
| 265 |
+
yield from ids
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class MDAspectRatioGroupedDataset(torch.utils.data.IterableDataset):
|
| 269 |
+
def __init__(self, dataset, batch_size, num_datasets):
|
| 270 |
+
"""
|
| 271 |
+
"""
|
| 272 |
+
self.dataset = dataset
|
| 273 |
+
self.batch_size = batch_size
|
| 274 |
+
self._buckets = [[] for _ in range(2 * num_datasets)]
|
| 275 |
+
|
| 276 |
+
def __iter__(self):
|
| 277 |
+
for d in self.dataset:
|
| 278 |
+
w, h = d["width"], d["height"]
|
| 279 |
+
aspect_ratio_bucket_id = 0 if w > h else 1
|
| 280 |
+
bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id
|
| 281 |
+
bucket = self._buckets[bucket_id]
|
| 282 |
+
bucket.append(d)
|
| 283 |
+
if len(bucket) == self.batch_size:
|
| 284 |
+
yield bucket[:]
|
| 285 |
+
del bucket[:]
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class DIFFMDAspectRatioGroupedDataset(torch.utils.data.IterableDataset):
|
| 289 |
+
def __init__(self, dataset, batch_sizes, num_datasets):
|
| 290 |
+
"""
|
| 291 |
+
"""
|
| 292 |
+
self.dataset = dataset
|
| 293 |
+
self.batch_sizes = batch_sizes
|
| 294 |
+
self._buckets = [[] for _ in range(2 * num_datasets)]
|
| 295 |
+
|
| 296 |
+
def __iter__(self):
|
| 297 |
+
for d in self.dataset:
|
| 298 |
+
w, h = d["width"], d["height"]
|
| 299 |
+
aspect_ratio_bucket_id = 0 if w > h else 1
|
| 300 |
+
bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id
|
| 301 |
+
bucket = self._buckets[bucket_id]
|
| 302 |
+
bucket.append(d)
|
| 303 |
+
if len(bucket) == self.batch_sizes[d['dataset_source']]:
|
| 304 |
+
yield bucket[:]
|
| 305 |
+
del bucket[:]
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def repeat_factors_from_tag_frequency(dataset_dicts, repeat_thresh):
|
| 309 |
+
"""
|
| 310 |
+
"""
|
| 311 |
+
category_freq = defaultdict(int)
|
| 312 |
+
for dataset_dict in dataset_dicts:
|
| 313 |
+
cat_ids = dataset_dict['pos_category_ids']
|
| 314 |
+
for cat_id in cat_ids:
|
| 315 |
+
category_freq[cat_id] += 1
|
| 316 |
+
num_images = len(dataset_dicts)
|
| 317 |
+
for k, v in category_freq.items():
|
| 318 |
+
category_freq[k] = v / num_images
|
| 319 |
+
|
| 320 |
+
category_rep = {
|
| 321 |
+
cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq))
|
| 322 |
+
for cat_id, cat_freq in category_freq.items()
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
rep_factors = []
|
| 326 |
+
for dataset_dict in dataset_dicts:
|
| 327 |
+
cat_ids = dataset_dict['pos_category_ids']
|
| 328 |
+
rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0)
|
| 329 |
+
rep_factors.append(rep_factor)
|
| 330 |
+
|
| 331 |
+
return torch.tensor(rep_factors, dtype=torch.float32)
|
proxydet/data/custom_dataset_mapper.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import copy
|
| 3 |
+
import logging
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import List, Optional, Union
|
| 6 |
+
import torch
|
| 7 |
+
import pycocotools.mask as mask_util
|
| 8 |
+
|
| 9 |
+
from detectron2.config import configurable
|
| 10 |
+
|
| 11 |
+
from detectron2.data import detection_utils as utils
|
| 12 |
+
from detectron2.data.detection_utils import transform_keypoint_annotations
|
| 13 |
+
from detectron2.data import transforms as T
|
| 14 |
+
from detectron2.data.dataset_mapper import DatasetMapper
|
| 15 |
+
from detectron2.structures import Boxes, BoxMode, Instances
|
| 16 |
+
from detectron2.structures import Keypoints, PolygonMasks, BitMasks
|
| 17 |
+
from fvcore.transforms.transform import TransformList
|
| 18 |
+
from .custom_build_augmentation import build_custom_augmentation
|
| 19 |
+
from .tar_dataset import DiskTarDataset
|
| 20 |
+
|
| 21 |
+
__all__ = ["CustomDatasetMapper"]
|
| 22 |
+
|
| 23 |
+
class CustomDatasetMapper(DatasetMapper):
|
| 24 |
+
@configurable
|
| 25 |
+
def __init__(self, is_train: bool,
|
| 26 |
+
with_ann_type=False,
|
| 27 |
+
dataset_ann=[],
|
| 28 |
+
use_diff_bs_size=False,
|
| 29 |
+
dataset_augs=[],
|
| 30 |
+
is_debug=False,
|
| 31 |
+
use_tar_dataset=False,
|
| 32 |
+
tarfile_path='',
|
| 33 |
+
tar_index_dir='',
|
| 34 |
+
**kwargs):
|
| 35 |
+
"""
|
| 36 |
+
add image labels
|
| 37 |
+
"""
|
| 38 |
+
self.with_ann_type = with_ann_type
|
| 39 |
+
self.dataset_ann = dataset_ann
|
| 40 |
+
self.use_diff_bs_size = use_diff_bs_size
|
| 41 |
+
if self.use_diff_bs_size and is_train:
|
| 42 |
+
self.dataset_augs = [T.AugmentationList(x) for x in dataset_augs]
|
| 43 |
+
self.is_debug = is_debug
|
| 44 |
+
self.use_tar_dataset = use_tar_dataset
|
| 45 |
+
if self.use_tar_dataset:
|
| 46 |
+
print('Using tar dataset')
|
| 47 |
+
self.tar_dataset = DiskTarDataset(tarfile_path, tar_index_dir)
|
| 48 |
+
super().__init__(is_train, **kwargs)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@classmethod
|
| 52 |
+
def from_config(cls, cfg, is_train: bool = True):
|
| 53 |
+
ret = super().from_config(cfg, is_train)
|
| 54 |
+
ret.update({
|
| 55 |
+
'with_ann_type': cfg.WITH_IMAGE_LABELS,
|
| 56 |
+
'dataset_ann': cfg.DATALOADER.DATASET_ANN,
|
| 57 |
+
'use_diff_bs_size': cfg.DATALOADER.USE_DIFF_BS_SIZE,
|
| 58 |
+
'is_debug': cfg.IS_DEBUG,
|
| 59 |
+
'use_tar_dataset': cfg.DATALOADER.USE_TAR_DATASET,
|
| 60 |
+
'tarfile_path': cfg.DATALOADER.TARFILE_PATH,
|
| 61 |
+
'tar_index_dir': cfg.DATALOADER.TAR_INDEX_DIR,
|
| 62 |
+
})
|
| 63 |
+
if ret['use_diff_bs_size'] and is_train:
|
| 64 |
+
if cfg.INPUT.CUSTOM_AUG == 'EfficientDetResizeCrop':
|
| 65 |
+
dataset_scales = cfg.DATALOADER.DATASET_INPUT_SCALE
|
| 66 |
+
dataset_sizes = cfg.DATALOADER.DATASET_INPUT_SIZE
|
| 67 |
+
ret['dataset_augs'] = [
|
| 68 |
+
build_custom_augmentation(cfg, True, scale, size) \
|
| 69 |
+
for scale, size in zip(dataset_scales, dataset_sizes)]
|
| 70 |
+
else:
|
| 71 |
+
assert cfg.INPUT.CUSTOM_AUG == 'ResizeShortestEdge'
|
| 72 |
+
min_sizes = cfg.DATALOADER.DATASET_MIN_SIZES
|
| 73 |
+
max_sizes = cfg.DATALOADER.DATASET_MAX_SIZES
|
| 74 |
+
ret['dataset_augs'] = [
|
| 75 |
+
build_custom_augmentation(
|
| 76 |
+
cfg, True, min_size=mi, max_size=ma) \
|
| 77 |
+
for mi, ma in zip(min_sizes, max_sizes)]
|
| 78 |
+
else:
|
| 79 |
+
ret['dataset_augs'] = []
|
| 80 |
+
|
| 81 |
+
return ret
|
| 82 |
+
|
| 83 |
+
def __call__(self, dataset_dict):
|
| 84 |
+
"""
|
| 85 |
+
include image labels
|
| 86 |
+
"""
|
| 87 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
| 88 |
+
# USER: Write your own image loading if it's not from a file
|
| 89 |
+
if 'file_name' in dataset_dict:
|
| 90 |
+
ori_image = utils.read_image(
|
| 91 |
+
dataset_dict["file_name"], format=self.image_format)
|
| 92 |
+
else:
|
| 93 |
+
ori_image, _, _ = self.tar_dataset[dataset_dict["tar_index"]]
|
| 94 |
+
ori_image = utils._apply_exif_orientation(ori_image)
|
| 95 |
+
ori_image = utils.convert_PIL_to_numpy(ori_image, self.image_format)
|
| 96 |
+
utils.check_image_size(dataset_dict, ori_image)
|
| 97 |
+
|
| 98 |
+
# USER: Remove if you don't do semantic/panoptic segmentation.
|
| 99 |
+
if "sem_seg_file_name" in dataset_dict:
|
| 100 |
+
sem_seg_gt = utils.read_image(
|
| 101 |
+
dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
|
| 102 |
+
else:
|
| 103 |
+
sem_seg_gt = None
|
| 104 |
+
|
| 105 |
+
if self.is_debug:
|
| 106 |
+
dataset_dict['dataset_source'] = 0
|
| 107 |
+
|
| 108 |
+
not_full_labeled = 'dataset_source' in dataset_dict and \
|
| 109 |
+
self.with_ann_type and \
|
| 110 |
+
self.dataset_ann[dataset_dict['dataset_source']] != 'box'
|
| 111 |
+
|
| 112 |
+
aug_input = T.AugInput(copy.deepcopy(ori_image), sem_seg=sem_seg_gt)
|
| 113 |
+
if self.use_diff_bs_size and self.is_train:
|
| 114 |
+
transforms = \
|
| 115 |
+
self.dataset_augs[dataset_dict['dataset_source']](aug_input)
|
| 116 |
+
else:
|
| 117 |
+
transforms = self.augmentations(aug_input)
|
| 118 |
+
image, sem_seg_gt = aug_input.image, aug_input.sem_seg
|
| 119 |
+
|
| 120 |
+
image_shape = image.shape[:2] # h, w
|
| 121 |
+
dataset_dict["image"] = torch.as_tensor(
|
| 122 |
+
np.ascontiguousarray(image.transpose(2, 0, 1)))
|
| 123 |
+
|
| 124 |
+
if sem_seg_gt is not None:
|
| 125 |
+
dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
|
| 126 |
+
|
| 127 |
+
# USER: Remove if you don't use pre-computed proposals.
|
| 128 |
+
# Most users would not need this feature.
|
| 129 |
+
if self.proposal_topk is not None:
|
| 130 |
+
utils.transform_proposals(
|
| 131 |
+
dataset_dict, image_shape, transforms,
|
| 132 |
+
proposal_topk=self.proposal_topk
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
if not self.is_train:
|
| 136 |
+
# USER: Modify this if you want to keep them for some reason.
|
| 137 |
+
dataset_dict.pop("annotations", None)
|
| 138 |
+
dataset_dict.pop("sem_seg_file_name", None)
|
| 139 |
+
return dataset_dict
|
| 140 |
+
|
| 141 |
+
if "annotations" in dataset_dict:
|
| 142 |
+
# USER: Modify this if you want to keep them for some reason.
|
| 143 |
+
for anno in dataset_dict["annotations"]:
|
| 144 |
+
if not self.use_instance_mask:
|
| 145 |
+
anno.pop("segmentation", None)
|
| 146 |
+
if not self.use_keypoint:
|
| 147 |
+
anno.pop("keypoints", None)
|
| 148 |
+
|
| 149 |
+
# USER: Implement additional transformations if you have other types of data
|
| 150 |
+
all_annos = [
|
| 151 |
+
(utils.transform_instance_annotations(
|
| 152 |
+
obj, transforms, image_shape,
|
| 153 |
+
keypoint_hflip_indices=self.keypoint_hflip_indices,
|
| 154 |
+
), obj.get("iscrowd", 0))
|
| 155 |
+
for obj in dataset_dict.pop("annotations")
|
| 156 |
+
]
|
| 157 |
+
annos = [ann[0] for ann in all_annos if ann[1] == 0]
|
| 158 |
+
instances = utils.annotations_to_instances(
|
| 159 |
+
annos, image_shape, mask_format=self.instance_mask_format
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
del all_annos
|
| 163 |
+
if self.recompute_boxes:
|
| 164 |
+
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
|
| 165 |
+
dataset_dict["instances"] = utils.filter_empty_instances(instances)
|
| 166 |
+
if self.with_ann_type:
|
| 167 |
+
dataset_dict["pos_category_ids"] = dataset_dict.get(
|
| 168 |
+
'pos_category_ids', [])
|
| 169 |
+
dataset_dict["ann_type"] = \
|
| 170 |
+
self.dataset_ann[dataset_dict['dataset_source']]
|
| 171 |
+
if self.is_debug and (('pos_category_ids' not in dataset_dict) or \
|
| 172 |
+
(dataset_dict['pos_category_ids'] == [])):
|
| 173 |
+
dataset_dict['pos_category_ids'] = [x for x in sorted(set(
|
| 174 |
+
dataset_dict['instances'].gt_classes.tolist()
|
| 175 |
+
))]
|
| 176 |
+
return dataset_dict
|
| 177 |
+
|
| 178 |
+
# DETR augmentation
|
| 179 |
+
def build_transform_gen(cfg, is_train):
|
| 180 |
+
"""
|
| 181 |
+
"""
|
| 182 |
+
if is_train:
|
| 183 |
+
min_size = cfg.INPUT.MIN_SIZE_TRAIN
|
| 184 |
+
max_size = cfg.INPUT.MAX_SIZE_TRAIN
|
| 185 |
+
sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
|
| 186 |
+
else:
|
| 187 |
+
min_size = cfg.INPUT.MIN_SIZE_TEST
|
| 188 |
+
max_size = cfg.INPUT.MAX_SIZE_TEST
|
| 189 |
+
sample_style = "choice"
|
| 190 |
+
if sample_style == "range":
|
| 191 |
+
assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size))
|
| 192 |
+
|
| 193 |
+
logger = logging.getLogger(__name__)
|
| 194 |
+
tfm_gens = []
|
| 195 |
+
if is_train:
|
| 196 |
+
tfm_gens.append(T.RandomFlip())
|
| 197 |
+
tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
|
| 198 |
+
if is_train:
|
| 199 |
+
logger.info("TransformGens used in training: " + str(tfm_gens))
|
| 200 |
+
return tfm_gens
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class DetrDatasetMapper:
|
| 204 |
+
"""
|
| 205 |
+
A callable which takes a dataset dict in Detectron2 Dataset format,
|
| 206 |
+
and map it into a format used by DETR.
|
| 207 |
+
The callable currently does the following:
|
| 208 |
+
1. Read the image from "file_name"
|
| 209 |
+
2. Applies geometric transforms to the image and annotation
|
| 210 |
+
3. Find and applies suitable cropping to the image and annotation
|
| 211 |
+
4. Prepare image and annotation to Tensors
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
def __init__(self, cfg, is_train=True):
|
| 215 |
+
if cfg.INPUT.CROP.ENABLED and is_train:
|
| 216 |
+
self.crop_gen = [
|
| 217 |
+
T.ResizeShortestEdge([400, 500, 600], sample_style="choice"),
|
| 218 |
+
T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE),
|
| 219 |
+
]
|
| 220 |
+
else:
|
| 221 |
+
self.crop_gen = None
|
| 222 |
+
|
| 223 |
+
self.mask_on = cfg.MODEL.MASK_ON
|
| 224 |
+
self.tfm_gens = build_transform_gen(cfg, is_train)
|
| 225 |
+
logging.getLogger(__name__).info(
|
| 226 |
+
"Full TransformGens used in training: {}, crop: {}".format(str(self.tfm_gens), str(self.crop_gen))
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
self.img_format = cfg.INPUT.FORMAT
|
| 230 |
+
self.is_train = is_train
|
| 231 |
+
|
| 232 |
+
def __call__(self, dataset_dict):
|
| 233 |
+
"""
|
| 234 |
+
Args:
|
| 235 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
| 236 |
+
Returns:
|
| 237 |
+
dict: a format that builtin models in detectron2 accept
|
| 238 |
+
"""
|
| 239 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
| 240 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
| 241 |
+
utils.check_image_size(dataset_dict, image)
|
| 242 |
+
|
| 243 |
+
if self.crop_gen is None:
|
| 244 |
+
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
|
| 245 |
+
else:
|
| 246 |
+
if np.random.rand() > 0.5:
|
| 247 |
+
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
|
| 248 |
+
else:
|
| 249 |
+
image, transforms = T.apply_transform_gens(
|
| 250 |
+
self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
image_shape = image.shape[:2] # h, w
|
| 254 |
+
|
| 255 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
| 256 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
| 257 |
+
# Therefore it's important to use torch.Tensor.
|
| 258 |
+
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
| 259 |
+
|
| 260 |
+
if not self.is_train:
|
| 261 |
+
# USER: Modify this if you want to keep them for some reason.
|
| 262 |
+
dataset_dict.pop("annotations", None)
|
| 263 |
+
return dataset_dict
|
| 264 |
+
|
| 265 |
+
if "annotations" in dataset_dict:
|
| 266 |
+
# USER: Modify this if you want to keep them for some reason.
|
| 267 |
+
for anno in dataset_dict["annotations"]:
|
| 268 |
+
if not self.mask_on:
|
| 269 |
+
anno.pop("segmentation", None)
|
| 270 |
+
anno.pop("keypoints", None)
|
| 271 |
+
|
| 272 |
+
# USER: Implement additional transformations if you have other types of data
|
| 273 |
+
annos = [
|
| 274 |
+
utils.transform_instance_annotations(obj, transforms, image_shape)
|
| 275 |
+
for obj in dataset_dict.pop("annotations")
|
| 276 |
+
if obj.get("iscrowd", 0) == 0
|
| 277 |
+
]
|
| 278 |
+
instances = utils.annotations_to_instances(annos, image_shape)
|
| 279 |
+
dataset_dict["instances"] = utils.filter_empty_instances(instances)
|
| 280 |
+
return dataset_dict
|
proxydet/data/datasets/cc.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
from detectron2.data.datasets.builtin_meta import _get_builtin_metadata
|
| 6 |
+
from detectron2.data.datasets.lvis import get_lvis_instances_meta
|
| 7 |
+
from .lvis_v1 import custom_register_lvis_instances
|
| 8 |
+
|
| 9 |
+
_CUSTOM_SPLITS = {
|
| 10 |
+
"cc3m_v1_val": ("cc3m/validation/", "cc3m/val_image_info.json"),
|
| 11 |
+
"cc3m_v1_train": ("cc3m/training/", "cc3m/train_image_info.json"),
|
| 12 |
+
"cc3m_v1_train_tags": ("cc3m/training/", "cc3m/train_image_info_tags.json"),
|
| 13 |
+
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
for key, (image_root, json_file) in _CUSTOM_SPLITS.items():
|
| 17 |
+
custom_register_lvis_instances(
|
| 18 |
+
key,
|
| 19 |
+
get_lvis_instances_meta('lvis_v1'),
|
| 20 |
+
os.path.join("datasets", json_file) if "://" not in json_file else json_file,
|
| 21 |
+
os.path.join("datasets", image_root),
|
| 22 |
+
)
|
| 23 |
+
|
proxydet/data/datasets/coco_zeroshot.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from detectron2.data.datasets.register_coco import register_coco_instances
|
| 5 |
+
from detectron2.data.datasets.builtin_meta import _get_builtin_metadata
|
| 6 |
+
from .lvis_v1 import custom_register_lvis_instances
|
| 7 |
+
|
| 8 |
+
categories_seen = [
|
| 9 |
+
{'id': 1, 'name': 'person'},
|
| 10 |
+
{'id': 2, 'name': 'bicycle'},
|
| 11 |
+
{'id': 3, 'name': 'car'},
|
| 12 |
+
{'id': 4, 'name': 'motorcycle'},
|
| 13 |
+
{'id': 7, 'name': 'train'},
|
| 14 |
+
{'id': 8, 'name': 'truck'},
|
| 15 |
+
{'id': 9, 'name': 'boat'},
|
| 16 |
+
{'id': 15, 'name': 'bench'},
|
| 17 |
+
{'id': 16, 'name': 'bird'},
|
| 18 |
+
{'id': 19, 'name': 'horse'},
|
| 19 |
+
{'id': 20, 'name': 'sheep'},
|
| 20 |
+
{'id': 23, 'name': 'bear'},
|
| 21 |
+
{'id': 24, 'name': 'zebra'},
|
| 22 |
+
{'id': 25, 'name': 'giraffe'},
|
| 23 |
+
{'id': 27, 'name': 'backpack'},
|
| 24 |
+
{'id': 31, 'name': 'handbag'},
|
| 25 |
+
{'id': 33, 'name': 'suitcase'},
|
| 26 |
+
{'id': 34, 'name': 'frisbee'},
|
| 27 |
+
{'id': 35, 'name': 'skis'},
|
| 28 |
+
{'id': 38, 'name': 'kite'},
|
| 29 |
+
{'id': 42, 'name': 'surfboard'},
|
| 30 |
+
{'id': 44, 'name': 'bottle'},
|
| 31 |
+
{'id': 48, 'name': 'fork'},
|
| 32 |
+
{'id': 50, 'name': 'spoon'},
|
| 33 |
+
{'id': 51, 'name': 'bowl'},
|
| 34 |
+
{'id': 52, 'name': 'banana'},
|
| 35 |
+
{'id': 53, 'name': 'apple'},
|
| 36 |
+
{'id': 54, 'name': 'sandwich'},
|
| 37 |
+
{'id': 55, 'name': 'orange'},
|
| 38 |
+
{'id': 56, 'name': 'broccoli'},
|
| 39 |
+
{'id': 57, 'name': 'carrot'},
|
| 40 |
+
{'id': 59, 'name': 'pizza'},
|
| 41 |
+
{'id': 60, 'name': 'donut'},
|
| 42 |
+
{'id': 62, 'name': 'chair'},
|
| 43 |
+
{'id': 65, 'name': 'bed'},
|
| 44 |
+
{'id': 70, 'name': 'toilet'},
|
| 45 |
+
{'id': 72, 'name': 'tv'},
|
| 46 |
+
{'id': 73, 'name': 'laptop'},
|
| 47 |
+
{'id': 74, 'name': 'mouse'},
|
| 48 |
+
{'id': 75, 'name': 'remote'},
|
| 49 |
+
{'id': 78, 'name': 'microwave'},
|
| 50 |
+
{'id': 79, 'name': 'oven'},
|
| 51 |
+
{'id': 80, 'name': 'toaster'},
|
| 52 |
+
{'id': 82, 'name': 'refrigerator'},
|
| 53 |
+
{'id': 84, 'name': 'book'},
|
| 54 |
+
{'id': 85, 'name': 'clock'},
|
| 55 |
+
{'id': 86, 'name': 'vase'},
|
| 56 |
+
{'id': 90, 'name': 'toothbrush'},
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
categories_unseen = [
|
| 60 |
+
{'id': 5, 'name': 'airplane'},
|
| 61 |
+
{'id': 6, 'name': 'bus'},
|
| 62 |
+
{'id': 17, 'name': 'cat'},
|
| 63 |
+
{'id': 18, 'name': 'dog'},
|
| 64 |
+
{'id': 21, 'name': 'cow'},
|
| 65 |
+
{'id': 22, 'name': 'elephant'},
|
| 66 |
+
{'id': 28, 'name': 'umbrella'},
|
| 67 |
+
{'id': 32, 'name': 'tie'},
|
| 68 |
+
{'id': 36, 'name': 'snowboard'},
|
| 69 |
+
{'id': 41, 'name': 'skateboard'},
|
| 70 |
+
{'id': 47, 'name': 'cup'},
|
| 71 |
+
{'id': 49, 'name': 'knife'},
|
| 72 |
+
{'id': 61, 'name': 'cake'},
|
| 73 |
+
{'id': 63, 'name': 'couch'},
|
| 74 |
+
{'id': 76, 'name': 'keyboard'},
|
| 75 |
+
{'id': 81, 'name': 'sink'},
|
| 76 |
+
{'id': 87, 'name': 'scissors'},
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
def _get_metadata(cat):
|
| 80 |
+
if cat == 'all':
|
| 81 |
+
return _get_builtin_metadata('coco')
|
| 82 |
+
elif cat == 'seen':
|
| 83 |
+
id_to_name = {x['id']: x['name'] for x in categories_seen}
|
| 84 |
+
else:
|
| 85 |
+
assert cat == 'unseen'
|
| 86 |
+
id_to_name = {x['id']: x['name'] for x in categories_unseen}
|
| 87 |
+
|
| 88 |
+
thing_dataset_id_to_contiguous_id = {
|
| 89 |
+
x: i for i, x in enumerate(sorted(id_to_name))}
|
| 90 |
+
thing_classes = [id_to_name[k] for k in sorted(id_to_name)]
|
| 91 |
+
return {
|
| 92 |
+
"thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
|
| 93 |
+
"thing_classes": thing_classes}
|
| 94 |
+
|
| 95 |
+
_PREDEFINED_SPLITS_COCO = {
|
| 96 |
+
"coco_zeroshot_train": ("coco/train2017", "coco/zero-shot/instances_train2017_seen_2.json", 'seen'),
|
| 97 |
+
"coco_zeroshot_val": ("coco/val2017", "coco/zero-shot/instances_val2017_unseen_2.json", 'unseen'),
|
| 98 |
+
"coco_not_zeroshot_val": ("coco/val2017", "coco/zero-shot/instances_val2017_seen_2.json", 'seen'),
|
| 99 |
+
"coco_generalized_zeroshot_val": ("coco/val2017", "coco/zero-shot/instances_val2017_all_2_oriorder.json", 'all'),
|
| 100 |
+
"coco_zeroshot_train_oriorder": ("coco/train2017", "coco/zero-shot/instances_train2017_seen_2_oriorder.json", 'all'),
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
for key, (image_root, json_file, cat) in _PREDEFINED_SPLITS_COCO.items():
|
| 104 |
+
register_coco_instances(
|
| 105 |
+
key,
|
| 106 |
+
_get_metadata(cat),
|
| 107 |
+
os.path.join("datasets", json_file) if "://" not in json_file else json_file,
|
| 108 |
+
os.path.join("datasets", image_root),
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
_CUSTOM_SPLITS_COCO = {
|
| 112 |
+
"cc3m_coco_train_tags": ("cc3m/training/", "cc3m/coco_train_image_info_tags.json"),
|
| 113 |
+
"coco_caption_train_tags": ("coco/train2017/", "coco/annotations/captions_train2017_tags_allcaps.json"),}
|
| 114 |
+
|
| 115 |
+
for key, (image_root, json_file) in _CUSTOM_SPLITS_COCO.items():
|
| 116 |
+
custom_register_lvis_instances(
|
| 117 |
+
key,
|
| 118 |
+
_get_builtin_metadata('coco'),
|
| 119 |
+
os.path.join("datasets", json_file) if "://" not in json_file else json_file,
|
| 120 |
+
os.path.join("datasets", image_root),
|
| 121 |
+
)
|
proxydet/data/datasets/imagenet.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
| 6 |
+
from detectron2.data.datasets.lvis import get_lvis_instances_meta
|
| 7 |
+
from .lvis_v1 import custom_load_lvis_json, get_lvis_22k_meta
|
| 8 |
+
def custom_register_imagenet_instances(name, metadata, json_file, image_root):
|
| 9 |
+
"""
|
| 10 |
+
"""
|
| 11 |
+
DatasetCatalog.register(name, lambda: custom_load_lvis_json(
|
| 12 |
+
json_file, image_root, name))
|
| 13 |
+
MetadataCatalog.get(name).set(
|
| 14 |
+
json_file=json_file, image_root=image_root,
|
| 15 |
+
evaluator_type="imagenet", **metadata
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
_CUSTOM_SPLITS_IMAGENET = {
|
| 19 |
+
"imagenet_lvis_v1": ("imagenet/ImageNet-LVIS/", "imagenet/annotations/imagenet_lvis_image_info.json"),
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
for key, (image_root, json_file) in _CUSTOM_SPLITS_IMAGENET.items():
|
| 23 |
+
custom_register_imagenet_instances(
|
| 24 |
+
key,
|
| 25 |
+
get_lvis_instances_meta('lvis_v1'),
|
| 26 |
+
os.path.join("datasets", json_file) if "://" not in json_file else json_file,
|
| 27 |
+
os.path.join("datasets", image_root),
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
_CUSTOM_SPLITS_IMAGENET_22K = {
|
| 32 |
+
"imagenet_lvis-22k": ("imagenet/ImageNet-LVIS/", "imagenet/annotations/imagenet-22k_image_info_lvis-22k.json"),
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
for key, (image_root, json_file) in _CUSTOM_SPLITS_IMAGENET_22K.items():
|
| 36 |
+
custom_register_imagenet_instances(
|
| 37 |
+
key,
|
| 38 |
+
get_lvis_22k_meta(),
|
| 39 |
+
os.path.join("datasets", json_file) if "://" not in json_file else json_file,
|
| 40 |
+
os.path.join("datasets", image_root),
|
| 41 |
+
)
|
proxydet/data/datasets/lvis_22k_categories.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
proxydet/data/datasets/lvis_v1.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
from fvcore.common.timer import Timer
|
| 6 |
+
from detectron2.structures import BoxMode
|
| 7 |
+
from fvcore.common.file_io import PathManager
|
| 8 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
| 9 |
+
from detectron2.data.datasets.lvis import get_lvis_instances_meta
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
__all__ = ["custom_load_lvis_json", "custom_register_lvis_instances"]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def custom_register_lvis_instances(name, metadata, json_file, image_root):
|
| 17 |
+
"""
|
| 18 |
+
"""
|
| 19 |
+
DatasetCatalog.register(name, lambda: custom_load_lvis_json(
|
| 20 |
+
json_file, image_root, name))
|
| 21 |
+
MetadataCatalog.get(name).set(
|
| 22 |
+
json_file=json_file, image_root=image_root,
|
| 23 |
+
evaluator_type="lvis", **metadata
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def custom_load_lvis_json(json_file, image_root, dataset_name=None):
|
| 28 |
+
'''
|
| 29 |
+
Modifications:
|
| 30 |
+
use `file_name`
|
| 31 |
+
convert neg_category_ids
|
| 32 |
+
add pos_category_ids
|
| 33 |
+
'''
|
| 34 |
+
from lvis import LVIS
|
| 35 |
+
|
| 36 |
+
json_file = PathManager.get_local_path(json_file)
|
| 37 |
+
|
| 38 |
+
timer = Timer()
|
| 39 |
+
lvis_api = LVIS(json_file)
|
| 40 |
+
if timer.seconds() > 1:
|
| 41 |
+
logger.info("Loading {} takes {:.2f} seconds.".format(
|
| 42 |
+
json_file, timer.seconds()))
|
| 43 |
+
|
| 44 |
+
catid2contid = {x['id']: i for i, x in enumerate(
|
| 45 |
+
sorted(lvis_api.dataset['categories'], key=lambda x: x['id']))}
|
| 46 |
+
if len(lvis_api.dataset['categories']) == 1203:
|
| 47 |
+
for x in lvis_api.dataset['categories']:
|
| 48 |
+
assert catid2contid[x['id']] == x['id'] - 1
|
| 49 |
+
img_ids = sorted(lvis_api.imgs.keys())
|
| 50 |
+
imgs = lvis_api.load_imgs(img_ids)
|
| 51 |
+
anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
|
| 52 |
+
|
| 53 |
+
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
|
| 54 |
+
assert len(set(ann_ids)) == len(ann_ids), \
|
| 55 |
+
"Annotation ids in '{}' are not unique".format(json_file)
|
| 56 |
+
|
| 57 |
+
imgs_anns = list(zip(imgs, anns))
|
| 58 |
+
logger.info("Loaded {} images in the LVIS v1 format from {}".format(
|
| 59 |
+
len(imgs_anns), json_file))
|
| 60 |
+
|
| 61 |
+
dataset_dicts = []
|
| 62 |
+
|
| 63 |
+
for (img_dict, anno_dict_list) in imgs_anns:
|
| 64 |
+
record = {}
|
| 65 |
+
if "file_name" in img_dict:
|
| 66 |
+
file_name = img_dict["file_name"]
|
| 67 |
+
if img_dict["file_name"].startswith("COCO"):
|
| 68 |
+
file_name = file_name[-16:]
|
| 69 |
+
record["file_name"] = os.path.join(image_root, file_name)
|
| 70 |
+
elif 'coco_url' in img_dict:
|
| 71 |
+
# e.g., http://images.cocodataset.org/train2017/000000391895.jpg
|
| 72 |
+
file_name = img_dict["coco_url"][30:]
|
| 73 |
+
record["file_name"] = os.path.join(image_root, file_name)
|
| 74 |
+
elif 'tar_index' in img_dict:
|
| 75 |
+
record['tar_index'] = img_dict['tar_index']
|
| 76 |
+
|
| 77 |
+
record["height"] = img_dict["height"]
|
| 78 |
+
record["width"] = img_dict["width"]
|
| 79 |
+
record["not_exhaustive_category_ids"] = img_dict.get(
|
| 80 |
+
"not_exhaustive_category_ids", [])
|
| 81 |
+
record["neg_category_ids"] = img_dict.get("neg_category_ids", [])
|
| 82 |
+
# NOTE: modified by Xingyi: convert to 0-based
|
| 83 |
+
record["neg_category_ids"] = [
|
| 84 |
+
catid2contid[x] for x in record["neg_category_ids"]]
|
| 85 |
+
if 'pos_category_ids' in img_dict:
|
| 86 |
+
record['pos_category_ids'] = [
|
| 87 |
+
catid2contid[x] for x in img_dict.get("pos_category_ids", [])]
|
| 88 |
+
if 'captions' in img_dict:
|
| 89 |
+
record['captions'] = img_dict['captions']
|
| 90 |
+
if 'caption_features' in img_dict:
|
| 91 |
+
record['caption_features'] = img_dict['caption_features']
|
| 92 |
+
image_id = record["image_id"] = img_dict["id"]
|
| 93 |
+
|
| 94 |
+
objs = []
|
| 95 |
+
for anno in anno_dict_list:
|
| 96 |
+
assert anno["image_id"] == image_id
|
| 97 |
+
if anno.get('iscrowd', 0) > 0:
|
| 98 |
+
continue
|
| 99 |
+
obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
|
| 100 |
+
obj["category_id"] = catid2contid[anno['category_id']]
|
| 101 |
+
if 'segmentation' in anno:
|
| 102 |
+
segm = anno["segmentation"]
|
| 103 |
+
valid_segm = [poly for poly in segm \
|
| 104 |
+
if len(poly) % 2 == 0 and len(poly) >= 6]
|
| 105 |
+
# assert len(segm) == len(
|
| 106 |
+
# valid_segm
|
| 107 |
+
# ), "Annotation contains an invalid polygon with < 3 points"
|
| 108 |
+
if not len(segm) == len(valid_segm):
|
| 109 |
+
print('Annotation contains an invalid polygon with < 3 points')
|
| 110 |
+
assert len(segm) > 0
|
| 111 |
+
obj["segmentation"] = segm
|
| 112 |
+
objs.append(obj)
|
| 113 |
+
record["annotations"] = objs
|
| 114 |
+
dataset_dicts.append(record)
|
| 115 |
+
|
| 116 |
+
return dataset_dicts
|
| 117 |
+
|
| 118 |
+
_CUSTOM_SPLITS_LVIS = {
|
| 119 |
+
"lvis_v1_train+coco": ("coco/", "lvis/lvis_v1_train+coco_mask.json"),
|
| 120 |
+
"lvis_v1_train_norare": ("coco/", "lvis/lvis_v1_train_norare.json"),
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
for key, (image_root, json_file) in _CUSTOM_SPLITS_LVIS.items():
|
| 125 |
+
custom_register_lvis_instances(
|
| 126 |
+
key,
|
| 127 |
+
get_lvis_instances_meta(key),
|
| 128 |
+
os.path.join("datasets", json_file) if "://" not in json_file else json_file,
|
| 129 |
+
os.path.join("datasets", image_root),
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def get_lvis_22k_meta():
|
| 134 |
+
from .lvis_22k_categories import CATEGORIES
|
| 135 |
+
cat_ids = [k["id"] for k in CATEGORIES]
|
| 136 |
+
assert min(cat_ids) == 1 and max(cat_ids) == len(
|
| 137 |
+
cat_ids
|
| 138 |
+
), "Category ids are not in [1, #categories], as expected"
|
| 139 |
+
# Ensure that the category list is sorted by id
|
| 140 |
+
lvis_categories = sorted(CATEGORIES, key=lambda x: x["id"])
|
| 141 |
+
thing_classes = [k["name"] for k in lvis_categories]
|
| 142 |
+
meta = {"thing_classes": thing_classes}
|
| 143 |
+
return meta
|
| 144 |
+
|
| 145 |
+
_CUSTOM_SPLITS_LVIS_22K = {
|
| 146 |
+
"lvis_v1_train_22k": ("coco/", "lvis/lvis_v1_train_lvis-22k.json"),
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
for key, (image_root, json_file) in _CUSTOM_SPLITS_LVIS_22K.items():
|
| 150 |
+
custom_register_lvis_instances(
|
| 151 |
+
key,
|
| 152 |
+
get_lvis_22k_meta(),
|
| 153 |
+
os.path.join("datasets", json_file) if "://" not in json_file else json_file,
|
| 154 |
+
os.path.join("datasets", image_root),
|
| 155 |
+
)
|
proxydet/data/datasets/objects365.py
ADDED
|
@@ -0,0 +1,770 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
from detectron2.data.datasets.register_coco import register_coco_instances
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# categories_v2 = [
|
| 6 |
+
# {'id': 1, 'name': 'Person'},
|
| 7 |
+
# {'id': 2, 'name': 'Sneakers'},
|
| 8 |
+
# {'id': 3, 'name': 'Chair'},
|
| 9 |
+
# {'id': 4, 'name': 'Other Shoes'},
|
| 10 |
+
# {'id': 5, 'name': 'Hat'},
|
| 11 |
+
# {'id': 6, 'name': 'Car'},
|
| 12 |
+
# {'id': 7, 'name': 'Lamp'},
|
| 13 |
+
# {'id': 8, 'name': 'Glasses'},
|
| 14 |
+
# {'id': 9, 'name': 'Bottle'},
|
| 15 |
+
# {'id': 10, 'name': 'Desk'},
|
| 16 |
+
# {'id': 11, 'name': 'Cup'},
|
| 17 |
+
# {'id': 12, 'name': 'Street Lights'},
|
| 18 |
+
# {'id': 13, 'name': 'Cabinet/shelf'},
|
| 19 |
+
# {'id': 14, 'name': 'Handbag/Satchel'},
|
| 20 |
+
# {'id': 15, 'name': 'Bracelet'},
|
| 21 |
+
# {'id': 16, 'name': 'Plate'},
|
| 22 |
+
# {'id': 17, 'name': 'Picture/Frame'},
|
| 23 |
+
# {'id': 18, 'name': 'Helmet'},
|
| 24 |
+
# {'id': 19, 'name': 'Book'},
|
| 25 |
+
# {'id': 20, 'name': 'Gloves'},
|
| 26 |
+
# {'id': 21, 'name': 'Storage box'},
|
| 27 |
+
# {'id': 22, 'name': 'Boat'},
|
| 28 |
+
# {'id': 23, 'name': 'Leather Shoes'},
|
| 29 |
+
# {'id': 24, 'name': 'Flower'},
|
| 30 |
+
# {'id': 25, 'name': 'Bench'},
|
| 31 |
+
# {'id': 26, 'name': 'Potted Plant'},
|
| 32 |
+
# {'id': 27, 'name': 'Bowl/Basin'},
|
| 33 |
+
# {'id': 28, 'name': 'Flag'},
|
| 34 |
+
# {'id': 29, 'name': 'Pillow'},
|
| 35 |
+
# {'id': 30, 'name': 'Boots'},
|
| 36 |
+
# {'id': 31, 'name': 'Vase'},
|
| 37 |
+
# {'id': 32, 'name': 'Microphone'},
|
| 38 |
+
# {'id': 33, 'name': 'Necklace'},
|
| 39 |
+
# {'id': 34, 'name': 'Ring'},
|
| 40 |
+
# {'id': 35, 'name': 'SUV'},
|
| 41 |
+
# {'id': 36, 'name': 'Wine Glass'},
|
| 42 |
+
# {'id': 37, 'name': 'Belt'},
|
| 43 |
+
# {'id': 38, 'name': 'Moniter/TV'},
|
| 44 |
+
# {'id': 39, 'name': 'Backpack'},
|
| 45 |
+
# {'id': 40, 'name': 'Umbrella'},
|
| 46 |
+
# {'id': 41, 'name': 'Traffic Light'},
|
| 47 |
+
# {'id': 42, 'name': 'Speaker'},
|
| 48 |
+
# {'id': 43, 'name': 'Watch'},
|
| 49 |
+
# {'id': 44, 'name': 'Tie'},
|
| 50 |
+
# {'id': 45, 'name': 'Trash bin Can'},
|
| 51 |
+
# {'id': 46, 'name': 'Slippers'},
|
| 52 |
+
# {'id': 47, 'name': 'Bicycle'},
|
| 53 |
+
# {'id': 48, 'name': 'Stool'},
|
| 54 |
+
# {'id': 49, 'name': 'Barrel/bucket'},
|
| 55 |
+
# {'id': 50, 'name': 'Van'},
|
| 56 |
+
# {'id': 51, 'name': 'Couch'},
|
| 57 |
+
# {'id': 52, 'name': 'Sandals'},
|
| 58 |
+
# {'id': 53, 'name': 'Bakset'},
|
| 59 |
+
# {'id': 54, 'name': 'Drum'},
|
| 60 |
+
# {'id': 55, 'name': 'Pen/Pencil'},
|
| 61 |
+
# {'id': 56, 'name': 'Bus'},
|
| 62 |
+
# {'id': 57, 'name': 'Wild Bird'},
|
| 63 |
+
# {'id': 58, 'name': 'High Heels'},
|
| 64 |
+
# {'id': 59, 'name': 'Motorcycle'},
|
| 65 |
+
# {'id': 60, 'name': 'Guitar'},
|
| 66 |
+
# {'id': 61, 'name': 'Carpet'},
|
| 67 |
+
# {'id': 62, 'name': 'Cell Phone'},
|
| 68 |
+
# {'id': 63, 'name': 'Bread'},
|
| 69 |
+
# {'id': 64, 'name': 'Camera'},
|
| 70 |
+
# {'id': 65, 'name': 'Canned'},
|
| 71 |
+
# {'id': 66, 'name': 'Truck'},
|
| 72 |
+
# {'id': 67, 'name': 'Traffic cone'},
|
| 73 |
+
# {'id': 68, 'name': 'Cymbal'},
|
| 74 |
+
# {'id': 69, 'name': 'Lifesaver'},
|
| 75 |
+
# {'id': 70, 'name': 'Towel'},
|
| 76 |
+
# {'id': 71, 'name': 'Stuffed Toy'},
|
| 77 |
+
# {'id': 72, 'name': 'Candle'},
|
| 78 |
+
# {'id': 73, 'name': 'Sailboat'},
|
| 79 |
+
# {'id': 74, 'name': 'Laptop'},
|
| 80 |
+
# {'id': 75, 'name': 'Awning'},
|
| 81 |
+
# {'id': 76, 'name': 'Bed'},
|
| 82 |
+
# {'id': 77, 'name': 'Faucet'},
|
| 83 |
+
# {'id': 78, 'name': 'Tent'},
|
| 84 |
+
# {'id': 79, 'name': 'Horse'},
|
| 85 |
+
# {'id': 80, 'name': 'Mirror'},
|
| 86 |
+
# {'id': 81, 'name': 'Power outlet'},
|
| 87 |
+
# {'id': 82, 'name': 'Sink'},
|
| 88 |
+
# {'id': 83, 'name': 'Apple'},
|
| 89 |
+
# {'id': 84, 'name': 'Air Conditioner'},
|
| 90 |
+
# {'id': 85, 'name': 'Knife'},
|
| 91 |
+
# {'id': 86, 'name': 'Hockey Stick'},
|
| 92 |
+
# {'id': 87, 'name': 'Paddle'},
|
| 93 |
+
# {'id': 88, 'name': 'Pickup Truck'},
|
| 94 |
+
# {'id': 89, 'name': 'Fork'},
|
| 95 |
+
# {'id': 90, 'name': 'Traffic Sign'},
|
| 96 |
+
# {'id': 91, 'name': 'Ballon'},
|
| 97 |
+
# {'id': 92, 'name': 'Tripod'},
|
| 98 |
+
# {'id': 93, 'name': 'Dog'},
|
| 99 |
+
# {'id': 94, 'name': 'Spoon'},
|
| 100 |
+
# {'id': 95, 'name': 'Clock'},
|
| 101 |
+
# {'id': 96, 'name': 'Pot'},
|
| 102 |
+
# {'id': 97, 'name': 'Cow'},
|
| 103 |
+
# {'id': 98, 'name': 'Cake'},
|
| 104 |
+
# {'id': 99, 'name': 'Dinning Table'},
|
| 105 |
+
# {'id': 100, 'name': 'Sheep'},
|
| 106 |
+
# {'id': 101, 'name': 'Hanger'},
|
| 107 |
+
# {'id': 102, 'name': 'Blackboard/Whiteboard'},
|
| 108 |
+
# {'id': 103, 'name': 'Napkin'},
|
| 109 |
+
# {'id': 104, 'name': 'Other Fish'},
|
| 110 |
+
# {'id': 105, 'name': 'Orange/Tangerine'},
|
| 111 |
+
# {'id': 106, 'name': 'Toiletry'},
|
| 112 |
+
# {'id': 107, 'name': 'Keyboard'},
|
| 113 |
+
# {'id': 108, 'name': 'Tomato'},
|
| 114 |
+
# {'id': 109, 'name': 'Lantern'},
|
| 115 |
+
# {'id': 110, 'name': 'Machinery Vehicle'},
|
| 116 |
+
# {'id': 111, 'name': 'Fan'},
|
| 117 |
+
# {'id': 112, 'name': 'Green Vegetables'},
|
| 118 |
+
# {'id': 113, 'name': 'Banana'},
|
| 119 |
+
# {'id': 114, 'name': 'Baseball Glove'},
|
| 120 |
+
# {'id': 115, 'name': 'Airplane'},
|
| 121 |
+
# {'id': 116, 'name': 'Mouse'},
|
| 122 |
+
# {'id': 117, 'name': 'Train'},
|
| 123 |
+
# {'id': 118, 'name': 'Pumpkin'},
|
| 124 |
+
# {'id': 119, 'name': 'Soccer'},
|
| 125 |
+
# {'id': 120, 'name': 'Skiboard'},
|
| 126 |
+
# {'id': 121, 'name': 'Luggage'},
|
| 127 |
+
# {'id': 122, 'name': 'Nightstand'},
|
| 128 |
+
# {'id': 123, 'name': 'Tea pot'},
|
| 129 |
+
# {'id': 124, 'name': 'Telephone'},
|
| 130 |
+
# {'id': 125, 'name': 'Trolley'},
|
| 131 |
+
# {'id': 126, 'name': 'Head Phone'},
|
| 132 |
+
# {'id': 127, 'name': 'Sports Car'},
|
| 133 |
+
# {'id': 128, 'name': 'Stop Sign'},
|
| 134 |
+
# {'id': 129, 'name': 'Dessert'},
|
| 135 |
+
# {'id': 130, 'name': 'Scooter'},
|
| 136 |
+
# {'id': 131, 'name': 'Stroller'},
|
| 137 |
+
# {'id': 132, 'name': 'Crane'},
|
| 138 |
+
# {'id': 133, 'name': 'Remote'},
|
| 139 |
+
# {'id': 134, 'name': 'Refrigerator'},
|
| 140 |
+
# {'id': 135, 'name': 'Oven'},
|
| 141 |
+
# {'id': 136, 'name': 'Lemon'},
|
| 142 |
+
# {'id': 137, 'name': 'Duck'},
|
| 143 |
+
# {'id': 138, 'name': 'Baseball Bat'},
|
| 144 |
+
# {'id': 139, 'name': 'Surveillance Camera'},
|
| 145 |
+
# {'id': 140, 'name': 'Cat'},
|
| 146 |
+
# {'id': 141, 'name': 'Jug'},
|
| 147 |
+
# {'id': 142, 'name': 'Broccoli'},
|
| 148 |
+
# {'id': 143, 'name': 'Piano'},
|
| 149 |
+
# {'id': 144, 'name': 'Pizza'},
|
| 150 |
+
# {'id': 145, 'name': 'Elephant'},
|
| 151 |
+
# {'id': 146, 'name': 'Skateboard'},
|
| 152 |
+
# {'id': 147, 'name': 'Surfboard'},
|
| 153 |
+
# {'id': 148, 'name': 'Gun'},
|
| 154 |
+
# {'id': 149, 'name': 'Skating and Skiing shoes'},
|
| 155 |
+
# {'id': 150, 'name': 'Gas stove'},
|
| 156 |
+
# {'id': 151, 'name': 'Donut'},
|
| 157 |
+
# {'id': 152, 'name': 'Bow Tie'},
|
| 158 |
+
# {'id': 153, 'name': 'Carrot'},
|
| 159 |
+
# {'id': 154, 'name': 'Toilet'},
|
| 160 |
+
# {'id': 155, 'name': 'Kite'},
|
| 161 |
+
# {'id': 156, 'name': 'Strawberry'},
|
| 162 |
+
# {'id': 157, 'name': 'Other Balls'},
|
| 163 |
+
# {'id': 158, 'name': 'Shovel'},
|
| 164 |
+
# {'id': 159, 'name': 'Pepper'},
|
| 165 |
+
# {'id': 160, 'name': 'Computer Box'},
|
| 166 |
+
# {'id': 161, 'name': 'Toilet Paper'},
|
| 167 |
+
# {'id': 162, 'name': 'Cleaning Products'},
|
| 168 |
+
# {'id': 163, 'name': 'Chopsticks'},
|
| 169 |
+
# {'id': 164, 'name': 'Microwave'},
|
| 170 |
+
# {'id': 165, 'name': 'Pigeon'},
|
| 171 |
+
# {'id': 166, 'name': 'Baseball'},
|
| 172 |
+
# {'id': 167, 'name': 'Cutting/chopping Board'},
|
| 173 |
+
# {'id': 168, 'name': 'Coffee Table'},
|
| 174 |
+
# {'id': 169, 'name': 'Side Table'},
|
| 175 |
+
# {'id': 170, 'name': 'Scissors'},
|
| 176 |
+
# {'id': 171, 'name': 'Marker'},
|
| 177 |
+
# {'id': 172, 'name': 'Pie'},
|
| 178 |
+
# {'id': 173, 'name': 'Ladder'},
|
| 179 |
+
# {'id': 174, 'name': 'Snowboard'},
|
| 180 |
+
# {'id': 175, 'name': 'Cookies'},
|
| 181 |
+
# {'id': 176, 'name': 'Radiator'},
|
| 182 |
+
# {'id': 177, 'name': 'Fire Hydrant'},
|
| 183 |
+
# {'id': 178, 'name': 'Basketball'},
|
| 184 |
+
# {'id': 179, 'name': 'Zebra'},
|
| 185 |
+
# {'id': 180, 'name': 'Grape'},
|
| 186 |
+
# {'id': 181, 'name': 'Giraffe'},
|
| 187 |
+
# {'id': 182, 'name': 'Potato'},
|
| 188 |
+
# {'id': 183, 'name': 'Sausage'},
|
| 189 |
+
# {'id': 184, 'name': 'Tricycle'},
|
| 190 |
+
# {'id': 185, 'name': 'Violin'},
|
| 191 |
+
# {'id': 186, 'name': 'Egg'},
|
| 192 |
+
# {'id': 187, 'name': 'Fire Extinguisher'},
|
| 193 |
+
# {'id': 188, 'name': 'Candy'},
|
| 194 |
+
# {'id': 189, 'name': 'Fire Truck'},
|
| 195 |
+
# {'id': 190, 'name': 'Billards'},
|
| 196 |
+
# {'id': 191, 'name': 'Converter'},
|
| 197 |
+
# {'id': 192, 'name': 'Bathtub'},
|
| 198 |
+
# {'id': 193, 'name': 'Wheelchair'},
|
| 199 |
+
# {'id': 194, 'name': 'Golf Club'},
|
| 200 |
+
# {'id': 195, 'name': 'Briefcase'},
|
| 201 |
+
# {'id': 196, 'name': 'Cucumber'},
|
| 202 |
+
# {'id': 197, 'name': 'Cigar/Cigarette '},
|
| 203 |
+
# {'id': 198, 'name': 'Paint Brush'},
|
| 204 |
+
# {'id': 199, 'name': 'Pear'},
|
| 205 |
+
# {'id': 200, 'name': 'Heavy Truck'},
|
| 206 |
+
# {'id': 201, 'name': 'Hamburger'},
|
| 207 |
+
# {'id': 202, 'name': 'Extractor'},
|
| 208 |
+
# {'id': 203, 'name': 'Extention Cord'},
|
| 209 |
+
# {'id': 204, 'name': 'Tong'},
|
| 210 |
+
# {'id': 205, 'name': 'Tennis Racket'},
|
| 211 |
+
# {'id': 206, 'name': 'Folder'},
|
| 212 |
+
# {'id': 207, 'name': 'American Football'},
|
| 213 |
+
# {'id': 208, 'name': 'earphone'},
|
| 214 |
+
# {'id': 209, 'name': 'Mask'},
|
| 215 |
+
# {'id': 210, 'name': 'Kettle'},
|
| 216 |
+
# {'id': 211, 'name': 'Tennis'},
|
| 217 |
+
# {'id': 212, 'name': 'Ship'},
|
| 218 |
+
# {'id': 213, 'name': 'Swing'},
|
| 219 |
+
# {'id': 214, 'name': 'Coffee Machine'},
|
| 220 |
+
# {'id': 215, 'name': 'Slide'},
|
| 221 |
+
# {'id': 216, 'name': 'Carriage'},
|
| 222 |
+
# {'id': 217, 'name': 'Onion'},
|
| 223 |
+
# {'id': 218, 'name': 'Green beans'},
|
| 224 |
+
# {'id': 219, 'name': 'Projector'},
|
| 225 |
+
# {'id': 220, 'name': 'Frisbee'},
|
| 226 |
+
# {'id': 221, 'name': 'Washing Machine/Drying Machine'},
|
| 227 |
+
# {'id': 222, 'name': 'Chicken'},
|
| 228 |
+
# {'id': 223, 'name': 'Printer'},
|
| 229 |
+
# {'id': 224, 'name': 'Watermelon'},
|
| 230 |
+
# {'id': 225, 'name': 'Saxophone'},
|
| 231 |
+
# {'id': 226, 'name': 'Tissue'},
|
| 232 |
+
# {'id': 227, 'name': 'Toothbrush'},
|
| 233 |
+
# {'id': 228, 'name': 'Ice cream'},
|
| 234 |
+
# {'id': 229, 'name': 'Hotair ballon'},
|
| 235 |
+
# {'id': 230, 'name': 'Cello'},
|
| 236 |
+
# {'id': 231, 'name': 'French Fries'},
|
| 237 |
+
# {'id': 232, 'name': 'Scale'},
|
| 238 |
+
# {'id': 233, 'name': 'Trophy'},
|
| 239 |
+
# {'id': 234, 'name': 'Cabbage'},
|
| 240 |
+
# {'id': 235, 'name': 'Hot dog'},
|
| 241 |
+
# {'id': 236, 'name': 'Blender'},
|
| 242 |
+
# {'id': 237, 'name': 'Peach'},
|
| 243 |
+
# {'id': 238, 'name': 'Rice'},
|
| 244 |
+
# {'id': 239, 'name': 'Wallet/Purse'},
|
| 245 |
+
# {'id': 240, 'name': 'Volleyball'},
|
| 246 |
+
# {'id': 241, 'name': 'Deer'},
|
| 247 |
+
# {'id': 242, 'name': 'Goose'},
|
| 248 |
+
# {'id': 243, 'name': 'Tape'},
|
| 249 |
+
# {'id': 244, 'name': 'Tablet'},
|
| 250 |
+
# {'id': 245, 'name': 'Cosmetics'},
|
| 251 |
+
# {'id': 246, 'name': 'Trumpet'},
|
| 252 |
+
# {'id': 247, 'name': 'Pineapple'},
|
| 253 |
+
# {'id': 248, 'name': 'Golf Ball'},
|
| 254 |
+
# {'id': 249, 'name': 'Ambulance'},
|
| 255 |
+
# {'id': 250, 'name': 'Parking meter'},
|
| 256 |
+
# {'id': 251, 'name': 'Mango'},
|
| 257 |
+
# {'id': 252, 'name': 'Key'},
|
| 258 |
+
# {'id': 253, 'name': 'Hurdle'},
|
| 259 |
+
# {'id': 254, 'name': 'Fishing Rod'},
|
| 260 |
+
# {'id': 255, 'name': 'Medal'},
|
| 261 |
+
# {'id': 256, 'name': 'Flute'},
|
| 262 |
+
# {'id': 257, 'name': 'Brush'},
|
| 263 |
+
# {'id': 258, 'name': 'Penguin'},
|
| 264 |
+
# {'id': 259, 'name': 'Megaphone'},
|
| 265 |
+
# {'id': 260, 'name': 'Corn'},
|
| 266 |
+
# {'id': 261, 'name': 'Lettuce'},
|
| 267 |
+
# {'id': 262, 'name': 'Garlic'},
|
| 268 |
+
# {'id': 263, 'name': 'Swan'},
|
| 269 |
+
# {'id': 264, 'name': 'Helicopter'},
|
| 270 |
+
# {'id': 265, 'name': 'Green Onion'},
|
| 271 |
+
# {'id': 266, 'name': 'Sandwich'},
|
| 272 |
+
# {'id': 267, 'name': 'Nuts'},
|
| 273 |
+
# {'id': 268, 'name': 'Speed Limit Sign'},
|
| 274 |
+
# {'id': 269, 'name': 'Induction Cooker'},
|
| 275 |
+
# {'id': 270, 'name': 'Broom'},
|
| 276 |
+
# {'id': 271, 'name': 'Trombone'},
|
| 277 |
+
# {'id': 272, 'name': 'Plum'},
|
| 278 |
+
# {'id': 273, 'name': 'Rickshaw'},
|
| 279 |
+
# {'id': 274, 'name': 'Goldfish'},
|
| 280 |
+
# {'id': 275, 'name': 'Kiwi fruit'},
|
| 281 |
+
# {'id': 276, 'name': 'Router/modem'},
|
| 282 |
+
# {'id': 277, 'name': 'Poker Card'},
|
| 283 |
+
# {'id': 278, 'name': 'Toaster'},
|
| 284 |
+
# {'id': 279, 'name': 'Shrimp'},
|
| 285 |
+
# {'id': 280, 'name': 'Sushi'},
|
| 286 |
+
# {'id': 281, 'name': 'Cheese'},
|
| 287 |
+
# {'id': 282, 'name': 'Notepaper'},
|
| 288 |
+
# {'id': 283, 'name': 'Cherry'},
|
| 289 |
+
# {'id': 284, 'name': 'Pliers'},
|
| 290 |
+
# {'id': 285, 'name': 'CD'},
|
| 291 |
+
# {'id': 286, 'name': 'Pasta'},
|
| 292 |
+
# {'id': 287, 'name': 'Hammer'},
|
| 293 |
+
# {'id': 288, 'name': 'Cue'},
|
| 294 |
+
# {'id': 289, 'name': 'Avocado'},
|
| 295 |
+
# {'id': 290, 'name': 'Hamimelon'},
|
| 296 |
+
# {'id': 291, 'name': 'Flask'},
|
| 297 |
+
# {'id': 292, 'name': 'Mushroon'},
|
| 298 |
+
# {'id': 293, 'name': 'Screwdriver'},
|
| 299 |
+
# {'id': 294, 'name': 'Soap'},
|
| 300 |
+
# {'id': 295, 'name': 'Recorder'},
|
| 301 |
+
# {'id': 296, 'name': 'Bear'},
|
| 302 |
+
# {'id': 297, 'name': 'Eggplant'},
|
| 303 |
+
# {'id': 298, 'name': 'Board Eraser'},
|
| 304 |
+
# {'id': 299, 'name': 'Coconut'},
|
| 305 |
+
# {'id': 300, 'name': 'Tape Measur/ Ruler'},
|
| 306 |
+
# {'id': 301, 'name': 'Pig'},
|
| 307 |
+
# {'id': 302, 'name': 'Showerhead'},
|
| 308 |
+
# {'id': 303, 'name': 'Globe'},
|
| 309 |
+
# {'id': 304, 'name': 'Chips'},
|
| 310 |
+
# {'id': 305, 'name': 'Steak'},
|
| 311 |
+
# {'id': 306, 'name': 'Crosswalk Sign'},
|
| 312 |
+
# {'id': 307, 'name': 'Stapler'},
|
| 313 |
+
# {'id': 308, 'name': 'Campel'},
|
| 314 |
+
# {'id': 309, 'name': 'Formula 1 '},
|
| 315 |
+
# {'id': 310, 'name': 'Pomegranate'},
|
| 316 |
+
# {'id': 311, 'name': 'Dishwasher'},
|
| 317 |
+
# {'id': 312, 'name': 'Crab'},
|
| 318 |
+
# {'id': 313, 'name': 'Hoverboard'},
|
| 319 |
+
# {'id': 314, 'name': 'Meat ball'},
|
| 320 |
+
# {'id': 315, 'name': 'Rice Cooker'},
|
| 321 |
+
# {'id': 316, 'name': 'Tuba'},
|
| 322 |
+
# {'id': 317, 'name': 'Calculator'},
|
| 323 |
+
# {'id': 318, 'name': 'Papaya'},
|
| 324 |
+
# {'id': 319, 'name': 'Antelope'},
|
| 325 |
+
# {'id': 320, 'name': 'Parrot'},
|
| 326 |
+
# {'id': 321, 'name': 'Seal'},
|
| 327 |
+
# {'id': 322, 'name': 'Buttefly'},
|
| 328 |
+
# {'id': 323, 'name': 'Dumbbell'},
|
| 329 |
+
# {'id': 324, 'name': 'Donkey'},
|
| 330 |
+
# {'id': 325, 'name': 'Lion'},
|
| 331 |
+
# {'id': 326, 'name': 'Urinal'},
|
| 332 |
+
# {'id': 327, 'name': 'Dolphin'},
|
| 333 |
+
# {'id': 328, 'name': 'Electric Drill'},
|
| 334 |
+
# {'id': 329, 'name': 'Hair Dryer'},
|
| 335 |
+
# {'id': 330, 'name': 'Egg tart'},
|
| 336 |
+
# {'id': 331, 'name': 'Jellyfish'},
|
| 337 |
+
# {'id': 332, 'name': 'Treadmill'},
|
| 338 |
+
# {'id': 333, 'name': 'Lighter'},
|
| 339 |
+
# {'id': 334, 'name': 'Grapefruit'},
|
| 340 |
+
# {'id': 335, 'name': 'Game board'},
|
| 341 |
+
# {'id': 336, 'name': 'Mop'},
|
| 342 |
+
# {'id': 337, 'name': 'Radish'},
|
| 343 |
+
# {'id': 338, 'name': 'Baozi'},
|
| 344 |
+
# {'id': 339, 'name': 'Target'},
|
| 345 |
+
# {'id': 340, 'name': 'French'},
|
| 346 |
+
# {'id': 341, 'name': 'Spring Rolls'},
|
| 347 |
+
# {'id': 342, 'name': 'Monkey'},
|
| 348 |
+
# {'id': 343, 'name': 'Rabbit'},
|
| 349 |
+
# {'id': 344, 'name': 'Pencil Case'},
|
| 350 |
+
# {'id': 345, 'name': 'Yak'},
|
| 351 |
+
# {'id': 346, 'name': 'Red Cabbage'},
|
| 352 |
+
# {'id': 347, 'name': 'Binoculars'},
|
| 353 |
+
# {'id': 348, 'name': 'Asparagus'},
|
| 354 |
+
# {'id': 349, 'name': 'Barbell'},
|
| 355 |
+
# {'id': 350, 'name': 'Scallop'},
|
| 356 |
+
# {'id': 351, 'name': 'Noddles'},
|
| 357 |
+
# {'id': 352, 'name': 'Comb'},
|
| 358 |
+
# {'id': 353, 'name': 'Dumpling'},
|
| 359 |
+
# {'id': 354, 'name': 'Oyster'},
|
| 360 |
+
# {'id': 355, 'name': 'Table Teniis paddle'},
|
| 361 |
+
# {'id': 356, 'name': 'Cosmetics Brush/Eyeliner Pencil'},
|
| 362 |
+
# {'id': 357, 'name': 'Chainsaw'},
|
| 363 |
+
# {'id': 358, 'name': 'Eraser'},
|
| 364 |
+
# {'id': 359, 'name': 'Lobster'},
|
| 365 |
+
# {'id': 360, 'name': 'Durian'},
|
| 366 |
+
# {'id': 361, 'name': 'Okra'},
|
| 367 |
+
# {'id': 362, 'name': 'Lipstick'},
|
| 368 |
+
# {'id': 363, 'name': 'Cosmetics Mirror'},
|
| 369 |
+
# {'id': 364, 'name': 'Curling'},
|
| 370 |
+
# {'id': 365, 'name': 'Table Tennis '},
|
| 371 |
+
# ]
|
| 372 |
+
|
| 373 |
+
'''
|
| 374 |
+
The official Objects365 category names contains typos.
|
| 375 |
+
Below is a manual fix.
|
| 376 |
+
'''
|
| 377 |
+
categories_v2_fix = [
|
| 378 |
+
{'id': 1, 'name': 'Person'},
|
| 379 |
+
{'id': 2, 'name': 'Sneakers'},
|
| 380 |
+
{'id': 3, 'name': 'Chair'},
|
| 381 |
+
{'id': 4, 'name': 'Other Shoes'},
|
| 382 |
+
{'id': 5, 'name': 'Hat'},
|
| 383 |
+
{'id': 6, 'name': 'Car'},
|
| 384 |
+
{'id': 7, 'name': 'Lamp'},
|
| 385 |
+
{'id': 8, 'name': 'Glasses'},
|
| 386 |
+
{'id': 9, 'name': 'Bottle'},
|
| 387 |
+
{'id': 10, 'name': 'Desk'},
|
| 388 |
+
{'id': 11, 'name': 'Cup'},
|
| 389 |
+
{'id': 12, 'name': 'Street Lights'},
|
| 390 |
+
{'id': 13, 'name': 'Cabinet/shelf'},
|
| 391 |
+
{'id': 14, 'name': 'Handbag/Satchel'},
|
| 392 |
+
{'id': 15, 'name': 'Bracelet'},
|
| 393 |
+
{'id': 16, 'name': 'Plate'},
|
| 394 |
+
{'id': 17, 'name': 'Picture/Frame'},
|
| 395 |
+
{'id': 18, 'name': 'Helmet'},
|
| 396 |
+
{'id': 19, 'name': 'Book'},
|
| 397 |
+
{'id': 20, 'name': 'Gloves'},
|
| 398 |
+
{'id': 21, 'name': 'Storage box'},
|
| 399 |
+
{'id': 22, 'name': 'Boat'},
|
| 400 |
+
{'id': 23, 'name': 'Leather Shoes'},
|
| 401 |
+
{'id': 24, 'name': 'Flower'},
|
| 402 |
+
{'id': 25, 'name': 'Bench'},
|
| 403 |
+
{'id': 26, 'name': 'Potted Plant'},
|
| 404 |
+
{'id': 27, 'name': 'Bowl/Basin'},
|
| 405 |
+
{'id': 28, 'name': 'Flag'},
|
| 406 |
+
{'id': 29, 'name': 'Pillow'},
|
| 407 |
+
{'id': 30, 'name': 'Boots'},
|
| 408 |
+
{'id': 31, 'name': 'Vase'},
|
| 409 |
+
{'id': 32, 'name': 'Microphone'},
|
| 410 |
+
{'id': 33, 'name': 'Necklace'},
|
| 411 |
+
{'id': 34, 'name': 'Ring'},
|
| 412 |
+
{'id': 35, 'name': 'SUV'},
|
| 413 |
+
{'id': 36, 'name': 'Wine Glass'},
|
| 414 |
+
{'id': 37, 'name': 'Belt'},
|
| 415 |
+
{'id': 38, 'name': 'Monitor/TV'},
|
| 416 |
+
{'id': 39, 'name': 'Backpack'},
|
| 417 |
+
{'id': 40, 'name': 'Umbrella'},
|
| 418 |
+
{'id': 41, 'name': 'Traffic Light'},
|
| 419 |
+
{'id': 42, 'name': 'Speaker'},
|
| 420 |
+
{'id': 43, 'name': 'Watch'},
|
| 421 |
+
{'id': 44, 'name': 'Tie'},
|
| 422 |
+
{'id': 45, 'name': 'Trash bin Can'},
|
| 423 |
+
{'id': 46, 'name': 'Slippers'},
|
| 424 |
+
{'id': 47, 'name': 'Bicycle'},
|
| 425 |
+
{'id': 48, 'name': 'Stool'},
|
| 426 |
+
{'id': 49, 'name': 'Barrel/bucket'},
|
| 427 |
+
{'id': 50, 'name': 'Van'},
|
| 428 |
+
{'id': 51, 'name': 'Couch'},
|
| 429 |
+
{'id': 52, 'name': 'Sandals'},
|
| 430 |
+
{'id': 53, 'name': 'Basket'},
|
| 431 |
+
{'id': 54, 'name': 'Drum'},
|
| 432 |
+
{'id': 55, 'name': 'Pen/Pencil'},
|
| 433 |
+
{'id': 56, 'name': 'Bus'},
|
| 434 |
+
{'id': 57, 'name': 'Wild Bird'},
|
| 435 |
+
{'id': 58, 'name': 'High Heels'},
|
| 436 |
+
{'id': 59, 'name': 'Motorcycle'},
|
| 437 |
+
{'id': 60, 'name': 'Guitar'},
|
| 438 |
+
{'id': 61, 'name': 'Carpet'},
|
| 439 |
+
{'id': 62, 'name': 'Cell Phone'},
|
| 440 |
+
{'id': 63, 'name': 'Bread'},
|
| 441 |
+
{'id': 64, 'name': 'Camera'},
|
| 442 |
+
{'id': 65, 'name': 'Canned'},
|
| 443 |
+
{'id': 66, 'name': 'Truck'},
|
| 444 |
+
{'id': 67, 'name': 'Traffic cone'},
|
| 445 |
+
{'id': 68, 'name': 'Cymbal'},
|
| 446 |
+
{'id': 69, 'name': 'Lifesaver'},
|
| 447 |
+
{'id': 70, 'name': 'Towel'},
|
| 448 |
+
{'id': 71, 'name': 'Stuffed Toy'},
|
| 449 |
+
{'id': 72, 'name': 'Candle'},
|
| 450 |
+
{'id': 73, 'name': 'Sailboat'},
|
| 451 |
+
{'id': 74, 'name': 'Laptop'},
|
| 452 |
+
{'id': 75, 'name': 'Awning'},
|
| 453 |
+
{'id': 76, 'name': 'Bed'},
|
| 454 |
+
{'id': 77, 'name': 'Faucet'},
|
| 455 |
+
{'id': 78, 'name': 'Tent'},
|
| 456 |
+
{'id': 79, 'name': 'Horse'},
|
| 457 |
+
{'id': 80, 'name': 'Mirror'},
|
| 458 |
+
{'id': 81, 'name': 'Power outlet'},
|
| 459 |
+
{'id': 82, 'name': 'Sink'},
|
| 460 |
+
{'id': 83, 'name': 'Apple'},
|
| 461 |
+
{'id': 84, 'name': 'Air Conditioner'},
|
| 462 |
+
{'id': 85, 'name': 'Knife'},
|
| 463 |
+
{'id': 86, 'name': 'Hockey Stick'},
|
| 464 |
+
{'id': 87, 'name': 'Paddle'},
|
| 465 |
+
{'id': 88, 'name': 'Pickup Truck'},
|
| 466 |
+
{'id': 89, 'name': 'Fork'},
|
| 467 |
+
{'id': 90, 'name': 'Traffic Sign'},
|
| 468 |
+
{'id': 91, 'name': 'Ballon'},
|
| 469 |
+
{'id': 92, 'name': 'Tripod'},
|
| 470 |
+
{'id': 93, 'name': 'Dog'},
|
| 471 |
+
{'id': 94, 'name': 'Spoon'},
|
| 472 |
+
{'id': 95, 'name': 'Clock'},
|
| 473 |
+
{'id': 96, 'name': 'Pot'},
|
| 474 |
+
{'id': 97, 'name': 'Cow'},
|
| 475 |
+
{'id': 98, 'name': 'Cake'},
|
| 476 |
+
{'id': 99, 'name': 'Dining Table'},
|
| 477 |
+
{'id': 100, 'name': 'Sheep'},
|
| 478 |
+
{'id': 101, 'name': 'Hanger'},
|
| 479 |
+
{'id': 102, 'name': 'Blackboard/Whiteboard'},
|
| 480 |
+
{'id': 103, 'name': 'Napkin'},
|
| 481 |
+
{'id': 104, 'name': 'Other Fish'},
|
| 482 |
+
{'id': 105, 'name': 'Orange/Tangerine'},
|
| 483 |
+
{'id': 106, 'name': 'Toiletry'},
|
| 484 |
+
{'id': 107, 'name': 'Keyboard'},
|
| 485 |
+
{'id': 108, 'name': 'Tomato'},
|
| 486 |
+
{'id': 109, 'name': 'Lantern'},
|
| 487 |
+
{'id': 110, 'name': 'Machinery Vehicle'},
|
| 488 |
+
{'id': 111, 'name': 'Fan'},
|
| 489 |
+
{'id': 112, 'name': 'Green Vegetables'},
|
| 490 |
+
{'id': 113, 'name': 'Banana'},
|
| 491 |
+
{'id': 114, 'name': 'Baseball Glove'},
|
| 492 |
+
{'id': 115, 'name': 'Airplane'},
|
| 493 |
+
{'id': 116, 'name': 'Mouse'},
|
| 494 |
+
{'id': 117, 'name': 'Train'},
|
| 495 |
+
{'id': 118, 'name': 'Pumpkin'},
|
| 496 |
+
{'id': 119, 'name': 'Soccer'},
|
| 497 |
+
{'id': 120, 'name': 'Skiboard'},
|
| 498 |
+
{'id': 121, 'name': 'Luggage'},
|
| 499 |
+
{'id': 122, 'name': 'Nightstand'},
|
| 500 |
+
{'id': 123, 'name': 'Teapot'},
|
| 501 |
+
{'id': 124, 'name': 'Telephone'},
|
| 502 |
+
{'id': 125, 'name': 'Trolley'},
|
| 503 |
+
{'id': 126, 'name': 'Head Phone'},
|
| 504 |
+
{'id': 127, 'name': 'Sports Car'},
|
| 505 |
+
{'id': 128, 'name': 'Stop Sign'},
|
| 506 |
+
{'id': 129, 'name': 'Dessert'},
|
| 507 |
+
{'id': 130, 'name': 'Scooter'},
|
| 508 |
+
{'id': 131, 'name': 'Stroller'},
|
| 509 |
+
{'id': 132, 'name': 'Crane'},
|
| 510 |
+
{'id': 133, 'name': 'Remote'},
|
| 511 |
+
{'id': 134, 'name': 'Refrigerator'},
|
| 512 |
+
{'id': 135, 'name': 'Oven'},
|
| 513 |
+
{'id': 136, 'name': 'Lemon'},
|
| 514 |
+
{'id': 137, 'name': 'Duck'},
|
| 515 |
+
{'id': 138, 'name': 'Baseball Bat'},
|
| 516 |
+
{'id': 139, 'name': 'Surveillance Camera'},
|
| 517 |
+
{'id': 140, 'name': 'Cat'},
|
| 518 |
+
{'id': 141, 'name': 'Jug'},
|
| 519 |
+
{'id': 142, 'name': 'Broccoli'},
|
| 520 |
+
{'id': 143, 'name': 'Piano'},
|
| 521 |
+
{'id': 144, 'name': 'Pizza'},
|
| 522 |
+
{'id': 145, 'name': 'Elephant'},
|
| 523 |
+
{'id': 146, 'name': 'Skateboard'},
|
| 524 |
+
{'id': 147, 'name': 'Surfboard'},
|
| 525 |
+
{'id': 148, 'name': 'Gun'},
|
| 526 |
+
{'id': 149, 'name': 'Skating and Skiing shoes'},
|
| 527 |
+
{'id': 150, 'name': 'Gas stove'},
|
| 528 |
+
{'id': 151, 'name': 'Donut'},
|
| 529 |
+
{'id': 152, 'name': 'Bow Tie'},
|
| 530 |
+
{'id': 153, 'name': 'Carrot'},
|
| 531 |
+
{'id': 154, 'name': 'Toilet'},
|
| 532 |
+
{'id': 155, 'name': 'Kite'},
|
| 533 |
+
{'id': 156, 'name': 'Strawberry'},
|
| 534 |
+
{'id': 157, 'name': 'Other Balls'},
|
| 535 |
+
{'id': 158, 'name': 'Shovel'},
|
| 536 |
+
{'id': 159, 'name': 'Pepper'},
|
| 537 |
+
{'id': 160, 'name': 'Computer Box'},
|
| 538 |
+
{'id': 161, 'name': 'Toilet Paper'},
|
| 539 |
+
{'id': 162, 'name': 'Cleaning Products'},
|
| 540 |
+
{'id': 163, 'name': 'Chopsticks'},
|
| 541 |
+
{'id': 164, 'name': 'Microwave'},
|
| 542 |
+
{'id': 165, 'name': 'Pigeon'},
|
| 543 |
+
{'id': 166, 'name': 'Baseball'},
|
| 544 |
+
{'id': 167, 'name': 'Cutting/chopping Board'},
|
| 545 |
+
{'id': 168, 'name': 'Coffee Table'},
|
| 546 |
+
{'id': 169, 'name': 'Side Table'},
|
| 547 |
+
{'id': 170, 'name': 'Scissors'},
|
| 548 |
+
{'id': 171, 'name': 'Marker'},
|
| 549 |
+
{'id': 172, 'name': 'Pie'},
|
| 550 |
+
{'id': 173, 'name': 'Ladder'},
|
| 551 |
+
{'id': 174, 'name': 'Snowboard'},
|
| 552 |
+
{'id': 175, 'name': 'Cookies'},
|
| 553 |
+
{'id': 176, 'name': 'Radiator'},
|
| 554 |
+
{'id': 177, 'name': 'Fire Hydrant'},
|
| 555 |
+
{'id': 178, 'name': 'Basketball'},
|
| 556 |
+
{'id': 179, 'name': 'Zebra'},
|
| 557 |
+
{'id': 180, 'name': 'Grape'},
|
| 558 |
+
{'id': 181, 'name': 'Giraffe'},
|
| 559 |
+
{'id': 182, 'name': 'Potato'},
|
| 560 |
+
{'id': 183, 'name': 'Sausage'},
|
| 561 |
+
{'id': 184, 'name': 'Tricycle'},
|
| 562 |
+
{'id': 185, 'name': 'Violin'},
|
| 563 |
+
{'id': 186, 'name': 'Egg'},
|
| 564 |
+
{'id': 187, 'name': 'Fire Extinguisher'},
|
| 565 |
+
{'id': 188, 'name': 'Candy'},
|
| 566 |
+
{'id': 189, 'name': 'Fire Truck'},
|
| 567 |
+
{'id': 190, 'name': 'Billards'},
|
| 568 |
+
{'id': 191, 'name': 'Converter'},
|
| 569 |
+
{'id': 192, 'name': 'Bathtub'},
|
| 570 |
+
{'id': 193, 'name': 'Wheelchair'},
|
| 571 |
+
{'id': 194, 'name': 'Golf Club'},
|
| 572 |
+
{'id': 195, 'name': 'Briefcase'},
|
| 573 |
+
{'id': 196, 'name': 'Cucumber'},
|
| 574 |
+
{'id': 197, 'name': 'Cigar/Cigarette '},
|
| 575 |
+
{'id': 198, 'name': 'Paint Brush'},
|
| 576 |
+
{'id': 199, 'name': 'Pear'},
|
| 577 |
+
{'id': 200, 'name': 'Heavy Truck'},
|
| 578 |
+
{'id': 201, 'name': 'Hamburger'},
|
| 579 |
+
{'id': 202, 'name': 'Extractor'},
|
| 580 |
+
{'id': 203, 'name': 'Extension Cord'},
|
| 581 |
+
{'id': 204, 'name': 'Tong'},
|
| 582 |
+
{'id': 205, 'name': 'Tennis Racket'},
|
| 583 |
+
{'id': 206, 'name': 'Folder'},
|
| 584 |
+
{'id': 207, 'name': 'American Football'},
|
| 585 |
+
{'id': 208, 'name': 'earphone'},
|
| 586 |
+
{'id': 209, 'name': 'Mask'},
|
| 587 |
+
{'id': 210, 'name': 'Kettle'},
|
| 588 |
+
{'id': 211, 'name': 'Tennis'},
|
| 589 |
+
{'id': 212, 'name': 'Ship'},
|
| 590 |
+
{'id': 213, 'name': 'Swing'},
|
| 591 |
+
{'id': 214, 'name': 'Coffee Machine'},
|
| 592 |
+
{'id': 215, 'name': 'Slide'},
|
| 593 |
+
{'id': 216, 'name': 'Carriage'},
|
| 594 |
+
{'id': 217, 'name': 'Onion'},
|
| 595 |
+
{'id': 218, 'name': 'Green beans'},
|
| 596 |
+
{'id': 219, 'name': 'Projector'},
|
| 597 |
+
{'id': 220, 'name': 'Frisbee'},
|
| 598 |
+
{'id': 221, 'name': 'Washing Machine/Drying Machine'},
|
| 599 |
+
{'id': 222, 'name': 'Chicken'},
|
| 600 |
+
{'id': 223, 'name': 'Printer'},
|
| 601 |
+
{'id': 224, 'name': 'Watermelon'},
|
| 602 |
+
{'id': 225, 'name': 'Saxophone'},
|
| 603 |
+
{'id': 226, 'name': 'Tissue'},
|
| 604 |
+
{'id': 227, 'name': 'Toothbrush'},
|
| 605 |
+
{'id': 228, 'name': 'Ice cream'},
|
| 606 |
+
{'id': 229, 'name': 'Hot air balloon'},
|
| 607 |
+
{'id': 230, 'name': 'Cello'},
|
| 608 |
+
{'id': 231, 'name': 'French Fries'},
|
| 609 |
+
{'id': 232, 'name': 'Scale'},
|
| 610 |
+
{'id': 233, 'name': 'Trophy'},
|
| 611 |
+
{'id': 234, 'name': 'Cabbage'},
|
| 612 |
+
{'id': 235, 'name': 'Hot dog'},
|
| 613 |
+
{'id': 236, 'name': 'Blender'},
|
| 614 |
+
{'id': 237, 'name': 'Peach'},
|
| 615 |
+
{'id': 238, 'name': 'Rice'},
|
| 616 |
+
{'id': 239, 'name': 'Wallet/Purse'},
|
| 617 |
+
{'id': 240, 'name': 'Volleyball'},
|
| 618 |
+
{'id': 241, 'name': 'Deer'},
|
| 619 |
+
{'id': 242, 'name': 'Goose'},
|
| 620 |
+
{'id': 243, 'name': 'Tape'},
|
| 621 |
+
{'id': 244, 'name': 'Tablet'},
|
| 622 |
+
{'id': 245, 'name': 'Cosmetics'},
|
| 623 |
+
{'id': 246, 'name': 'Trumpet'},
|
| 624 |
+
{'id': 247, 'name': 'Pineapple'},
|
| 625 |
+
{'id': 248, 'name': 'Golf Ball'},
|
| 626 |
+
{'id': 249, 'name': 'Ambulance'},
|
| 627 |
+
{'id': 250, 'name': 'Parking meter'},
|
| 628 |
+
{'id': 251, 'name': 'Mango'},
|
| 629 |
+
{'id': 252, 'name': 'Key'},
|
| 630 |
+
{'id': 253, 'name': 'Hurdle'},
|
| 631 |
+
{'id': 254, 'name': 'Fishing Rod'},
|
| 632 |
+
{'id': 255, 'name': 'Medal'},
|
| 633 |
+
{'id': 256, 'name': 'Flute'},
|
| 634 |
+
{'id': 257, 'name': 'Brush'},
|
| 635 |
+
{'id': 258, 'name': 'Penguin'},
|
| 636 |
+
{'id': 259, 'name': 'Megaphone'},
|
| 637 |
+
{'id': 260, 'name': 'Corn'},
|
| 638 |
+
{'id': 261, 'name': 'Lettuce'},
|
| 639 |
+
{'id': 262, 'name': 'Garlic'},
|
| 640 |
+
{'id': 263, 'name': 'Swan'},
|
| 641 |
+
{'id': 264, 'name': 'Helicopter'},
|
| 642 |
+
{'id': 265, 'name': 'Green Onion'},
|
| 643 |
+
{'id': 266, 'name': 'Sandwich'},
|
| 644 |
+
{'id': 267, 'name': 'Nuts'},
|
| 645 |
+
{'id': 268, 'name': 'Speed Limit Sign'},
|
| 646 |
+
{'id': 269, 'name': 'Induction Cooker'},
|
| 647 |
+
{'id': 270, 'name': 'Broom'},
|
| 648 |
+
{'id': 271, 'name': 'Trombone'},
|
| 649 |
+
{'id': 272, 'name': 'Plum'},
|
| 650 |
+
{'id': 273, 'name': 'Rickshaw'},
|
| 651 |
+
{'id': 274, 'name': 'Goldfish'},
|
| 652 |
+
{'id': 275, 'name': 'Kiwi fruit'},
|
| 653 |
+
{'id': 276, 'name': 'Router/modem'},
|
| 654 |
+
{'id': 277, 'name': 'Poker Card'},
|
| 655 |
+
{'id': 278, 'name': 'Toaster'},
|
| 656 |
+
{'id': 279, 'name': 'Shrimp'},
|
| 657 |
+
{'id': 280, 'name': 'Sushi'},
|
| 658 |
+
{'id': 281, 'name': 'Cheese'},
|
| 659 |
+
{'id': 282, 'name': 'Notepaper'},
|
| 660 |
+
{'id': 283, 'name': 'Cherry'},
|
| 661 |
+
{'id': 284, 'name': 'Pliers'},
|
| 662 |
+
{'id': 285, 'name': 'CD'},
|
| 663 |
+
{'id': 286, 'name': 'Pasta'},
|
| 664 |
+
{'id': 287, 'name': 'Hammer'},
|
| 665 |
+
{'id': 288, 'name': 'Cue'},
|
| 666 |
+
{'id': 289, 'name': 'Avocado'},
|
| 667 |
+
{'id': 290, 'name': 'Hami melon'},
|
| 668 |
+
{'id': 291, 'name': 'Flask'},
|
| 669 |
+
{'id': 292, 'name': 'Mushroom'},
|
| 670 |
+
{'id': 293, 'name': 'Screwdriver'},
|
| 671 |
+
{'id': 294, 'name': 'Soap'},
|
| 672 |
+
{'id': 295, 'name': 'Recorder'},
|
| 673 |
+
{'id': 296, 'name': 'Bear'},
|
| 674 |
+
{'id': 297, 'name': 'Eggplant'},
|
| 675 |
+
{'id': 298, 'name': 'Board Eraser'},
|
| 676 |
+
{'id': 299, 'name': 'Coconut'},
|
| 677 |
+
{'id': 300, 'name': 'Tape Measure/ Ruler'},
|
| 678 |
+
{'id': 301, 'name': 'Pig'},
|
| 679 |
+
{'id': 302, 'name': 'Showerhead'},
|
| 680 |
+
{'id': 303, 'name': 'Globe'},
|
| 681 |
+
{'id': 304, 'name': 'Chips'},
|
| 682 |
+
{'id': 305, 'name': 'Steak'},
|
| 683 |
+
{'id': 306, 'name': 'Crosswalk Sign'},
|
| 684 |
+
{'id': 307, 'name': 'Stapler'},
|
| 685 |
+
{'id': 308, 'name': 'Camel'},
|
| 686 |
+
{'id': 309, 'name': 'Formula 1 '},
|
| 687 |
+
{'id': 310, 'name': 'Pomegranate'},
|
| 688 |
+
{'id': 311, 'name': 'Dishwasher'},
|
| 689 |
+
{'id': 312, 'name': 'Crab'},
|
| 690 |
+
{'id': 313, 'name': 'Hoverboard'},
|
| 691 |
+
{'id': 314, 'name': 'Meatball'},
|
| 692 |
+
{'id': 315, 'name': 'Rice Cooker'},
|
| 693 |
+
{'id': 316, 'name': 'Tuba'},
|
| 694 |
+
{'id': 317, 'name': 'Calculator'},
|
| 695 |
+
{'id': 318, 'name': 'Papaya'},
|
| 696 |
+
{'id': 319, 'name': 'Antelope'},
|
| 697 |
+
{'id': 320, 'name': 'Parrot'},
|
| 698 |
+
{'id': 321, 'name': 'Seal'},
|
| 699 |
+
{'id': 322, 'name': 'Butterfly'},
|
| 700 |
+
{'id': 323, 'name': 'Dumbbell'},
|
| 701 |
+
{'id': 324, 'name': 'Donkey'},
|
| 702 |
+
{'id': 325, 'name': 'Lion'},
|
| 703 |
+
{'id': 326, 'name': 'Urinal'},
|
| 704 |
+
{'id': 327, 'name': 'Dolphin'},
|
| 705 |
+
{'id': 328, 'name': 'Electric Drill'},
|
| 706 |
+
{'id': 329, 'name': 'Hair Dryer'},
|
| 707 |
+
{'id': 330, 'name': 'Egg tart'},
|
| 708 |
+
{'id': 331, 'name': 'Jellyfish'},
|
| 709 |
+
{'id': 332, 'name': 'Treadmill'},
|
| 710 |
+
{'id': 333, 'name': 'Lighter'},
|
| 711 |
+
{'id': 334, 'name': 'Grapefruit'},
|
| 712 |
+
{'id': 335, 'name': 'Game board'},
|
| 713 |
+
{'id': 336, 'name': 'Mop'},
|
| 714 |
+
{'id': 337, 'name': 'Radish'},
|
| 715 |
+
{'id': 338, 'name': 'Baozi'},
|
| 716 |
+
{'id': 339, 'name': 'Target'},
|
| 717 |
+
{'id': 340, 'name': 'French'},
|
| 718 |
+
{'id': 341, 'name': 'Spring Rolls'},
|
| 719 |
+
{'id': 342, 'name': 'Monkey'},
|
| 720 |
+
{'id': 343, 'name': 'Rabbit'},
|
| 721 |
+
{'id': 344, 'name': 'Pencil Case'},
|
| 722 |
+
{'id': 345, 'name': 'Yak'},
|
| 723 |
+
{'id': 346, 'name': 'Red Cabbage'},
|
| 724 |
+
{'id': 347, 'name': 'Binoculars'},
|
| 725 |
+
{'id': 348, 'name': 'Asparagus'},
|
| 726 |
+
{'id': 349, 'name': 'Barbell'},
|
| 727 |
+
{'id': 350, 'name': 'Scallop'},
|
| 728 |
+
{'id': 351, 'name': 'Noddles'},
|
| 729 |
+
{'id': 352, 'name': 'Comb'},
|
| 730 |
+
{'id': 353, 'name': 'Dumpling'},
|
| 731 |
+
{'id': 354, 'name': 'Oyster'},
|
| 732 |
+
{'id': 355, 'name': 'Table Tennis paddle'},
|
| 733 |
+
{'id': 356, 'name': 'Cosmetics Brush/Eyeliner Pencil'},
|
| 734 |
+
{'id': 357, 'name': 'Chainsaw'},
|
| 735 |
+
{'id': 358, 'name': 'Eraser'},
|
| 736 |
+
{'id': 359, 'name': 'Lobster'},
|
| 737 |
+
{'id': 360, 'name': 'Durian'},
|
| 738 |
+
{'id': 361, 'name': 'Okra'},
|
| 739 |
+
{'id': 362, 'name': 'Lipstick'},
|
| 740 |
+
{'id': 363, 'name': 'Cosmetics Mirror'},
|
| 741 |
+
{'id': 364, 'name': 'Curling'},
|
| 742 |
+
{'id': 365, 'name': 'Table Tennis '},
|
| 743 |
+
]
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
def _get_builtin_metadata():
|
| 747 |
+
id_to_name = {x['id']: x['name'] for x in categories_v2_fix}
|
| 748 |
+
thing_dataset_id_to_contiguous_id = {
|
| 749 |
+
x['id']: i for i, x in enumerate(
|
| 750 |
+
sorted(categories_v2_fix, key=lambda x: x['id']))}
|
| 751 |
+
thing_classes = [id_to_name[k] for k in sorted(id_to_name)]
|
| 752 |
+
return {
|
| 753 |
+
"thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
|
| 754 |
+
"thing_classes": thing_classes}
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
_PREDEFINED_SPLITS_OBJECTS365 = {
|
| 758 |
+
"objects365_v2_train": ("objects365/train", "objects365/annotations/zhiyuan_objv2_train_fixname_fixmiss.json"),
|
| 759 |
+
# 80,000 images, 1,240,587 annotations
|
| 760 |
+
"objects365_v2_val": ("objects365/val", "objects365/annotations/zhiyuan_objv2_val_fixname.json"),
|
| 761 |
+
"objects365_v2_val_rare": ("objects365/val", "objects365/annotations/zhiyuan_objv2_val_fixname_rare.json"),
|
| 762 |
+
}
|
| 763 |
+
|
| 764 |
+
for key, (image_root, json_file) in _PREDEFINED_SPLITS_OBJECTS365.items():
|
| 765 |
+
register_coco_instances(
|
| 766 |
+
key,
|
| 767 |
+
_get_builtin_metadata(),
|
| 768 |
+
os.path.join("datasets", json_file) if "://" not in json_file else json_file,
|
| 769 |
+
os.path.join("datasets", image_root),
|
| 770 |
+
)
|
proxydet/data/datasets/oid.py
ADDED
|
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Part of the code is from https://github.com/xingyizhou/UniDet/blob/master/projects/UniDet/unidet/data/datasets/oid.py
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
from .register_oid import register_oid_instances
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
categories = [
|
| 7 |
+
{'id': 1, 'name': 'Infant bed', 'freebase_id': '/m/061hd_'},
|
| 8 |
+
{'id': 2, 'name': 'Rose', 'freebase_id': '/m/06m11'},
|
| 9 |
+
{'id': 3, 'name': 'Flag', 'freebase_id': '/m/03120'},
|
| 10 |
+
{'id': 4, 'name': 'Flashlight', 'freebase_id': '/m/01kb5b'},
|
| 11 |
+
{'id': 5, 'name': 'Sea turtle', 'freebase_id': '/m/0120dh'},
|
| 12 |
+
{'id': 6, 'name': 'Camera', 'freebase_id': '/m/0dv5r'},
|
| 13 |
+
{'id': 7, 'name': 'Animal', 'freebase_id': '/m/0jbk'},
|
| 14 |
+
{'id': 8, 'name': 'Glove', 'freebase_id': '/m/0174n1'},
|
| 15 |
+
{'id': 9, 'name': 'Crocodile', 'freebase_id': '/m/09f_2'},
|
| 16 |
+
{'id': 10, 'name': 'Cattle', 'freebase_id': '/m/01xq0k1'},
|
| 17 |
+
{'id': 11, 'name': 'House', 'freebase_id': '/m/03jm5'},
|
| 18 |
+
{'id': 12, 'name': 'Guacamole', 'freebase_id': '/m/02g30s'},
|
| 19 |
+
{'id': 13, 'name': 'Penguin', 'freebase_id': '/m/05z6w'},
|
| 20 |
+
{'id': 14, 'name': 'Vehicle registration plate', 'freebase_id': '/m/01jfm_'},
|
| 21 |
+
{'id': 15, 'name': 'Bench', 'freebase_id': '/m/076lb9'},
|
| 22 |
+
{'id': 16, 'name': 'Ladybug', 'freebase_id': '/m/0gj37'},
|
| 23 |
+
{'id': 17, 'name': 'Human nose', 'freebase_id': '/m/0k0pj'},
|
| 24 |
+
{'id': 18, 'name': 'Watermelon', 'freebase_id': '/m/0kpqd'},
|
| 25 |
+
{'id': 19, 'name': 'Flute', 'freebase_id': '/m/0l14j_'},
|
| 26 |
+
{'id': 20, 'name': 'Butterfly', 'freebase_id': '/m/0cyf8'},
|
| 27 |
+
{'id': 21, 'name': 'Washing machine', 'freebase_id': '/m/0174k2'},
|
| 28 |
+
{'id': 22, 'name': 'Raccoon', 'freebase_id': '/m/0dq75'},
|
| 29 |
+
{'id': 23, 'name': 'Segway', 'freebase_id': '/m/076bq'},
|
| 30 |
+
{'id': 24, 'name': 'Taco', 'freebase_id': '/m/07crc'},
|
| 31 |
+
{'id': 25, 'name': 'Jellyfish', 'freebase_id': '/m/0d8zb'},
|
| 32 |
+
{'id': 26, 'name': 'Cake', 'freebase_id': '/m/0fszt'},
|
| 33 |
+
{'id': 27, 'name': 'Pen', 'freebase_id': '/m/0k1tl'},
|
| 34 |
+
{'id': 28, 'name': 'Cannon', 'freebase_id': '/m/020kz'},
|
| 35 |
+
{'id': 29, 'name': 'Bread', 'freebase_id': '/m/09728'},
|
| 36 |
+
{'id': 30, 'name': 'Tree', 'freebase_id': '/m/07j7r'},
|
| 37 |
+
{'id': 31, 'name': 'Shellfish', 'freebase_id': '/m/0fbdv'},
|
| 38 |
+
{'id': 32, 'name': 'Bed', 'freebase_id': '/m/03ssj5'},
|
| 39 |
+
{'id': 33, 'name': 'Hamster', 'freebase_id': '/m/03qrc'},
|
| 40 |
+
{'id': 34, 'name': 'Hat', 'freebase_id': '/m/02dl1y'},
|
| 41 |
+
{'id': 35, 'name': 'Toaster', 'freebase_id': '/m/01k6s3'},
|
| 42 |
+
{'id': 36, 'name': 'Sombrero', 'freebase_id': '/m/02jfl0'},
|
| 43 |
+
{'id': 37, 'name': 'Tiara', 'freebase_id': '/m/01krhy'},
|
| 44 |
+
{'id': 38, 'name': 'Bowl', 'freebase_id': '/m/04kkgm'},
|
| 45 |
+
{'id': 39, 'name': 'Dragonfly', 'freebase_id': '/m/0ft9s'},
|
| 46 |
+
{'id': 40, 'name': 'Moths and butterflies', 'freebase_id': '/m/0d_2m'},
|
| 47 |
+
{'id': 41, 'name': 'Antelope', 'freebase_id': '/m/0czz2'},
|
| 48 |
+
{'id': 42, 'name': 'Vegetable', 'freebase_id': '/m/0f4s2w'},
|
| 49 |
+
{'id': 43, 'name': 'Torch', 'freebase_id': '/m/07dd4'},
|
| 50 |
+
{'id': 44, 'name': 'Building', 'freebase_id': '/m/0cgh4'},
|
| 51 |
+
{'id': 45, 'name': 'Power plugs and sockets', 'freebase_id': '/m/03bbps'},
|
| 52 |
+
{'id': 46, 'name': 'Blender', 'freebase_id': '/m/02pjr4'},
|
| 53 |
+
{'id': 47, 'name': 'Billiard table', 'freebase_id': '/m/04p0qw'},
|
| 54 |
+
{'id': 48, 'name': 'Cutting board', 'freebase_id': '/m/02pdsw'},
|
| 55 |
+
{'id': 49, 'name': 'Bronze sculpture', 'freebase_id': '/m/01yx86'},
|
| 56 |
+
{'id': 50, 'name': 'Turtle', 'freebase_id': '/m/09dzg'},
|
| 57 |
+
{'id': 51, 'name': 'Broccoli', 'freebase_id': '/m/0hkxq'},
|
| 58 |
+
{'id': 52, 'name': 'Tiger', 'freebase_id': '/m/07dm6'},
|
| 59 |
+
{'id': 53, 'name': 'Mirror', 'freebase_id': '/m/054_l'},
|
| 60 |
+
{'id': 54, 'name': 'Bear', 'freebase_id': '/m/01dws'},
|
| 61 |
+
{'id': 55, 'name': 'Zucchini', 'freebase_id': '/m/027pcv'},
|
| 62 |
+
{'id': 56, 'name': 'Dress', 'freebase_id': '/m/01d40f'},
|
| 63 |
+
{'id': 57, 'name': 'Volleyball', 'freebase_id': '/m/02rgn06'},
|
| 64 |
+
{'id': 58, 'name': 'Guitar', 'freebase_id': '/m/0342h'},
|
| 65 |
+
{'id': 59, 'name': 'Reptile', 'freebase_id': '/m/06bt6'},
|
| 66 |
+
{'id': 60, 'name': 'Golf cart', 'freebase_id': '/m/0323sq'},
|
| 67 |
+
{'id': 61, 'name': 'Tart', 'freebase_id': '/m/02zvsm'},
|
| 68 |
+
{'id': 62, 'name': 'Fedora', 'freebase_id': '/m/02fq_6'},
|
| 69 |
+
{'id': 63, 'name': 'Carnivore', 'freebase_id': '/m/01lrl'},
|
| 70 |
+
{'id': 64, 'name': 'Car', 'freebase_id': '/m/0k4j'},
|
| 71 |
+
{'id': 65, 'name': 'Lighthouse', 'freebase_id': '/m/04h7h'},
|
| 72 |
+
{'id': 66, 'name': 'Coffeemaker', 'freebase_id': '/m/07xyvk'},
|
| 73 |
+
{'id': 67, 'name': 'Food processor', 'freebase_id': '/m/03y6mg'},
|
| 74 |
+
{'id': 68, 'name': 'Truck', 'freebase_id': '/m/07r04'},
|
| 75 |
+
{'id': 69, 'name': 'Bookcase', 'freebase_id': '/m/03__z0'},
|
| 76 |
+
{'id': 70, 'name': 'Surfboard', 'freebase_id': '/m/019w40'},
|
| 77 |
+
{'id': 71, 'name': 'Footwear', 'freebase_id': '/m/09j5n'},
|
| 78 |
+
{'id': 72, 'name': 'Bench', 'freebase_id': '/m/0cvnqh'},
|
| 79 |
+
{'id': 73, 'name': 'Necklace', 'freebase_id': '/m/01llwg'},
|
| 80 |
+
{'id': 74, 'name': 'Flower', 'freebase_id': '/m/0c9ph5'},
|
| 81 |
+
{'id': 75, 'name': 'Radish', 'freebase_id': '/m/015x5n'},
|
| 82 |
+
{'id': 76, 'name': 'Marine mammal', 'freebase_id': '/m/0gd2v'},
|
| 83 |
+
{'id': 77, 'name': 'Frying pan', 'freebase_id': '/m/04v6l4'},
|
| 84 |
+
{'id': 78, 'name': 'Tap', 'freebase_id': '/m/02jz0l'},
|
| 85 |
+
{'id': 79, 'name': 'Peach', 'freebase_id': '/m/0dj6p'},
|
| 86 |
+
{'id': 80, 'name': 'Knife', 'freebase_id': '/m/04ctx'},
|
| 87 |
+
{'id': 81, 'name': 'Handbag', 'freebase_id': '/m/080hkjn'},
|
| 88 |
+
{'id': 82, 'name': 'Laptop', 'freebase_id': '/m/01c648'},
|
| 89 |
+
{'id': 83, 'name': 'Tent', 'freebase_id': '/m/01j61q'},
|
| 90 |
+
{'id': 84, 'name': 'Ambulance', 'freebase_id': '/m/012n7d'},
|
| 91 |
+
{'id': 85, 'name': 'Christmas tree', 'freebase_id': '/m/025nd'},
|
| 92 |
+
{'id': 86, 'name': 'Eagle', 'freebase_id': '/m/09csl'},
|
| 93 |
+
{'id': 87, 'name': 'Limousine', 'freebase_id': '/m/01lcw4'},
|
| 94 |
+
{'id': 88, 'name': 'Kitchen & dining room table', 'freebase_id': '/m/0h8n5zk'},
|
| 95 |
+
{'id': 89, 'name': 'Polar bear', 'freebase_id': '/m/0633h'},
|
| 96 |
+
{'id': 90, 'name': 'Tower', 'freebase_id': '/m/01fdzj'},
|
| 97 |
+
{'id': 91, 'name': 'Football', 'freebase_id': '/m/01226z'},
|
| 98 |
+
{'id': 92, 'name': 'Willow', 'freebase_id': '/m/0mw_6'},
|
| 99 |
+
{'id': 93, 'name': 'Human head', 'freebase_id': '/m/04hgtk'},
|
| 100 |
+
{'id': 94, 'name': 'Stop sign', 'freebase_id': '/m/02pv19'},
|
| 101 |
+
{'id': 95, 'name': 'Banana', 'freebase_id': '/m/09qck'},
|
| 102 |
+
{'id': 96, 'name': 'Mixer', 'freebase_id': '/m/063rgb'},
|
| 103 |
+
{'id': 97, 'name': 'Binoculars', 'freebase_id': '/m/0lt4_'},
|
| 104 |
+
{'id': 98, 'name': 'Dessert', 'freebase_id': '/m/0270h'},
|
| 105 |
+
{'id': 99, 'name': 'Bee', 'freebase_id': '/m/01h3n'},
|
| 106 |
+
{'id': 100, 'name': 'Chair', 'freebase_id': '/m/01mzpv'},
|
| 107 |
+
{'id': 101, 'name': 'Wood-burning stove', 'freebase_id': '/m/04169hn'},
|
| 108 |
+
{'id': 102, 'name': 'Flowerpot', 'freebase_id': '/m/0fm3zh'},
|
| 109 |
+
{'id': 103, 'name': 'Beaker', 'freebase_id': '/m/0d20w4'},
|
| 110 |
+
{'id': 104, 'name': 'Oyster', 'freebase_id': '/m/0_cp5'},
|
| 111 |
+
{'id': 105, 'name': 'Woodpecker', 'freebase_id': '/m/01dy8n'},
|
| 112 |
+
{'id': 106, 'name': 'Harp', 'freebase_id': '/m/03m5k'},
|
| 113 |
+
{'id': 107, 'name': 'Bathtub', 'freebase_id': '/m/03dnzn'},
|
| 114 |
+
{'id': 108, 'name': 'Wall clock', 'freebase_id': '/m/0h8mzrc'},
|
| 115 |
+
{'id': 109, 'name': 'Sports uniform', 'freebase_id': '/m/0h8mhzd'},
|
| 116 |
+
{'id': 110, 'name': 'Rhinoceros', 'freebase_id': '/m/03d443'},
|
| 117 |
+
{'id': 111, 'name': 'Beehive', 'freebase_id': '/m/01gllr'},
|
| 118 |
+
{'id': 112, 'name': 'Cupboard', 'freebase_id': '/m/0642b4'},
|
| 119 |
+
{'id': 113, 'name': 'Chicken', 'freebase_id': '/m/09b5t'},
|
| 120 |
+
{'id': 114, 'name': 'Man', 'freebase_id': '/m/04yx4'},
|
| 121 |
+
{'id': 115, 'name': 'Blue jay', 'freebase_id': '/m/01f8m5'},
|
| 122 |
+
{'id': 116, 'name': 'Cucumber', 'freebase_id': '/m/015x4r'},
|
| 123 |
+
{'id': 117, 'name': 'Balloon', 'freebase_id': '/m/01j51'},
|
| 124 |
+
{'id': 118, 'name': 'Kite', 'freebase_id': '/m/02zt3'},
|
| 125 |
+
{'id': 119, 'name': 'Fireplace', 'freebase_id': '/m/03tw93'},
|
| 126 |
+
{'id': 120, 'name': 'Lantern', 'freebase_id': '/m/01jfsr'},
|
| 127 |
+
{'id': 121, 'name': 'Missile', 'freebase_id': '/m/04ylt'},
|
| 128 |
+
{'id': 122, 'name': 'Book', 'freebase_id': '/m/0bt_c3'},
|
| 129 |
+
{'id': 123, 'name': 'Spoon', 'freebase_id': '/m/0cmx8'},
|
| 130 |
+
{'id': 124, 'name': 'Grapefruit', 'freebase_id': '/m/0hqkz'},
|
| 131 |
+
{'id': 125, 'name': 'Squirrel', 'freebase_id': '/m/071qp'},
|
| 132 |
+
{'id': 126, 'name': 'Orange', 'freebase_id': '/m/0cyhj_'},
|
| 133 |
+
{'id': 127, 'name': 'Coat', 'freebase_id': '/m/01xygc'},
|
| 134 |
+
{'id': 128, 'name': 'Punching bag', 'freebase_id': '/m/0420v5'},
|
| 135 |
+
{'id': 129, 'name': 'Zebra', 'freebase_id': '/m/0898b'},
|
| 136 |
+
{'id': 130, 'name': 'Billboard', 'freebase_id': '/m/01knjb'},
|
| 137 |
+
{'id': 131, 'name': 'Bicycle', 'freebase_id': '/m/0199g'},
|
| 138 |
+
{'id': 132, 'name': 'Door handle', 'freebase_id': '/m/03c7gz'},
|
| 139 |
+
{'id': 133, 'name': 'Mechanical fan', 'freebase_id': '/m/02x984l'},
|
| 140 |
+
{'id': 134, 'name': 'Ring binder', 'freebase_id': '/m/04zwwv'},
|
| 141 |
+
{'id': 135, 'name': 'Table', 'freebase_id': '/m/04bcr3'},
|
| 142 |
+
{'id': 136, 'name': 'Parrot', 'freebase_id': '/m/0gv1x'},
|
| 143 |
+
{'id': 137, 'name': 'Sock', 'freebase_id': '/m/01nq26'},
|
| 144 |
+
{'id': 138, 'name': 'Vase', 'freebase_id': '/m/02s195'},
|
| 145 |
+
{'id': 139, 'name': 'Weapon', 'freebase_id': '/m/083kb'},
|
| 146 |
+
{'id': 140, 'name': 'Shotgun', 'freebase_id': '/m/06nrc'},
|
| 147 |
+
{'id': 141, 'name': 'Glasses', 'freebase_id': '/m/0jyfg'},
|
| 148 |
+
{'id': 142, 'name': 'Seahorse', 'freebase_id': '/m/0nybt'},
|
| 149 |
+
{'id': 143, 'name': 'Belt', 'freebase_id': '/m/0176mf'},
|
| 150 |
+
{'id': 144, 'name': 'Watercraft', 'freebase_id': '/m/01rzcn'},
|
| 151 |
+
{'id': 145, 'name': 'Window', 'freebase_id': '/m/0d4v4'},
|
| 152 |
+
{'id': 146, 'name': 'Giraffe', 'freebase_id': '/m/03bk1'},
|
| 153 |
+
{'id': 147, 'name': 'Lion', 'freebase_id': '/m/096mb'},
|
| 154 |
+
{'id': 148, 'name': 'Tire', 'freebase_id': '/m/0h9mv'},
|
| 155 |
+
{'id': 149, 'name': 'Vehicle', 'freebase_id': '/m/07yv9'},
|
| 156 |
+
{'id': 150, 'name': 'Canoe', 'freebase_id': '/m/0ph39'},
|
| 157 |
+
{'id': 151, 'name': 'Tie', 'freebase_id': '/m/01rkbr'},
|
| 158 |
+
{'id': 152, 'name': 'Shelf', 'freebase_id': '/m/0gjbg72'},
|
| 159 |
+
{'id': 153, 'name': 'Picture frame', 'freebase_id': '/m/06z37_'},
|
| 160 |
+
{'id': 154, 'name': 'Printer', 'freebase_id': '/m/01m4t'},
|
| 161 |
+
{'id': 155, 'name': 'Human leg', 'freebase_id': '/m/035r7c'},
|
| 162 |
+
{'id': 156, 'name': 'Boat', 'freebase_id': '/m/019jd'},
|
| 163 |
+
{'id': 157, 'name': 'Slow cooker', 'freebase_id': '/m/02tsc9'},
|
| 164 |
+
{'id': 158, 'name': 'Croissant', 'freebase_id': '/m/015wgc'},
|
| 165 |
+
{'id': 159, 'name': 'Candle', 'freebase_id': '/m/0c06p'},
|
| 166 |
+
{'id': 160, 'name': 'Pancake', 'freebase_id': '/m/01dwwc'},
|
| 167 |
+
{'id': 161, 'name': 'Pillow', 'freebase_id': '/m/034c16'},
|
| 168 |
+
{'id': 162, 'name': 'Coin', 'freebase_id': '/m/0242l'},
|
| 169 |
+
{'id': 163, 'name': 'Stretcher', 'freebase_id': '/m/02lbcq'},
|
| 170 |
+
{'id': 164, 'name': 'Sandal', 'freebase_id': '/m/03nfch'},
|
| 171 |
+
{'id': 165, 'name': 'Woman', 'freebase_id': '/m/03bt1vf'},
|
| 172 |
+
{'id': 166, 'name': 'Stairs', 'freebase_id': '/m/01lynh'},
|
| 173 |
+
{'id': 167, 'name': 'Harpsichord', 'freebase_id': '/m/03q5t'},
|
| 174 |
+
{'id': 168, 'name': 'Stool', 'freebase_id': '/m/0fqt361'},
|
| 175 |
+
{'id': 169, 'name': 'Bus', 'freebase_id': '/m/01bjv'},
|
| 176 |
+
{'id': 170, 'name': 'Suitcase', 'freebase_id': '/m/01s55n'},
|
| 177 |
+
{'id': 171, 'name': 'Human mouth', 'freebase_id': '/m/0283dt1'},
|
| 178 |
+
{'id': 172, 'name': 'Juice', 'freebase_id': '/m/01z1kdw'},
|
| 179 |
+
{'id': 173, 'name': 'Skull', 'freebase_id': '/m/016m2d'},
|
| 180 |
+
{'id': 174, 'name': 'Door', 'freebase_id': '/m/02dgv'},
|
| 181 |
+
{'id': 175, 'name': 'Violin', 'freebase_id': '/m/07y_7'},
|
| 182 |
+
{'id': 176, 'name': 'Chopsticks', 'freebase_id': '/m/01_5g'},
|
| 183 |
+
{'id': 177, 'name': 'Digital clock', 'freebase_id': '/m/06_72j'},
|
| 184 |
+
{'id': 178, 'name': 'Sunflower', 'freebase_id': '/m/0ftb8'},
|
| 185 |
+
{'id': 179, 'name': 'Leopard', 'freebase_id': '/m/0c29q'},
|
| 186 |
+
{'id': 180, 'name': 'Bell pepper', 'freebase_id': '/m/0jg57'},
|
| 187 |
+
{'id': 181, 'name': 'Harbor seal', 'freebase_id': '/m/02l8p9'},
|
| 188 |
+
{'id': 182, 'name': 'Snake', 'freebase_id': '/m/078jl'},
|
| 189 |
+
{'id': 183, 'name': 'Sewing machine', 'freebase_id': '/m/0llzx'},
|
| 190 |
+
{'id': 184, 'name': 'Goose', 'freebase_id': '/m/0dbvp'},
|
| 191 |
+
{'id': 185, 'name': 'Helicopter', 'freebase_id': '/m/09ct_'},
|
| 192 |
+
{'id': 186, 'name': 'Seat belt', 'freebase_id': '/m/0dkzw'},
|
| 193 |
+
{'id': 187, 'name': 'Coffee cup', 'freebase_id': '/m/02p5f1q'},
|
| 194 |
+
{'id': 188, 'name': 'Microwave oven', 'freebase_id': '/m/0fx9l'},
|
| 195 |
+
{'id': 189, 'name': 'Hot dog', 'freebase_id': '/m/01b9xk'},
|
| 196 |
+
{'id': 190, 'name': 'Countertop', 'freebase_id': '/m/0b3fp9'},
|
| 197 |
+
{'id': 191, 'name': 'Serving tray', 'freebase_id': '/m/0h8n27j'},
|
| 198 |
+
{'id': 192, 'name': 'Dog bed', 'freebase_id': '/m/0h8n6f9'},
|
| 199 |
+
{'id': 193, 'name': 'Beer', 'freebase_id': '/m/01599'},
|
| 200 |
+
{'id': 194, 'name': 'Sunglasses', 'freebase_id': '/m/017ftj'},
|
| 201 |
+
{'id': 195, 'name': 'Golf ball', 'freebase_id': '/m/044r5d'},
|
| 202 |
+
{'id': 196, 'name': 'Waffle', 'freebase_id': '/m/01dwsz'},
|
| 203 |
+
{'id': 197, 'name': 'Palm tree', 'freebase_id': '/m/0cdl1'},
|
| 204 |
+
{'id': 198, 'name': 'Trumpet', 'freebase_id': '/m/07gql'},
|
| 205 |
+
{'id': 199, 'name': 'Ruler', 'freebase_id': '/m/0hdln'},
|
| 206 |
+
{'id': 200, 'name': 'Helmet', 'freebase_id': '/m/0zvk5'},
|
| 207 |
+
{'id': 201, 'name': 'Ladder', 'freebase_id': '/m/012w5l'},
|
| 208 |
+
{'id': 202, 'name': 'Office building', 'freebase_id': '/m/021sj1'},
|
| 209 |
+
{'id': 203, 'name': 'Tablet computer', 'freebase_id': '/m/0bh9flk'},
|
| 210 |
+
{'id': 204, 'name': 'Toilet paper', 'freebase_id': '/m/09gtd'},
|
| 211 |
+
{'id': 205, 'name': 'Pomegranate', 'freebase_id': '/m/0jwn_'},
|
| 212 |
+
{'id': 206, 'name': 'Skirt', 'freebase_id': '/m/02wv6h6'},
|
| 213 |
+
{'id': 207, 'name': 'Gas stove', 'freebase_id': '/m/02wv84t'},
|
| 214 |
+
{'id': 208, 'name': 'Cookie', 'freebase_id': '/m/021mn'},
|
| 215 |
+
{'id': 209, 'name': 'Cart', 'freebase_id': '/m/018p4k'},
|
| 216 |
+
{'id': 210, 'name': 'Raven', 'freebase_id': '/m/06j2d'},
|
| 217 |
+
{'id': 211, 'name': 'Egg', 'freebase_id': '/m/033cnk'},
|
| 218 |
+
{'id': 212, 'name': 'Burrito', 'freebase_id': '/m/01j3zr'},
|
| 219 |
+
{'id': 213, 'name': 'Goat', 'freebase_id': '/m/03fwl'},
|
| 220 |
+
{'id': 214, 'name': 'Kitchen knife', 'freebase_id': '/m/058qzx'},
|
| 221 |
+
{'id': 215, 'name': 'Skateboard', 'freebase_id': '/m/06_fw'},
|
| 222 |
+
{'id': 216, 'name': 'Salt and pepper shakers', 'freebase_id': '/m/02x8cch'},
|
| 223 |
+
{'id': 217, 'name': 'Lynx', 'freebase_id': '/m/04g2r'},
|
| 224 |
+
{'id': 218, 'name': 'Boot', 'freebase_id': '/m/01b638'},
|
| 225 |
+
{'id': 219, 'name': 'Platter', 'freebase_id': '/m/099ssp'},
|
| 226 |
+
{'id': 220, 'name': 'Ski', 'freebase_id': '/m/071p9'},
|
| 227 |
+
{'id': 221, 'name': 'Swimwear', 'freebase_id': '/m/01gkx_'},
|
| 228 |
+
{'id': 222, 'name': 'Swimming pool', 'freebase_id': '/m/0b_rs'},
|
| 229 |
+
{'id': 223, 'name': 'Drinking straw', 'freebase_id': '/m/03v5tg'},
|
| 230 |
+
{'id': 224, 'name': 'Wrench', 'freebase_id': '/m/01j5ks'},
|
| 231 |
+
{'id': 225, 'name': 'Drum', 'freebase_id': '/m/026t6'},
|
| 232 |
+
{'id': 226, 'name': 'Ant', 'freebase_id': '/m/0_k2'},
|
| 233 |
+
{'id': 227, 'name': 'Human ear', 'freebase_id': '/m/039xj_'},
|
| 234 |
+
{'id': 228, 'name': 'Headphones', 'freebase_id': '/m/01b7fy'},
|
| 235 |
+
{'id': 229, 'name': 'Fountain', 'freebase_id': '/m/0220r2'},
|
| 236 |
+
{'id': 230, 'name': 'Bird', 'freebase_id': '/m/015p6'},
|
| 237 |
+
{'id': 231, 'name': 'Jeans', 'freebase_id': '/m/0fly7'},
|
| 238 |
+
{'id': 232, 'name': 'Television', 'freebase_id': '/m/07c52'},
|
| 239 |
+
{'id': 233, 'name': 'Crab', 'freebase_id': '/m/0n28_'},
|
| 240 |
+
{'id': 234, 'name': 'Microphone', 'freebase_id': '/m/0hg7b'},
|
| 241 |
+
{'id': 235, 'name': 'Home appliance', 'freebase_id': '/m/019dx1'},
|
| 242 |
+
{'id': 236, 'name': 'Snowplow', 'freebase_id': '/m/04vv5k'},
|
| 243 |
+
{'id': 237, 'name': 'Beetle', 'freebase_id': '/m/020jm'},
|
| 244 |
+
{'id': 238, 'name': 'Artichoke', 'freebase_id': '/m/047v4b'},
|
| 245 |
+
{'id': 239, 'name': 'Jet ski', 'freebase_id': '/m/01xs3r'},
|
| 246 |
+
{'id': 240, 'name': 'Stationary bicycle', 'freebase_id': '/m/03kt2w'},
|
| 247 |
+
{'id': 241, 'name': 'Human hair', 'freebase_id': '/m/03q69'},
|
| 248 |
+
{'id': 242, 'name': 'Brown bear', 'freebase_id': '/m/01dxs'},
|
| 249 |
+
{'id': 243, 'name': 'Starfish', 'freebase_id': '/m/01h8tj'},
|
| 250 |
+
{'id': 244, 'name': 'Fork', 'freebase_id': '/m/0dt3t'},
|
| 251 |
+
{'id': 245, 'name': 'Lobster', 'freebase_id': '/m/0cjq5'},
|
| 252 |
+
{'id': 246, 'name': 'Corded phone', 'freebase_id': '/m/0h8lkj8'},
|
| 253 |
+
{'id': 247, 'name': 'Drink', 'freebase_id': '/m/0271t'},
|
| 254 |
+
{'id': 248, 'name': 'Saucer', 'freebase_id': '/m/03q5c7'},
|
| 255 |
+
{'id': 249, 'name': 'Carrot', 'freebase_id': '/m/0fj52s'},
|
| 256 |
+
{'id': 250, 'name': 'Insect', 'freebase_id': '/m/03vt0'},
|
| 257 |
+
{'id': 251, 'name': 'Clock', 'freebase_id': '/m/01x3z'},
|
| 258 |
+
{'id': 252, 'name': 'Castle', 'freebase_id': '/m/0d5gx'},
|
| 259 |
+
{'id': 253, 'name': 'Tennis racket', 'freebase_id': '/m/0h8my_4'},
|
| 260 |
+
{'id': 254, 'name': 'Ceiling fan', 'freebase_id': '/m/03ldnb'},
|
| 261 |
+
{'id': 255, 'name': 'Asparagus', 'freebase_id': '/m/0cjs7'},
|
| 262 |
+
{'id': 256, 'name': 'Jaguar', 'freebase_id': '/m/0449p'},
|
| 263 |
+
{'id': 257, 'name': 'Musical instrument', 'freebase_id': '/m/04szw'},
|
| 264 |
+
{'id': 258, 'name': 'Train', 'freebase_id': '/m/07jdr'},
|
| 265 |
+
{'id': 259, 'name': 'Cat', 'freebase_id': '/m/01yrx'},
|
| 266 |
+
{'id': 260, 'name': 'Rifle', 'freebase_id': '/m/06c54'},
|
| 267 |
+
{'id': 261, 'name': 'Dumbbell', 'freebase_id': '/m/04h8sr'},
|
| 268 |
+
{'id': 262, 'name': 'Mobile phone', 'freebase_id': '/m/050k8'},
|
| 269 |
+
{'id': 263, 'name': 'Taxi', 'freebase_id': '/m/0pg52'},
|
| 270 |
+
{'id': 264, 'name': 'Shower', 'freebase_id': '/m/02f9f_'},
|
| 271 |
+
{'id': 265, 'name': 'Pitcher', 'freebase_id': '/m/054fyh'},
|
| 272 |
+
{'id': 266, 'name': 'Lemon', 'freebase_id': '/m/09k_b'},
|
| 273 |
+
{'id': 267, 'name': 'Invertebrate', 'freebase_id': '/m/03xxp'},
|
| 274 |
+
{'id': 268, 'name': 'Turkey', 'freebase_id': '/m/0jly1'},
|
| 275 |
+
{'id': 269, 'name': 'High heels', 'freebase_id': '/m/06k2mb'},
|
| 276 |
+
{'id': 270, 'name': 'Bust', 'freebase_id': '/m/04yqq2'},
|
| 277 |
+
{'id': 271, 'name': 'Elephant', 'freebase_id': '/m/0bwd_0j'},
|
| 278 |
+
{'id': 272, 'name': 'Scarf', 'freebase_id': '/m/02h19r'},
|
| 279 |
+
{'id': 273, 'name': 'Barrel', 'freebase_id': '/m/02zn6n'},
|
| 280 |
+
{'id': 274, 'name': 'Trombone', 'freebase_id': '/m/07c6l'},
|
| 281 |
+
{'id': 275, 'name': 'Pumpkin', 'freebase_id': '/m/05zsy'},
|
| 282 |
+
{'id': 276, 'name': 'Box', 'freebase_id': '/m/025dyy'},
|
| 283 |
+
{'id': 277, 'name': 'Tomato', 'freebase_id': '/m/07j87'},
|
| 284 |
+
{'id': 278, 'name': 'Frog', 'freebase_id': '/m/09ld4'},
|
| 285 |
+
{'id': 279, 'name': 'Bidet', 'freebase_id': '/m/01vbnl'},
|
| 286 |
+
{'id': 280, 'name': 'Human face', 'freebase_id': '/m/0dzct'},
|
| 287 |
+
{'id': 281, 'name': 'Houseplant', 'freebase_id': '/m/03fp41'},
|
| 288 |
+
{'id': 282, 'name': 'Van', 'freebase_id': '/m/0h2r6'},
|
| 289 |
+
{'id': 283, 'name': 'Shark', 'freebase_id': '/m/0by6g'},
|
| 290 |
+
{'id': 284, 'name': 'Ice cream', 'freebase_id': '/m/0cxn2'},
|
| 291 |
+
{'id': 285, 'name': 'Swim cap', 'freebase_id': '/m/04tn4x'},
|
| 292 |
+
{'id': 286, 'name': 'Falcon', 'freebase_id': '/m/0f6wt'},
|
| 293 |
+
{'id': 287, 'name': 'Ostrich', 'freebase_id': '/m/05n4y'},
|
| 294 |
+
{'id': 288, 'name': 'Handgun', 'freebase_id': '/m/0gxl3'},
|
| 295 |
+
{'id': 289, 'name': 'Whiteboard', 'freebase_id': '/m/02d9qx'},
|
| 296 |
+
{'id': 290, 'name': 'Lizard', 'freebase_id': '/m/04m9y'},
|
| 297 |
+
{'id': 291, 'name': 'Pasta', 'freebase_id': '/m/05z55'},
|
| 298 |
+
{'id': 292, 'name': 'Snowmobile', 'freebase_id': '/m/01x3jk'},
|
| 299 |
+
{'id': 293, 'name': 'Light bulb', 'freebase_id': '/m/0h8l4fh'},
|
| 300 |
+
{'id': 294, 'name': 'Window blind', 'freebase_id': '/m/031b6r'},
|
| 301 |
+
{'id': 295, 'name': 'Muffin', 'freebase_id': '/m/01tcjp'},
|
| 302 |
+
{'id': 296, 'name': 'Pretzel', 'freebase_id': '/m/01f91_'},
|
| 303 |
+
{'id': 297, 'name': 'Computer monitor', 'freebase_id': '/m/02522'},
|
| 304 |
+
{'id': 298, 'name': 'Horn', 'freebase_id': '/m/0319l'},
|
| 305 |
+
{'id': 299, 'name': 'Furniture', 'freebase_id': '/m/0c_jw'},
|
| 306 |
+
{'id': 300, 'name': 'Sandwich', 'freebase_id': '/m/0l515'},
|
| 307 |
+
{'id': 301, 'name': 'Fox', 'freebase_id': '/m/0306r'},
|
| 308 |
+
{'id': 302, 'name': 'Convenience store', 'freebase_id': '/m/0crjs'},
|
| 309 |
+
{'id': 303, 'name': 'Fish', 'freebase_id': '/m/0ch_cf'},
|
| 310 |
+
{'id': 304, 'name': 'Fruit', 'freebase_id': '/m/02xwb'},
|
| 311 |
+
{'id': 305, 'name': 'Earrings', 'freebase_id': '/m/01r546'},
|
| 312 |
+
{'id': 306, 'name': 'Curtain', 'freebase_id': '/m/03rszm'},
|
| 313 |
+
{'id': 307, 'name': 'Grape', 'freebase_id': '/m/0388q'},
|
| 314 |
+
{'id': 308, 'name': 'Sofa bed', 'freebase_id': '/m/03m3pdh'},
|
| 315 |
+
{'id': 309, 'name': 'Horse', 'freebase_id': '/m/03k3r'},
|
| 316 |
+
{'id': 310, 'name': 'Luggage and bags', 'freebase_id': '/m/0hf58v5'},
|
| 317 |
+
{'id': 311, 'name': 'Desk', 'freebase_id': '/m/01y9k5'},
|
| 318 |
+
{'id': 312, 'name': 'Crutch', 'freebase_id': '/m/05441v'},
|
| 319 |
+
{'id': 313, 'name': 'Bicycle helmet', 'freebase_id': '/m/03p3bw'},
|
| 320 |
+
{'id': 314, 'name': 'Tick', 'freebase_id': '/m/0175cv'},
|
| 321 |
+
{'id': 315, 'name': 'Airplane', 'freebase_id': '/m/0cmf2'},
|
| 322 |
+
{'id': 316, 'name': 'Canary', 'freebase_id': '/m/0ccs93'},
|
| 323 |
+
{'id': 317, 'name': 'Spatula', 'freebase_id': '/m/02d1br'},
|
| 324 |
+
{'id': 318, 'name': 'Watch', 'freebase_id': '/m/0gjkl'},
|
| 325 |
+
{'id': 319, 'name': 'Lily', 'freebase_id': '/m/0jqgx'},
|
| 326 |
+
{'id': 320, 'name': 'Kitchen appliance', 'freebase_id': '/m/0h99cwc'},
|
| 327 |
+
{'id': 321, 'name': 'Filing cabinet', 'freebase_id': '/m/047j0r'},
|
| 328 |
+
{'id': 322, 'name': 'Aircraft', 'freebase_id': '/m/0k5j'},
|
| 329 |
+
{'id': 323, 'name': 'Cake stand', 'freebase_id': '/m/0h8n6ft'},
|
| 330 |
+
{'id': 324, 'name': 'Candy', 'freebase_id': '/m/0gm28'},
|
| 331 |
+
{'id': 325, 'name': 'Sink', 'freebase_id': '/m/0130jx'},
|
| 332 |
+
{'id': 326, 'name': 'Mouse', 'freebase_id': '/m/04rmv'},
|
| 333 |
+
{'id': 327, 'name': 'Wine', 'freebase_id': '/m/081qc'},
|
| 334 |
+
{'id': 328, 'name': 'Wheelchair', 'freebase_id': '/m/0qmmr'},
|
| 335 |
+
{'id': 329, 'name': 'Goldfish', 'freebase_id': '/m/03fj2'},
|
| 336 |
+
{'id': 330, 'name': 'Refrigerator', 'freebase_id': '/m/040b_t'},
|
| 337 |
+
{'id': 331, 'name': 'French fries', 'freebase_id': '/m/02y6n'},
|
| 338 |
+
{'id': 332, 'name': 'Drawer', 'freebase_id': '/m/0fqfqc'},
|
| 339 |
+
{'id': 333, 'name': 'Treadmill', 'freebase_id': '/m/030610'},
|
| 340 |
+
{'id': 334, 'name': 'Picnic basket', 'freebase_id': '/m/07kng9'},
|
| 341 |
+
{'id': 335, 'name': 'Dice', 'freebase_id': '/m/029b3'},
|
| 342 |
+
{'id': 336, 'name': 'Cabbage', 'freebase_id': '/m/0fbw6'},
|
| 343 |
+
{'id': 337, 'name': 'Football helmet', 'freebase_id': '/m/07qxg_'},
|
| 344 |
+
{'id': 338, 'name': 'Pig', 'freebase_id': '/m/068zj'},
|
| 345 |
+
{'id': 339, 'name': 'Person', 'freebase_id': '/m/01g317'},
|
| 346 |
+
{'id': 340, 'name': 'Shorts', 'freebase_id': '/m/01bfm9'},
|
| 347 |
+
{'id': 341, 'name': 'Gondola', 'freebase_id': '/m/02068x'},
|
| 348 |
+
{'id': 342, 'name': 'Honeycomb', 'freebase_id': '/m/0fz0h'},
|
| 349 |
+
{'id': 343, 'name': 'Doughnut', 'freebase_id': '/m/0jy4k'},
|
| 350 |
+
{'id': 344, 'name': 'Chest of drawers', 'freebase_id': '/m/05kyg_'},
|
| 351 |
+
{'id': 345, 'name': 'Land vehicle', 'freebase_id': '/m/01prls'},
|
| 352 |
+
{'id': 346, 'name': 'Bat', 'freebase_id': '/m/01h44'},
|
| 353 |
+
{'id': 347, 'name': 'Monkey', 'freebase_id': '/m/08pbxl'},
|
| 354 |
+
{'id': 348, 'name': 'Dagger', 'freebase_id': '/m/02gzp'},
|
| 355 |
+
{'id': 349, 'name': 'Tableware', 'freebase_id': '/m/04brg2'},
|
| 356 |
+
{'id': 350, 'name': 'Human foot', 'freebase_id': '/m/031n1'},
|
| 357 |
+
{'id': 351, 'name': 'Mug', 'freebase_id': '/m/02jvh9'},
|
| 358 |
+
{'id': 352, 'name': 'Alarm clock', 'freebase_id': '/m/046dlr'},
|
| 359 |
+
{'id': 353, 'name': 'Pressure cooker', 'freebase_id': '/m/0h8ntjv'},
|
| 360 |
+
{'id': 354, 'name': 'Human hand', 'freebase_id': '/m/0k65p'},
|
| 361 |
+
{'id': 355, 'name': 'Tortoise', 'freebase_id': '/m/011k07'},
|
| 362 |
+
{'id': 356, 'name': 'Baseball glove', 'freebase_id': '/m/03grzl'},
|
| 363 |
+
{'id': 357, 'name': 'Sword', 'freebase_id': '/m/06y5r'},
|
| 364 |
+
{'id': 358, 'name': 'Pear', 'freebase_id': '/m/061_f'},
|
| 365 |
+
{'id': 359, 'name': 'Miniskirt', 'freebase_id': '/m/01cmb2'},
|
| 366 |
+
{'id': 360, 'name': 'Traffic sign', 'freebase_id': '/m/01mqdt'},
|
| 367 |
+
{'id': 361, 'name': 'Girl', 'freebase_id': '/m/05r655'},
|
| 368 |
+
{'id': 362, 'name': 'Roller skates', 'freebase_id': '/m/02p3w7d'},
|
| 369 |
+
{'id': 363, 'name': 'Dinosaur', 'freebase_id': '/m/029tx'},
|
| 370 |
+
{'id': 364, 'name': 'Porch', 'freebase_id': '/m/04m6gz'},
|
| 371 |
+
{'id': 365, 'name': 'Human beard', 'freebase_id': '/m/015h_t'},
|
| 372 |
+
{'id': 366, 'name': 'Submarine sandwich', 'freebase_id': '/m/06pcq'},
|
| 373 |
+
{'id': 367, 'name': 'Screwdriver', 'freebase_id': '/m/01bms0'},
|
| 374 |
+
{'id': 368, 'name': 'Strawberry', 'freebase_id': '/m/07fbm7'},
|
| 375 |
+
{'id': 369, 'name': 'Wine glass', 'freebase_id': '/m/09tvcd'},
|
| 376 |
+
{'id': 370, 'name': 'Seafood', 'freebase_id': '/m/06nwz'},
|
| 377 |
+
{'id': 371, 'name': 'Racket', 'freebase_id': '/m/0dv9c'},
|
| 378 |
+
{'id': 372, 'name': 'Wheel', 'freebase_id': '/m/083wq'},
|
| 379 |
+
{'id': 373, 'name': 'Sea lion', 'freebase_id': '/m/0gd36'},
|
| 380 |
+
{'id': 374, 'name': 'Toy', 'freebase_id': '/m/0138tl'},
|
| 381 |
+
{'id': 375, 'name': 'Tea', 'freebase_id': '/m/07clx'},
|
| 382 |
+
{'id': 376, 'name': 'Tennis ball', 'freebase_id': '/m/05ctyq'},
|
| 383 |
+
{'id': 377, 'name': 'Waste container', 'freebase_id': '/m/0bjyj5'},
|
| 384 |
+
{'id': 378, 'name': 'Mule', 'freebase_id': '/m/0dbzx'},
|
| 385 |
+
{'id': 379, 'name': 'Cricket ball', 'freebase_id': '/m/02ctlc'},
|
| 386 |
+
{'id': 380, 'name': 'Pineapple', 'freebase_id': '/m/0fp6w'},
|
| 387 |
+
{'id': 381, 'name': 'Coconut', 'freebase_id': '/m/0djtd'},
|
| 388 |
+
{'id': 382, 'name': 'Doll', 'freebase_id': '/m/0167gd'},
|
| 389 |
+
{'id': 383, 'name': 'Coffee table', 'freebase_id': '/m/078n6m'},
|
| 390 |
+
{'id': 384, 'name': 'Snowman', 'freebase_id': '/m/0152hh'},
|
| 391 |
+
{'id': 385, 'name': 'Lavender', 'freebase_id': '/m/04gth'},
|
| 392 |
+
{'id': 386, 'name': 'Shrimp', 'freebase_id': '/m/0ll1f78'},
|
| 393 |
+
{'id': 387, 'name': 'Maple', 'freebase_id': '/m/0cffdh'},
|
| 394 |
+
{'id': 388, 'name': 'Cowboy hat', 'freebase_id': '/m/025rp__'},
|
| 395 |
+
{'id': 389, 'name': 'Goggles', 'freebase_id': '/m/02_n6y'},
|
| 396 |
+
{'id': 390, 'name': 'Rugby ball', 'freebase_id': '/m/0wdt60w'},
|
| 397 |
+
{'id': 391, 'name': 'Caterpillar', 'freebase_id': '/m/0cydv'},
|
| 398 |
+
{'id': 392, 'name': 'Poster', 'freebase_id': '/m/01n5jq'},
|
| 399 |
+
{'id': 393, 'name': 'Rocket', 'freebase_id': '/m/09rvcxw'},
|
| 400 |
+
{'id': 394, 'name': 'Organ', 'freebase_id': '/m/013y1f'},
|
| 401 |
+
{'id': 395, 'name': 'Saxophone', 'freebase_id': '/m/06ncr'},
|
| 402 |
+
{'id': 396, 'name': 'Traffic light', 'freebase_id': '/m/015qff'},
|
| 403 |
+
{'id': 397, 'name': 'Cocktail', 'freebase_id': '/m/024g6'},
|
| 404 |
+
{'id': 398, 'name': 'Plastic bag', 'freebase_id': '/m/05gqfk'},
|
| 405 |
+
{'id': 399, 'name': 'Squash', 'freebase_id': '/m/0dv77'},
|
| 406 |
+
{'id': 400, 'name': 'Mushroom', 'freebase_id': '/m/052sf'},
|
| 407 |
+
{'id': 401, 'name': 'Hamburger', 'freebase_id': '/m/0cdn1'},
|
| 408 |
+
{'id': 402, 'name': 'Light switch', 'freebase_id': '/m/03jbxj'},
|
| 409 |
+
{'id': 403, 'name': 'Parachute', 'freebase_id': '/m/0cyfs'},
|
| 410 |
+
{'id': 404, 'name': 'Teddy bear', 'freebase_id': '/m/0kmg4'},
|
| 411 |
+
{'id': 405, 'name': 'Winter melon', 'freebase_id': '/m/02cvgx'},
|
| 412 |
+
{'id': 406, 'name': 'Deer', 'freebase_id': '/m/09kx5'},
|
| 413 |
+
{'id': 407, 'name': 'Musical keyboard', 'freebase_id': '/m/057cc'},
|
| 414 |
+
{'id': 408, 'name': 'Plumbing fixture', 'freebase_id': '/m/02pkr5'},
|
| 415 |
+
{'id': 409, 'name': 'Scoreboard', 'freebase_id': '/m/057p5t'},
|
| 416 |
+
{'id': 410, 'name': 'Baseball bat', 'freebase_id': '/m/03g8mr'},
|
| 417 |
+
{'id': 411, 'name': 'Envelope', 'freebase_id': '/m/0frqm'},
|
| 418 |
+
{'id': 412, 'name': 'Adhesive tape', 'freebase_id': '/m/03m3vtv'},
|
| 419 |
+
{'id': 413, 'name': 'Briefcase', 'freebase_id': '/m/0584n8'},
|
| 420 |
+
{'id': 414, 'name': 'Paddle', 'freebase_id': '/m/014y4n'},
|
| 421 |
+
{'id': 415, 'name': 'Bow and arrow', 'freebase_id': '/m/01g3x7'},
|
| 422 |
+
{'id': 416, 'name': 'Telephone', 'freebase_id': '/m/07cx4'},
|
| 423 |
+
{'id': 417, 'name': 'Sheep', 'freebase_id': '/m/07bgp'},
|
| 424 |
+
{'id': 418, 'name': 'Jacket', 'freebase_id': '/m/032b3c'},
|
| 425 |
+
{'id': 419, 'name': 'Boy', 'freebase_id': '/m/01bl7v'},
|
| 426 |
+
{'id': 420, 'name': 'Pizza', 'freebase_id': '/m/0663v'},
|
| 427 |
+
{'id': 421, 'name': 'Otter', 'freebase_id': '/m/0cn6p'},
|
| 428 |
+
{'id': 422, 'name': 'Office supplies', 'freebase_id': '/m/02rdsp'},
|
| 429 |
+
{'id': 423, 'name': 'Couch', 'freebase_id': '/m/02crq1'},
|
| 430 |
+
{'id': 424, 'name': 'Cello', 'freebase_id': '/m/01xqw'},
|
| 431 |
+
{'id': 425, 'name': 'Bull', 'freebase_id': '/m/0cnyhnx'},
|
| 432 |
+
{'id': 426, 'name': 'Camel', 'freebase_id': '/m/01x_v'},
|
| 433 |
+
{'id': 427, 'name': 'Ball', 'freebase_id': '/m/018xm'},
|
| 434 |
+
{'id': 428, 'name': 'Duck', 'freebase_id': '/m/09ddx'},
|
| 435 |
+
{'id': 429, 'name': 'Whale', 'freebase_id': '/m/084zz'},
|
| 436 |
+
{'id': 430, 'name': 'Shirt', 'freebase_id': '/m/01n4qj'},
|
| 437 |
+
{'id': 431, 'name': 'Tank', 'freebase_id': '/m/07cmd'},
|
| 438 |
+
{'id': 432, 'name': 'Motorcycle', 'freebase_id': '/m/04_sv'},
|
| 439 |
+
{'id': 433, 'name': 'Accordion', 'freebase_id': '/m/0mkg'},
|
| 440 |
+
{'id': 434, 'name': 'Owl', 'freebase_id': '/m/09d5_'},
|
| 441 |
+
{'id': 435, 'name': 'Porcupine', 'freebase_id': '/m/0c568'},
|
| 442 |
+
{'id': 436, 'name': 'Sun hat', 'freebase_id': '/m/02wbtzl'},
|
| 443 |
+
{'id': 437, 'name': 'Nail', 'freebase_id': '/m/05bm6'},
|
| 444 |
+
{'id': 438, 'name': 'Scissors', 'freebase_id': '/m/01lsmm'},
|
| 445 |
+
{'id': 439, 'name': 'Swan', 'freebase_id': '/m/0dftk'},
|
| 446 |
+
{'id': 440, 'name': 'Lamp', 'freebase_id': '/m/0dtln'},
|
| 447 |
+
{'id': 441, 'name': 'Crown', 'freebase_id': '/m/0nl46'},
|
| 448 |
+
{'id': 442, 'name': 'Piano', 'freebase_id': '/m/05r5c'},
|
| 449 |
+
{'id': 443, 'name': 'Sculpture', 'freebase_id': '/m/06msq'},
|
| 450 |
+
{'id': 444, 'name': 'Cheetah', 'freebase_id': '/m/0cd4d'},
|
| 451 |
+
{'id': 445, 'name': 'Oboe', 'freebase_id': '/m/05kms'},
|
| 452 |
+
{'id': 446, 'name': 'Tin can', 'freebase_id': '/m/02jnhm'},
|
| 453 |
+
{'id': 447, 'name': 'Mango', 'freebase_id': '/m/0fldg'},
|
| 454 |
+
{'id': 448, 'name': 'Tripod', 'freebase_id': '/m/073bxn'},
|
| 455 |
+
{'id': 449, 'name': 'Oven', 'freebase_id': '/m/029bxz'},
|
| 456 |
+
{'id': 450, 'name': 'Mouse', 'freebase_id': '/m/020lf'},
|
| 457 |
+
{'id': 451, 'name': 'Barge', 'freebase_id': '/m/01btn'},
|
| 458 |
+
{'id': 452, 'name': 'Coffee', 'freebase_id': '/m/02vqfm'},
|
| 459 |
+
{'id': 453, 'name': 'Snowboard', 'freebase_id': '/m/06__v'},
|
| 460 |
+
{'id': 454, 'name': 'Common fig', 'freebase_id': '/m/043nyj'},
|
| 461 |
+
{'id': 455, 'name': 'Salad', 'freebase_id': '/m/0grw1'},
|
| 462 |
+
{'id': 456, 'name': 'Marine invertebrates', 'freebase_id': '/m/03hl4l9'},
|
| 463 |
+
{'id': 457, 'name': 'Umbrella', 'freebase_id': '/m/0hnnb'},
|
| 464 |
+
{'id': 458, 'name': 'Kangaroo', 'freebase_id': '/m/04c0y'},
|
| 465 |
+
{'id': 459, 'name': 'Human arm', 'freebase_id': '/m/0dzf4'},
|
| 466 |
+
{'id': 460, 'name': 'Measuring cup', 'freebase_id': '/m/07v9_z'},
|
| 467 |
+
{'id': 461, 'name': 'Snail', 'freebase_id': '/m/0f9_l'},
|
| 468 |
+
{'id': 462, 'name': 'Loveseat', 'freebase_id': '/m/0703r8'},
|
| 469 |
+
{'id': 463, 'name': 'Suit', 'freebase_id': '/m/01xyhv'},
|
| 470 |
+
{'id': 464, 'name': 'Teapot', 'freebase_id': '/m/01fh4r'},
|
| 471 |
+
{'id': 465, 'name': 'Bottle', 'freebase_id': '/m/04dr76w'},
|
| 472 |
+
{'id': 466, 'name': 'Alpaca', 'freebase_id': '/m/0pcr'},
|
| 473 |
+
{'id': 467, 'name': 'Kettle', 'freebase_id': '/m/03s_tn'},
|
| 474 |
+
{'id': 468, 'name': 'Trousers', 'freebase_id': '/m/07mhn'},
|
| 475 |
+
{'id': 469, 'name': 'Popcorn', 'freebase_id': '/m/01hrv5'},
|
| 476 |
+
{'id': 470, 'name': 'Centipede', 'freebase_id': '/m/019h78'},
|
| 477 |
+
{'id': 471, 'name': 'Spider', 'freebase_id': '/m/09kmb'},
|
| 478 |
+
{'id': 472, 'name': 'Sparrow', 'freebase_id': '/m/0h23m'},
|
| 479 |
+
{'id': 473, 'name': 'Plate', 'freebase_id': '/m/050gv4'},
|
| 480 |
+
{'id': 474, 'name': 'Bagel', 'freebase_id': '/m/01fb_0'},
|
| 481 |
+
{'id': 475, 'name': 'Personal care', 'freebase_id': '/m/02w3_ws'},
|
| 482 |
+
{'id': 476, 'name': 'Apple', 'freebase_id': '/m/014j1m'},
|
| 483 |
+
{'id': 477, 'name': 'Brassiere', 'freebase_id': '/m/01gmv2'},
|
| 484 |
+
{'id': 478, 'name': 'Bathroom cabinet', 'freebase_id': '/m/04y4h8h'},
|
| 485 |
+
{'id': 479, 'name': 'studio couch', 'freebase_id': '/m/026qbn5'},
|
| 486 |
+
{'id': 480, 'name': 'Computer keyboard', 'freebase_id': '/m/01m2v'},
|
| 487 |
+
{'id': 481, 'name': 'Table tennis racket', 'freebase_id': '/m/05_5p_0'},
|
| 488 |
+
{'id': 482, 'name': 'Sushi', 'freebase_id': '/m/07030'},
|
| 489 |
+
{'id': 483, 'name': 'Cabinetry', 'freebase_id': '/m/01s105'},
|
| 490 |
+
{'id': 484, 'name': 'Street light', 'freebase_id': '/m/033rq4'},
|
| 491 |
+
{'id': 485, 'name': 'Towel', 'freebase_id': '/m/0162_1'},
|
| 492 |
+
{'id': 486, 'name': 'Nightstand', 'freebase_id': '/m/02z51p'},
|
| 493 |
+
{'id': 487, 'name': 'Rabbit', 'freebase_id': '/m/06mf6'},
|
| 494 |
+
{'id': 488, 'name': 'Dolphin', 'freebase_id': '/m/02hj4'},
|
| 495 |
+
{'id': 489, 'name': 'Dog', 'freebase_id': '/m/0bt9lr'},
|
| 496 |
+
{'id': 490, 'name': 'Jug', 'freebase_id': '/m/08hvt4'},
|
| 497 |
+
{'id': 491, 'name': 'Wok', 'freebase_id': '/m/084rd'},
|
| 498 |
+
{'id': 492, 'name': 'Fire hydrant', 'freebase_id': '/m/01pns0'},
|
| 499 |
+
{'id': 493, 'name': 'Human eye', 'freebase_id': '/m/014sv8'},
|
| 500 |
+
{'id': 494, 'name': 'Skyscraper', 'freebase_id': '/m/079cl'},
|
| 501 |
+
{'id': 495, 'name': 'Backpack', 'freebase_id': '/m/01940j'},
|
| 502 |
+
{'id': 496, 'name': 'Potato', 'freebase_id': '/m/05vtc'},
|
| 503 |
+
{'id': 497, 'name': 'Paper towel', 'freebase_id': '/m/02w3r3'},
|
| 504 |
+
{'id': 498, 'name': 'Lifejacket', 'freebase_id': '/m/054xkw'},
|
| 505 |
+
{'id': 499, 'name': 'Bicycle wheel', 'freebase_id': '/m/01bqk0'},
|
| 506 |
+
{'id': 500, 'name': 'Toilet', 'freebase_id': '/m/09g1w'},
|
| 507 |
+
]
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def _get_builtin_metadata(cats):
|
| 511 |
+
id_to_name = {x['id']: x['name'] for x in cats}
|
| 512 |
+
thing_dataset_id_to_contiguous_id = {i + 1: i for i in range(len(cats))}
|
| 513 |
+
thing_classes = [x['name'] for x in sorted(cats, key=lambda x: x['id'])]
|
| 514 |
+
return {
|
| 515 |
+
"thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
|
| 516 |
+
"thing_classes": thing_classes}
|
| 517 |
+
|
| 518 |
+
_PREDEFINED_SPLITS_OID = {
|
| 519 |
+
# cat threshold: 500, 1500: r 170, c 151, f 179
|
| 520 |
+
"oid_train": ("oid/images/", "oid/annotations/oid_challenge_2019_train_bbox.json"),
|
| 521 |
+
# "expanded" duplicates annotations to their father classes based on the official
|
| 522 |
+
# hierarchy. This is used in the official evaulation protocol.
|
| 523 |
+
# https://storage.googleapis.com/openimages/web/evaluation.html
|
| 524 |
+
"oid_val_expanded": ("oid/images/validation/", "oid/annotations/oid_challenge_2019_val_expanded.json"),
|
| 525 |
+
"oid_val_expanded_rare": ("oid/images/validation/", "oid/annotations/oid_challenge_2019_val_expanded_rare.json"),
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
for key, (image_root, json_file) in _PREDEFINED_SPLITS_OID.items():
|
| 530 |
+
register_oid_instances(
|
| 531 |
+
key,
|
| 532 |
+
_get_builtin_metadata(categories),
|
| 533 |
+
os.path.join("datasets", json_file) if "://" not in json_file else json_file,
|
| 534 |
+
os.path.join("datasets", image_root),
|
| 535 |
+
)
|
proxydet/data/datasets/register_oid.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# Modified by Xingyi Zhou from https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/datasets/coco.py
|
| 3 |
+
import copy
|
| 4 |
+
import io
|
| 5 |
+
import logging
|
| 6 |
+
import contextlib
|
| 7 |
+
import os
|
| 8 |
+
import datetime
|
| 9 |
+
import json
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
from fvcore.common.timer import Timer
|
| 15 |
+
from fvcore.common.file_io import PathManager, file_lock
|
| 16 |
+
from detectron2.structures import BoxMode, PolygonMasks, Boxes
|
| 17 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
"""
|
| 22 |
+
This file contains functions to register a COCO-format dataset to the DatasetCatalog.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
__all__ = ["register_coco_instances", "register_coco_panoptic_separated"]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def register_oid_instances(name, metadata, json_file, image_root):
|
| 30 |
+
"""
|
| 31 |
+
"""
|
| 32 |
+
# 1. register a function which returns dicts
|
| 33 |
+
DatasetCatalog.register(name, lambda: load_coco_json_mem_efficient(
|
| 34 |
+
json_file, image_root, name))
|
| 35 |
+
|
| 36 |
+
# 2. Optionally, add metadata about this dataset,
|
| 37 |
+
# since they might be useful in evaluation, visualization or logging
|
| 38 |
+
MetadataCatalog.get(name).set(
|
| 39 |
+
json_file=json_file, image_root=image_root, evaluator_type="oid", **metadata
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def load_coco_json_mem_efficient(json_file, image_root, dataset_name=None, extra_annotation_keys=None):
|
| 44 |
+
"""
|
| 45 |
+
Actually not mem efficient
|
| 46 |
+
"""
|
| 47 |
+
from pycocotools.coco import COCO
|
| 48 |
+
|
| 49 |
+
timer = Timer()
|
| 50 |
+
json_file = PathManager.get_local_path(json_file)
|
| 51 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
| 52 |
+
coco_api = COCO(json_file)
|
| 53 |
+
if timer.seconds() > 1:
|
| 54 |
+
logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
|
| 55 |
+
|
| 56 |
+
id_map = None
|
| 57 |
+
if dataset_name is not None:
|
| 58 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 59 |
+
cat_ids = sorted(coco_api.getCatIds())
|
| 60 |
+
cats = coco_api.loadCats(cat_ids)
|
| 61 |
+
# The categories in a custom json file may not be sorted.
|
| 62 |
+
thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])]
|
| 63 |
+
meta.thing_classes = thing_classes
|
| 64 |
+
|
| 65 |
+
if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)):
|
| 66 |
+
if "coco" not in dataset_name:
|
| 67 |
+
logger.warning(
|
| 68 |
+
"""
|
| 69 |
+
Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.
|
| 70 |
+
"""
|
| 71 |
+
)
|
| 72 |
+
id_map = {v: i for i, v in enumerate(cat_ids)}
|
| 73 |
+
meta.thing_dataset_id_to_contiguous_id = id_map
|
| 74 |
+
|
| 75 |
+
# sort indices for reproducible results
|
| 76 |
+
img_ids = sorted(coco_api.imgs.keys())
|
| 77 |
+
imgs = coco_api.loadImgs(img_ids)
|
| 78 |
+
logger.info("Loaded {} images in COCO format from {}".format(len(imgs), json_file))
|
| 79 |
+
|
| 80 |
+
dataset_dicts = []
|
| 81 |
+
|
| 82 |
+
ann_keys = ["iscrowd", "bbox", "category_id"] + (extra_annotation_keys or [])
|
| 83 |
+
|
| 84 |
+
for img_dict in imgs:
|
| 85 |
+
record = {}
|
| 86 |
+
record["file_name"] = os.path.join(image_root, img_dict["file_name"])
|
| 87 |
+
record["height"] = img_dict["height"]
|
| 88 |
+
record["width"] = img_dict["width"]
|
| 89 |
+
image_id = record["image_id"] = img_dict["id"]
|
| 90 |
+
anno_dict_list = coco_api.imgToAnns[image_id]
|
| 91 |
+
if 'neg_category_ids' in img_dict:
|
| 92 |
+
record['neg_category_ids'] = \
|
| 93 |
+
[id_map[x] for x in img_dict['neg_category_ids']]
|
| 94 |
+
|
| 95 |
+
objs = []
|
| 96 |
+
for anno in anno_dict_list:
|
| 97 |
+
assert anno["image_id"] == image_id
|
| 98 |
+
|
| 99 |
+
assert anno.get("ignore", 0) == 0
|
| 100 |
+
|
| 101 |
+
obj = {key: anno[key] for key in ann_keys if key in anno}
|
| 102 |
+
|
| 103 |
+
segm = anno.get("segmentation", None)
|
| 104 |
+
if segm: # either list[list[float]] or dict(RLE)
|
| 105 |
+
if not isinstance(segm, dict):
|
| 106 |
+
# filter out invalid polygons (< 3 points)
|
| 107 |
+
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
|
| 108 |
+
if len(segm) == 0:
|
| 109 |
+
num_instances_without_valid_segmentation += 1
|
| 110 |
+
continue # ignore this instance
|
| 111 |
+
obj["segmentation"] = segm
|
| 112 |
+
|
| 113 |
+
obj["bbox_mode"] = BoxMode.XYWH_ABS
|
| 114 |
+
|
| 115 |
+
if id_map:
|
| 116 |
+
obj["category_id"] = id_map[obj["category_id"]]
|
| 117 |
+
objs.append(obj)
|
| 118 |
+
record["annotations"] = objs
|
| 119 |
+
dataset_dicts.append(record)
|
| 120 |
+
|
| 121 |
+
del coco_api
|
| 122 |
+
return dataset_dicts
|
proxydet/data/tar_dataset.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
import os
|
| 4 |
+
import gzip
|
| 5 |
+
import numpy as np
|
| 6 |
+
import io
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from PIL import UnidentifiedImageError
|
| 12 |
+
|
| 13 |
+
unidentified_error_available = True
|
| 14 |
+
except ImportError:
|
| 15 |
+
# UnidentifiedImageError isn't available in older versions of PIL
|
| 16 |
+
unidentified_error_available = False
|
| 17 |
+
|
| 18 |
+
class DiskTarDataset(Dataset):
|
| 19 |
+
def __init__(self,
|
| 20 |
+
tarfile_path='dataset/imagenet/ImageNet-21k/metadata/tar_files.npy',
|
| 21 |
+
tar_index_dir='dataset/imagenet/ImageNet-21k/metadata/tarindex_npy',
|
| 22 |
+
preload=False,
|
| 23 |
+
num_synsets="all"):
|
| 24 |
+
"""
|
| 25 |
+
- preload (bool): Recommend to set preload to False when using
|
| 26 |
+
- num_synsets (integer or string "all"): set to small number for debugging
|
| 27 |
+
will load subset of dataset
|
| 28 |
+
"""
|
| 29 |
+
tar_files = np.load(tarfile_path)
|
| 30 |
+
|
| 31 |
+
chunk_datasets = []
|
| 32 |
+
dataset_lens = []
|
| 33 |
+
if isinstance(num_synsets, int):
|
| 34 |
+
assert num_synsets < len(tar_files)
|
| 35 |
+
tar_files = tar_files[:num_synsets]
|
| 36 |
+
for tar_file in tar_files:
|
| 37 |
+
dataset = _TarDataset(tar_file, tar_index_dir, preload=preload)
|
| 38 |
+
chunk_datasets.append(dataset)
|
| 39 |
+
dataset_lens.append(len(dataset))
|
| 40 |
+
|
| 41 |
+
self.chunk_datasets = chunk_datasets
|
| 42 |
+
self.dataset_lens = np.array(dataset_lens).astype(np.int32)
|
| 43 |
+
self.dataset_cumsums = np.cumsum(self.dataset_lens)
|
| 44 |
+
self.num_samples = sum(self.dataset_lens)
|
| 45 |
+
labels = np.zeros(self.dataset_lens.sum(), dtype=np.int64)
|
| 46 |
+
sI = 0
|
| 47 |
+
for k in range(len(self.dataset_lens)):
|
| 48 |
+
assert (sI+self.dataset_lens[k]) <= len(labels), f"{k} {sI+self.dataset_lens[k]} vs. {len(labels)}"
|
| 49 |
+
labels[sI:(sI+self.dataset_lens[k])] = k
|
| 50 |
+
sI += self.dataset_lens[k]
|
| 51 |
+
self.labels = labels
|
| 52 |
+
|
| 53 |
+
def __len__(self):
|
| 54 |
+
return self.num_samples
|
| 55 |
+
|
| 56 |
+
def __getitem__(self, index):
|
| 57 |
+
assert index >= 0 and index < len(self)
|
| 58 |
+
# find the dataset file we need to go to
|
| 59 |
+
d_index = np.searchsorted(self.dataset_cumsums, index)
|
| 60 |
+
|
| 61 |
+
# edge case, if index is at edge of chunks, move right
|
| 62 |
+
if index in self.dataset_cumsums:
|
| 63 |
+
d_index += 1
|
| 64 |
+
|
| 65 |
+
assert d_index == self.labels[index], f"{d_index} vs. {self.labels[index]} mismatch for {index}"
|
| 66 |
+
|
| 67 |
+
# change index to local dataset index
|
| 68 |
+
if d_index == 0:
|
| 69 |
+
local_index = index
|
| 70 |
+
else:
|
| 71 |
+
local_index = index - self.dataset_cumsums[d_index - 1]
|
| 72 |
+
data_bytes = self.chunk_datasets[d_index][local_index]
|
| 73 |
+
exception_to_catch = UnidentifiedImageError if unidentified_error_available else Exception
|
| 74 |
+
try:
|
| 75 |
+
image = Image.open(data_bytes).convert("RGB")
|
| 76 |
+
except exception_to_catch:
|
| 77 |
+
image = Image.fromarray(np.ones((224,224,3), dtype=np.uint8)*128)
|
| 78 |
+
d_index = -1
|
| 79 |
+
|
| 80 |
+
# label is the dataset (synset) we indexed into
|
| 81 |
+
return image, d_index, index
|
| 82 |
+
|
| 83 |
+
def __repr__(self):
|
| 84 |
+
st = f"DiskTarDataset(subdatasets={len(self.dataset_lens)},samples={self.num_samples})"
|
| 85 |
+
return st
|
| 86 |
+
|
| 87 |
+
class _TarDataset(object):
|
| 88 |
+
|
| 89 |
+
def __init__(self, filename, npy_index_dir, preload=False):
|
| 90 |
+
# translated from
|
| 91 |
+
# fbcode/experimental/deeplearning/matthijs/comp_descs/tardataset.lua
|
| 92 |
+
self.filename = filename
|
| 93 |
+
self.names = []
|
| 94 |
+
self.offsets = []
|
| 95 |
+
self.npy_index_dir = npy_index_dir
|
| 96 |
+
names, offsets = self.load_index()
|
| 97 |
+
|
| 98 |
+
self.num_samples = len(names)
|
| 99 |
+
if preload:
|
| 100 |
+
self.data = np.memmap(filename, mode='r', dtype='uint8')
|
| 101 |
+
self.offsets = offsets
|
| 102 |
+
else:
|
| 103 |
+
self.data = None
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def __len__(self):
|
| 107 |
+
return self.num_samples
|
| 108 |
+
|
| 109 |
+
def load_index(self):
|
| 110 |
+
basename = os.path.basename(self.filename)
|
| 111 |
+
basename = os.path.splitext(basename)[0]
|
| 112 |
+
names = np.load(os.path.join(self.npy_index_dir, f"{basename}_names.npy"))
|
| 113 |
+
offsets = np.load(os.path.join(self.npy_index_dir, f"{basename}_offsets.npy"))
|
| 114 |
+
return names, offsets
|
| 115 |
+
|
| 116 |
+
def __getitem__(self, idx):
|
| 117 |
+
if self.data is None:
|
| 118 |
+
self.data = np.memmap(self.filename, mode='r', dtype='uint8')
|
| 119 |
+
_, self.offsets = self.load_index()
|
| 120 |
+
|
| 121 |
+
ofs = self.offsets[idx] * 512
|
| 122 |
+
fsize = 512 * (self.offsets[idx + 1] - self.offsets[idx])
|
| 123 |
+
data = self.data[ofs:ofs + fsize]
|
| 124 |
+
|
| 125 |
+
if data[:13].tostring() == '././@LongLink':
|
| 126 |
+
data = data[3 * 512:]
|
| 127 |
+
else:
|
| 128 |
+
data = data[512:]
|
| 129 |
+
|
| 130 |
+
# just to make it more fun a few JPEGs are GZIP compressed...
|
| 131 |
+
# catch this case
|
| 132 |
+
if tuple(data[:2]) == (0x1f, 0x8b):
|
| 133 |
+
s = io.BytesIO(data.tostring())
|
| 134 |
+
g = gzip.GzipFile(None, 'r', 0, s)
|
| 135 |
+
sdata = g.read()
|
| 136 |
+
else:
|
| 137 |
+
sdata = data.tostring()
|
| 138 |
+
return io.BytesIO(sdata)
|
proxydet/data/transforms/custom_augmentation_impl.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 3 |
+
# Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py
|
| 4 |
+
# Modified by Xingyi Zhou
|
| 5 |
+
# The original code is under Apache-2.0 License
|
| 6 |
+
import numpy as np
|
| 7 |
+
import sys
|
| 8 |
+
from fvcore.transforms.transform import (
|
| 9 |
+
BlendTransform,
|
| 10 |
+
CropTransform,
|
| 11 |
+
HFlipTransform,
|
| 12 |
+
NoOpTransform,
|
| 13 |
+
Transform,
|
| 14 |
+
VFlipTransform,
|
| 15 |
+
)
|
| 16 |
+
from PIL import Image
|
| 17 |
+
|
| 18 |
+
from detectron2.data.transforms.augmentation import Augmentation
|
| 19 |
+
from .custom_transform import EfficientDetResizeCropTransform
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
"EfficientDetResizeCrop",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
class EfficientDetResizeCrop(Augmentation):
|
| 26 |
+
"""
|
| 27 |
+
Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge.
|
| 28 |
+
If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self, size, scale, interp=Image.BILINEAR
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
"""
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.target_size = (size, size)
|
| 38 |
+
self.scale = scale
|
| 39 |
+
self.interp = interp
|
| 40 |
+
|
| 41 |
+
def get_transform(self, img):
|
| 42 |
+
# Select a random scale factor.
|
| 43 |
+
scale_factor = np.random.uniform(*self.scale)
|
| 44 |
+
scaled_target_height = scale_factor * self.target_size[0]
|
| 45 |
+
scaled_target_width = scale_factor * self.target_size[1]
|
| 46 |
+
# Recompute the accurate scale_factor using rounded scaled image size.
|
| 47 |
+
width, height = img.shape[1], img.shape[0]
|
| 48 |
+
img_scale_y = scaled_target_height / height
|
| 49 |
+
img_scale_x = scaled_target_width / width
|
| 50 |
+
img_scale = min(img_scale_y, img_scale_x)
|
| 51 |
+
|
| 52 |
+
# Select non-zero random offset (x, y) if scaled image is larger than target size
|
| 53 |
+
scaled_h = int(height * img_scale)
|
| 54 |
+
scaled_w = int(width * img_scale)
|
| 55 |
+
offset_y = scaled_h - self.target_size[0]
|
| 56 |
+
offset_x = scaled_w - self.target_size[1]
|
| 57 |
+
offset_y = int(max(0.0, float(offset_y)) * np.random.uniform(0, 1))
|
| 58 |
+
offset_x = int(max(0.0, float(offset_x)) * np.random.uniform(0, 1))
|
| 59 |
+
return EfficientDetResizeCropTransform(
|
| 60 |
+
scaled_h, scaled_w, offset_y, offset_x, img_scale, self.target_size, self.interp)
|
proxydet/data/transforms/custom_transform.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 3 |
+
# Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py
|
| 4 |
+
# Modified by Xingyi Zhou
|
| 5 |
+
# The original code is under Apache-2.0 License
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from fvcore.transforms.transform import (
|
| 10 |
+
CropTransform,
|
| 11 |
+
HFlipTransform,
|
| 12 |
+
NoOpTransform,
|
| 13 |
+
Transform,
|
| 14 |
+
TransformList,
|
| 15 |
+
)
|
| 16 |
+
from PIL import Image
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
import cv2 # noqa
|
| 20 |
+
except ImportError:
|
| 21 |
+
# OpenCV is an optional dependency at the moment
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
__all__ = [
|
| 25 |
+
"EfficientDetResizeCropTransform",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
class EfficientDetResizeCropTransform(Transform):
|
| 29 |
+
"""
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, scaled_h, scaled_w, offset_y, offset_x, img_scale, \
|
| 33 |
+
target_size, interp=None):
|
| 34 |
+
"""
|
| 35 |
+
Args:
|
| 36 |
+
h, w (int): original image size
|
| 37 |
+
new_h, new_w (int): new image size
|
| 38 |
+
interp: PIL interpolation methods, defaults to bilinear.
|
| 39 |
+
"""
|
| 40 |
+
# TODO decide on PIL vs opencv
|
| 41 |
+
super().__init__()
|
| 42 |
+
if interp is None:
|
| 43 |
+
interp = Image.BILINEAR
|
| 44 |
+
self._set_attributes(locals())
|
| 45 |
+
|
| 46 |
+
def apply_image(self, img, interp=None):
|
| 47 |
+
assert len(img.shape) <= 4
|
| 48 |
+
|
| 49 |
+
if img.dtype == np.uint8:
|
| 50 |
+
pil_image = Image.fromarray(img)
|
| 51 |
+
interp_method = interp if interp is not None else self.interp
|
| 52 |
+
pil_image = pil_image.resize((self.scaled_w, self.scaled_h), interp_method)
|
| 53 |
+
ret = np.asarray(pil_image)
|
| 54 |
+
right = min(self.scaled_w, self.offset_x + self.target_size[1])
|
| 55 |
+
lower = min(self.scaled_h, self.offset_y + self.target_size[0])
|
| 56 |
+
if len(ret.shape) <= 3:
|
| 57 |
+
ret = ret[self.offset_y: lower, self.offset_x: right]
|
| 58 |
+
else:
|
| 59 |
+
ret = ret[..., self.offset_y: lower, self.offset_x: right, :]
|
| 60 |
+
else:
|
| 61 |
+
# PIL only supports uint8
|
| 62 |
+
img = torch.from_numpy(img)
|
| 63 |
+
shape = list(img.shape)
|
| 64 |
+
shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:]
|
| 65 |
+
img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw
|
| 66 |
+
_PIL_RESIZE_TO_INTERPOLATE_MODE = {Image.BILINEAR: "bilinear", Image.BICUBIC: "bicubic"}
|
| 67 |
+
mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[self.interp]
|
| 68 |
+
img = F.interpolate(img, (self.scaled_h, self.scaled_w), mode=mode, align_corners=False)
|
| 69 |
+
shape[:2] = (self.scaled_h, self.scaled_w)
|
| 70 |
+
ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c)
|
| 71 |
+
right = min(self.scaled_w, self.offset_x + self.target_size[1])
|
| 72 |
+
lower = min(self.scaled_h, self.offset_y + self.target_size[0])
|
| 73 |
+
if len(ret.shape) <= 3:
|
| 74 |
+
ret = ret[self.offset_y: lower, self.offset_x: right]
|
| 75 |
+
else:
|
| 76 |
+
ret = ret[..., self.offset_y: lower, self.offset_x: right, :]
|
| 77 |
+
return ret
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def apply_coords(self, coords):
|
| 81 |
+
coords[:, 0] = coords[:, 0] * self.img_scale
|
| 82 |
+
coords[:, 1] = coords[:, 1] * self.img_scale
|
| 83 |
+
coords[:, 0] -= self.offset_x
|
| 84 |
+
coords[:, 1] -= self.offset_y
|
| 85 |
+
return coords
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def apply_segmentation(self, segmentation):
|
| 89 |
+
segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
|
| 90 |
+
return segmentation
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def inverse(self):
|
| 94 |
+
raise NotImplementedError
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def inverse_apply_coords(self, coords):
|
| 98 |
+
coords[:, 0] += self.offset_x
|
| 99 |
+
coords[:, 1] += self.offset_y
|
| 100 |
+
coords[:, 0] = coords[:, 0] / self.img_scale
|
| 101 |
+
coords[:, 1] = coords[:, 1] / self.img_scale
|
| 102 |
+
return coords
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def inverse_apply_box(self, box: np.ndarray) -> np.ndarray:
|
| 106 |
+
"""
|
| 107 |
+
"""
|
| 108 |
+
idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten()
|
| 109 |
+
coords = np.asarray(box).reshape(-1, 4)[:, idxs].reshape(-1, 2)
|
| 110 |
+
coords = self.inverse_apply_coords(coords).reshape((-1, 4, 2))
|
| 111 |
+
minxy = coords.min(axis=1)
|
| 112 |
+
maxxy = coords.max(axis=1)
|
| 113 |
+
trans_boxes = np.concatenate((minxy, maxxy), axis=1)
|
| 114 |
+
return trans_boxes
|
proxydet/evaluation/custom_coco_eval.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import contextlib
|
| 3 |
+
import copy
|
| 4 |
+
import io
|
| 5 |
+
import itertools
|
| 6 |
+
import json
|
| 7 |
+
import logging
|
| 8 |
+
import numpy as np
|
| 9 |
+
import os
|
| 10 |
+
import pickle
|
| 11 |
+
from collections import OrderedDict
|
| 12 |
+
import pycocotools.mask as mask_util
|
| 13 |
+
import torch
|
| 14 |
+
from pycocotools.coco import COCO
|
| 15 |
+
from pycocotools.cocoeval import COCOeval
|
| 16 |
+
from tabulate import tabulate
|
| 17 |
+
|
| 18 |
+
import detectron2.utils.comm as comm
|
| 19 |
+
from detectron2.config import CfgNode
|
| 20 |
+
from detectron2.data import MetadataCatalog
|
| 21 |
+
from detectron2.data.datasets.coco import convert_to_coco_json
|
| 22 |
+
from detectron2.evaluation.coco_evaluation import COCOEvaluator
|
| 23 |
+
from detectron2.structures import Boxes, BoxMode, pairwise_iou
|
| 24 |
+
from detectron2.utils.file_io import PathManager
|
| 25 |
+
from detectron2.utils.logger import create_small_table
|
| 26 |
+
from ..data.datasets.coco_zeroshot import categories_seen, categories_unseen
|
| 27 |
+
|
| 28 |
+
class CustomCOCOEvaluator(COCOEvaluator):
|
| 29 |
+
def _derive_coco_results(self, coco_eval, iou_type, class_names=None):
|
| 30 |
+
"""
|
| 31 |
+
Additionally plot mAP for 'seen classes' and 'unseen classes'
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
metrics = {
|
| 35 |
+
"bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
|
| 36 |
+
"segm": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
|
| 37 |
+
"keypoints": ["AP", "AP50", "AP75", "APm", "APl"],
|
| 38 |
+
}[iou_type]
|
| 39 |
+
|
| 40 |
+
if coco_eval is None:
|
| 41 |
+
self._logger.warn("No predictions from the model!")
|
| 42 |
+
return {metric: float("nan") for metric in metrics}
|
| 43 |
+
|
| 44 |
+
# the standard metrics
|
| 45 |
+
results = {
|
| 46 |
+
metric: float(coco_eval.stats[idx] * 100 if coco_eval.stats[idx] >= 0 else "nan")
|
| 47 |
+
for idx, metric in enumerate(metrics)
|
| 48 |
+
}
|
| 49 |
+
self._logger.info(
|
| 50 |
+
"Evaluation results for {}: \n".format(iou_type) + create_small_table(results)
|
| 51 |
+
)
|
| 52 |
+
if not np.isfinite(sum(results.values())):
|
| 53 |
+
self._logger.info("Some metrics cannot be computed and is shown as NaN.")
|
| 54 |
+
|
| 55 |
+
if class_names is None or len(class_names) <= 1:
|
| 56 |
+
return results
|
| 57 |
+
# Compute per-category AP
|
| 58 |
+
# from https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L222-L252 # noqa
|
| 59 |
+
precisions = coco_eval.eval["precision"]
|
| 60 |
+
# precision has dims (iou, recall, cls, area range, max dets)
|
| 61 |
+
assert len(class_names) == precisions.shape[2]
|
| 62 |
+
|
| 63 |
+
seen_names = set([x['name'] for x in categories_seen])
|
| 64 |
+
unseen_names = set([x['name'] for x in categories_unseen])
|
| 65 |
+
results_per_category = []
|
| 66 |
+
results_per_category50 = []
|
| 67 |
+
results_per_category50_seen = []
|
| 68 |
+
results_per_category50_unseen = []
|
| 69 |
+
for idx, name in enumerate(class_names):
|
| 70 |
+
# area range index 0: all area ranges
|
| 71 |
+
# max dets index -1: typically 100 per image
|
| 72 |
+
precision = precisions[:, :, idx, 0, -1]
|
| 73 |
+
precision = precision[precision > -1]
|
| 74 |
+
ap = np.mean(precision) if precision.size else float("nan")
|
| 75 |
+
results_per_category.append(("{}".format(name), float(ap * 100)))
|
| 76 |
+
precision50 = precisions[0, :, idx, 0, -1]
|
| 77 |
+
precision50 = precision50[precision50 > -1]
|
| 78 |
+
ap50 = np.mean(precision50) if precision50.size else float("nan")
|
| 79 |
+
results_per_category50.append(("{}".format(name), float(ap50 * 100)))
|
| 80 |
+
if name in seen_names:
|
| 81 |
+
results_per_category50_seen.append(float(ap50 * 100))
|
| 82 |
+
if name in unseen_names:
|
| 83 |
+
results_per_category50_unseen.append(float(ap50 * 100))
|
| 84 |
+
|
| 85 |
+
# tabulate it
|
| 86 |
+
N_COLS = min(6, len(results_per_category) * 2)
|
| 87 |
+
results_flatten = list(itertools.chain(*results_per_category))
|
| 88 |
+
results_2d = itertools.zip_longest(*[results_flatten[i::N_COLS] for i in range(N_COLS)])
|
| 89 |
+
table = tabulate(
|
| 90 |
+
results_2d,
|
| 91 |
+
tablefmt="pipe",
|
| 92 |
+
floatfmt=".3f",
|
| 93 |
+
headers=["category", "AP"] * (N_COLS // 2),
|
| 94 |
+
numalign="left",
|
| 95 |
+
)
|
| 96 |
+
self._logger.info("Per-category {} AP: \n".format(iou_type) + table)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
N_COLS = min(6, len(results_per_category50) * 2)
|
| 100 |
+
results_flatten = list(itertools.chain(*results_per_category50))
|
| 101 |
+
results_2d = itertools.zip_longest(*[results_flatten[i::N_COLS] for i in range(N_COLS)])
|
| 102 |
+
table = tabulate(
|
| 103 |
+
results_2d,
|
| 104 |
+
tablefmt="pipe",
|
| 105 |
+
floatfmt=".3f",
|
| 106 |
+
headers=["category", "AP50"] * (N_COLS // 2),
|
| 107 |
+
numalign="left",
|
| 108 |
+
)
|
| 109 |
+
self._logger.info("Per-category {} AP50: \n".format(iou_type) + table)
|
| 110 |
+
self._logger.info(
|
| 111 |
+
"Seen {} AP50: {}".format(
|
| 112 |
+
iou_type,
|
| 113 |
+
sum(results_per_category50_seen) / len(results_per_category50_seen),
|
| 114 |
+
))
|
| 115 |
+
self._logger.info(
|
| 116 |
+
"Unseen {} AP50: {}".format(
|
| 117 |
+
iou_type,
|
| 118 |
+
sum(results_per_category50_unseen) / len(results_per_category50_unseen),
|
| 119 |
+
))
|
| 120 |
+
|
| 121 |
+
results.update({"AP-" + name: ap for name, ap in results_per_category})
|
| 122 |
+
results["AP50-seen"] = sum(results_per_category50_seen) / len(results_per_category50_seen)
|
| 123 |
+
results["AP50-unseen"] = sum(results_per_category50_unseen) / len(results_per_category50_unseen)
|
| 124 |
+
return results
|
proxydet/evaluation/oideval.py
ADDED
|
@@ -0,0 +1,699 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Part of the code is from https://github.com/tensorflow/models/blob/master/research/object_detection/metrics/oid_challenge_evaluation.py
|
| 2 |
+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
| 3 |
+
# The original code is under Apache License, Version 2.0 (the "License");
|
| 4 |
+
# Part of the code is from https://github.com/lvis-dataset/lvis-api/blob/master/lvis/eval.py
|
| 5 |
+
# Copyright (c) 2019, Agrim Gupta and Ross Girshick
|
| 6 |
+
# Modified by Xingyi Zhou
|
| 7 |
+
# This script re-implement OpenImages evaluation in detectron2
|
| 8 |
+
# The code is from https://github.com/xingyizhou/UniDet/blob/master/projects/UniDet/unidet/evaluation/oideval.py
|
| 9 |
+
# The original code is under Apache-2.0 License
|
| 10 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 11 |
+
import os
|
| 12 |
+
import datetime
|
| 13 |
+
import logging
|
| 14 |
+
import itertools
|
| 15 |
+
from collections import OrderedDict
|
| 16 |
+
from collections import defaultdict
|
| 17 |
+
import copy
|
| 18 |
+
import json
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
from tabulate import tabulate
|
| 22 |
+
|
| 23 |
+
from lvis.lvis import LVIS
|
| 24 |
+
from lvis.results import LVISResults
|
| 25 |
+
|
| 26 |
+
import pycocotools.mask as mask_utils
|
| 27 |
+
|
| 28 |
+
from fvcore.common.file_io import PathManager
|
| 29 |
+
import detectron2.utils.comm as comm
|
| 30 |
+
from detectron2.data import MetadataCatalog
|
| 31 |
+
from detectron2.evaluation.coco_evaluation import instances_to_coco_json
|
| 32 |
+
from detectron2.utils.logger import create_small_table
|
| 33 |
+
from detectron2.evaluation import DatasetEvaluator
|
| 34 |
+
|
| 35 |
+
def compute_average_precision(precision, recall):
|
| 36 |
+
"""Compute Average Precision according to the definition in VOCdevkit.
|
| 37 |
+
Precision is modified to ensure that it does not decrease as recall
|
| 38 |
+
decrease.
|
| 39 |
+
Args:
|
| 40 |
+
precision: A float [N, 1] numpy array of precisions
|
| 41 |
+
recall: A float [N, 1] numpy array of recalls
|
| 42 |
+
Raises:
|
| 43 |
+
ValueError: if the input is not of the correct format
|
| 44 |
+
Returns:
|
| 45 |
+
average_precison: The area under the precision recall curve. NaN if
|
| 46 |
+
precision and recall are None.
|
| 47 |
+
"""
|
| 48 |
+
if precision is None:
|
| 49 |
+
if recall is not None:
|
| 50 |
+
raise ValueError("If precision is None, recall must also be None")
|
| 51 |
+
return np.NAN
|
| 52 |
+
|
| 53 |
+
if not isinstance(precision, np.ndarray) or not isinstance(
|
| 54 |
+
recall, np.ndarray):
|
| 55 |
+
raise ValueError("precision and recall must be numpy array")
|
| 56 |
+
if precision.dtype != np.float or recall.dtype != np.float:
|
| 57 |
+
raise ValueError("input must be float numpy array.")
|
| 58 |
+
if len(precision) != len(recall):
|
| 59 |
+
raise ValueError("precision and recall must be of the same size.")
|
| 60 |
+
if not precision.size:
|
| 61 |
+
return 0.0
|
| 62 |
+
if np.amin(precision) < 0 or np.amax(precision) > 1:
|
| 63 |
+
raise ValueError("Precision must be in the range of [0, 1].")
|
| 64 |
+
if np.amin(recall) < 0 or np.amax(recall) > 1:
|
| 65 |
+
raise ValueError("recall must be in the range of [0, 1].")
|
| 66 |
+
if not all(recall[i] <= recall[i + 1] for i in range(len(recall) - 1)):
|
| 67 |
+
raise ValueError("recall must be a non-decreasing array")
|
| 68 |
+
|
| 69 |
+
recall = np.concatenate([[0], recall, [1]])
|
| 70 |
+
precision = np.concatenate([[0], precision, [0]])
|
| 71 |
+
|
| 72 |
+
for i in range(len(precision) - 2, -1, -1):
|
| 73 |
+
precision[i] = np.maximum(precision[i], precision[i + 1])
|
| 74 |
+
indices = np.where(recall[1:] != recall[:-1])[0] + 1
|
| 75 |
+
average_precision = np.sum(
|
| 76 |
+
(recall[indices] - recall[indices - 1]) * precision[indices])
|
| 77 |
+
return average_precision
|
| 78 |
+
|
| 79 |
+
class OIDEval:
|
| 80 |
+
def __init__(
|
| 81 |
+
self, lvis_gt, lvis_dt, iou_type="bbox", expand_pred_label=False,
|
| 82 |
+
oid_hierarchy_path='./datasets/oid/annotations/challenge-2019-label500-hierarchy.json'):
|
| 83 |
+
"""Constructor for OIDEval.
|
| 84 |
+
Args:
|
| 85 |
+
lvis_gt (LVIS class instance, or str containing path of annotation file)
|
| 86 |
+
lvis_dt (LVISResult class instance, or str containing path of result file,
|
| 87 |
+
or list of dict)
|
| 88 |
+
iou_type (str): segm or bbox evaluation
|
| 89 |
+
"""
|
| 90 |
+
self.logger = logging.getLogger(__name__)
|
| 91 |
+
|
| 92 |
+
if iou_type not in ["bbox", "segm"]:
|
| 93 |
+
raise ValueError("iou_type: {} is not supported.".format(iou_type))
|
| 94 |
+
|
| 95 |
+
if isinstance(lvis_gt, LVIS):
|
| 96 |
+
self.lvis_gt = lvis_gt
|
| 97 |
+
elif isinstance(lvis_gt, str):
|
| 98 |
+
self.lvis_gt = LVIS(lvis_gt)
|
| 99 |
+
else:
|
| 100 |
+
raise TypeError("Unsupported type {} of lvis_gt.".format(lvis_gt))
|
| 101 |
+
|
| 102 |
+
if isinstance(lvis_dt, LVISResults):
|
| 103 |
+
self.lvis_dt = lvis_dt
|
| 104 |
+
elif isinstance(lvis_dt, (str, list)):
|
| 105 |
+
# self.lvis_dt = LVISResults(self.lvis_gt, lvis_dt, max_dets=-1)
|
| 106 |
+
self.lvis_dt = LVISResults(self.lvis_gt, lvis_dt)
|
| 107 |
+
else:
|
| 108 |
+
raise TypeError("Unsupported type {} of lvis_dt.".format(lvis_dt))
|
| 109 |
+
|
| 110 |
+
if expand_pred_label:
|
| 111 |
+
oid_hierarchy = json.load(open(oid_hierarchy_path, 'r'))
|
| 112 |
+
cat_info = self.lvis_gt.dataset['categories']
|
| 113 |
+
freebase2id = {x['freebase_id']: x['id'] for x in cat_info}
|
| 114 |
+
id2freebase = {x['id']: x['freebase_id'] for x in cat_info}
|
| 115 |
+
id2name = {x['id']: x['name'] for x in cat_info}
|
| 116 |
+
|
| 117 |
+
fas = defaultdict(set)
|
| 118 |
+
def dfs(hierarchy, cur_id):
|
| 119 |
+
all_childs = set()
|
| 120 |
+
all_keyed_child = {}
|
| 121 |
+
if 'Subcategory' in hierarchy:
|
| 122 |
+
for x in hierarchy['Subcategory']:
|
| 123 |
+
childs = dfs(x, freebase2id[x['LabelName']])
|
| 124 |
+
all_childs.update(childs)
|
| 125 |
+
if cur_id != -1:
|
| 126 |
+
for c in all_childs:
|
| 127 |
+
fas[c].add(cur_id)
|
| 128 |
+
all_childs.add(cur_id)
|
| 129 |
+
return all_childs
|
| 130 |
+
dfs(oid_hierarchy, -1)
|
| 131 |
+
|
| 132 |
+
expanded_pred = []
|
| 133 |
+
id_count = 0
|
| 134 |
+
for d in self.lvis_dt.dataset['annotations']:
|
| 135 |
+
cur_id = d['category_id']
|
| 136 |
+
ids = [cur_id] + [x for x in fas[cur_id]]
|
| 137 |
+
for cat_id in ids:
|
| 138 |
+
new_box = copy.deepcopy(d)
|
| 139 |
+
id_count = id_count + 1
|
| 140 |
+
new_box['id'] = id_count
|
| 141 |
+
new_box['category_id'] = cat_id
|
| 142 |
+
expanded_pred.append(new_box)
|
| 143 |
+
|
| 144 |
+
print('Expanding original {} preds to {} preds'.format(
|
| 145 |
+
len(self.lvis_dt.dataset['annotations']),
|
| 146 |
+
len(expanded_pred)
|
| 147 |
+
))
|
| 148 |
+
self.lvis_dt.dataset['annotations'] = expanded_pred
|
| 149 |
+
self.lvis_dt._create_index()
|
| 150 |
+
|
| 151 |
+
# per-image per-category evaluation results
|
| 152 |
+
self.eval_imgs = defaultdict(list)
|
| 153 |
+
self.eval = {} # accumulated evaluation results
|
| 154 |
+
self._gts = defaultdict(list) # gt for evaluation
|
| 155 |
+
self._dts = defaultdict(list) # dt for evaluation
|
| 156 |
+
self.params = Params(iou_type=iou_type) # parameters
|
| 157 |
+
self.results = OrderedDict()
|
| 158 |
+
self.ious = {} # ious between all gts and dts
|
| 159 |
+
|
| 160 |
+
self.params.img_ids = sorted(self.lvis_gt.get_img_ids())
|
| 161 |
+
self.params.cat_ids = sorted(self.lvis_gt.get_cat_ids())
|
| 162 |
+
|
| 163 |
+
def _to_mask(self, anns, lvis):
|
| 164 |
+
for ann in anns:
|
| 165 |
+
rle = lvis.ann_to_rle(ann)
|
| 166 |
+
ann["segmentation"] = rle
|
| 167 |
+
|
| 168 |
+
def _prepare(self):
|
| 169 |
+
"""Prepare self._gts and self._dts for evaluation based on params."""
|
| 170 |
+
|
| 171 |
+
cat_ids = self.params.cat_ids if self.params.cat_ids else None
|
| 172 |
+
|
| 173 |
+
gts = self.lvis_gt.load_anns(
|
| 174 |
+
self.lvis_gt.get_ann_ids(img_ids=self.params.img_ids, cat_ids=cat_ids)
|
| 175 |
+
)
|
| 176 |
+
dts = self.lvis_dt.load_anns(
|
| 177 |
+
self.lvis_dt.get_ann_ids(img_ids=self.params.img_ids, cat_ids=cat_ids)
|
| 178 |
+
)
|
| 179 |
+
# convert ground truth to mask if iou_type == 'segm'
|
| 180 |
+
if self.params.iou_type == "segm":
|
| 181 |
+
self._to_mask(gts, self.lvis_gt)
|
| 182 |
+
self._to_mask(dts, self.lvis_dt)
|
| 183 |
+
|
| 184 |
+
for gt in gts:
|
| 185 |
+
self._gts[gt["image_id"], gt["category_id"]].append(gt)
|
| 186 |
+
|
| 187 |
+
# For federated dataset evaluation we will filter out all dt for an
|
| 188 |
+
# image which belong to categories not present in gt and not present in
|
| 189 |
+
# the negative list for an image. In other words detector is not penalized
|
| 190 |
+
# for categories about which we don't have gt information about their
|
| 191 |
+
# presence or absence in an image.
|
| 192 |
+
img_data = self.lvis_gt.load_imgs(ids=self.params.img_ids)
|
| 193 |
+
# per image map of categories not present in image
|
| 194 |
+
img_nl = {d["id"]: d["neg_category_ids"] for d in img_data}
|
| 195 |
+
# per image list of categories present in image
|
| 196 |
+
img_pl = {d["id"]: d["pos_category_ids"] for d in img_data}
|
| 197 |
+
# img_pl = defaultdict(set)
|
| 198 |
+
for ann in gts:
|
| 199 |
+
# img_pl[ann["image_id"]].add(ann["category_id"])
|
| 200 |
+
assert ann["category_id"] in img_pl[ann["image_id"]]
|
| 201 |
+
# print('check pos ids OK.')
|
| 202 |
+
|
| 203 |
+
for dt in dts:
|
| 204 |
+
img_id, cat_id = dt["image_id"], dt["category_id"]
|
| 205 |
+
if cat_id not in img_nl[img_id] and cat_id not in img_pl[img_id]:
|
| 206 |
+
continue
|
| 207 |
+
self._dts[img_id, cat_id].append(dt)
|
| 208 |
+
|
| 209 |
+
def evaluate(self):
|
| 210 |
+
"""
|
| 211 |
+
Run per image evaluation on given images and store results
|
| 212 |
+
(a list of dict) in self.eval_imgs.
|
| 213 |
+
"""
|
| 214 |
+
self.logger.info("Running per image evaluation.")
|
| 215 |
+
self.logger.info("Evaluate annotation type *{}*".format(self.params.iou_type))
|
| 216 |
+
|
| 217 |
+
self.params.img_ids = list(np.unique(self.params.img_ids))
|
| 218 |
+
|
| 219 |
+
if self.params.use_cats:
|
| 220 |
+
cat_ids = self.params.cat_ids
|
| 221 |
+
else:
|
| 222 |
+
cat_ids = [-1]
|
| 223 |
+
|
| 224 |
+
self._prepare()
|
| 225 |
+
|
| 226 |
+
self.ious = {
|
| 227 |
+
(img_id, cat_id): self.compute_iou(img_id, cat_id)
|
| 228 |
+
for img_id in self.params.img_ids
|
| 229 |
+
for cat_id in cat_ids
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
# loop through images, area range, max detection number
|
| 233 |
+
print('Evaluating ...')
|
| 234 |
+
self.eval_imgs = [
|
| 235 |
+
self.evaluate_img_google(img_id, cat_id, area_rng)
|
| 236 |
+
for cat_id in cat_ids
|
| 237 |
+
for area_rng in self.params.area_rng
|
| 238 |
+
for img_id in self.params.img_ids
|
| 239 |
+
]
|
| 240 |
+
|
| 241 |
+
def _get_gt_dt(self, img_id, cat_id):
|
| 242 |
+
"""Create gt, dt which are list of anns/dets. If use_cats is true
|
| 243 |
+
only anns/dets corresponding to tuple (img_id, cat_id) will be
|
| 244 |
+
used. Else, all anns/dets in image are used and cat_id is not used.
|
| 245 |
+
"""
|
| 246 |
+
if self.params.use_cats:
|
| 247 |
+
gt = self._gts[img_id, cat_id]
|
| 248 |
+
dt = self._dts[img_id, cat_id]
|
| 249 |
+
else:
|
| 250 |
+
gt = [
|
| 251 |
+
_ann
|
| 252 |
+
for _cat_id in self.params.cat_ids
|
| 253 |
+
for _ann in self._gts[img_id, cat_id]
|
| 254 |
+
]
|
| 255 |
+
dt = [
|
| 256 |
+
_ann
|
| 257 |
+
for _cat_id in self.params.cat_ids
|
| 258 |
+
for _ann in self._dts[img_id, cat_id]
|
| 259 |
+
]
|
| 260 |
+
return gt, dt
|
| 261 |
+
|
| 262 |
+
def compute_iou(self, img_id, cat_id):
|
| 263 |
+
gt, dt = self._get_gt_dt(img_id, cat_id)
|
| 264 |
+
|
| 265 |
+
if len(gt) == 0 and len(dt) == 0:
|
| 266 |
+
return []
|
| 267 |
+
|
| 268 |
+
# Sort detections in decreasing order of score.
|
| 269 |
+
idx = np.argsort([-d["score"] for d in dt], kind="mergesort")
|
| 270 |
+
dt = [dt[i] for i in idx]
|
| 271 |
+
|
| 272 |
+
# iscrowd = [int(False)] * len(gt)
|
| 273 |
+
iscrowd = [int('iscrowd' in g and g['iscrowd'] > 0) for g in gt]
|
| 274 |
+
|
| 275 |
+
if self.params.iou_type == "segm":
|
| 276 |
+
ann_type = "segmentation"
|
| 277 |
+
elif self.params.iou_type == "bbox":
|
| 278 |
+
ann_type = "bbox"
|
| 279 |
+
else:
|
| 280 |
+
raise ValueError("Unknown iou_type for iou computation.")
|
| 281 |
+
gt = [g[ann_type] for g in gt]
|
| 282 |
+
dt = [d[ann_type] for d in dt]
|
| 283 |
+
|
| 284 |
+
# compute iou between each dt and gt region
|
| 285 |
+
# will return array of shape len(dt), len(gt)
|
| 286 |
+
ious = mask_utils.iou(dt, gt, iscrowd)
|
| 287 |
+
return ious
|
| 288 |
+
|
| 289 |
+
def evaluate_img_google(self, img_id, cat_id, area_rng):
|
| 290 |
+
gt, dt = self._get_gt_dt(img_id, cat_id)
|
| 291 |
+
if len(gt) == 0 and len(dt) == 0:
|
| 292 |
+
return None
|
| 293 |
+
|
| 294 |
+
if len(dt) == 0:
|
| 295 |
+
return {
|
| 296 |
+
"image_id": img_id,
|
| 297 |
+
"category_id": cat_id,
|
| 298 |
+
"area_rng": area_rng,
|
| 299 |
+
"dt_ids": [],
|
| 300 |
+
"dt_matches": np.array([], dtype=np.int32).reshape(1, -1),
|
| 301 |
+
"dt_scores": [],
|
| 302 |
+
"dt_ignore": np.array([], dtype=np.int32).reshape(1, -1),
|
| 303 |
+
'num_gt': len(gt)
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
no_crowd_inds = [i for i, g in enumerate(gt) \
|
| 307 |
+
if ('iscrowd' not in g) or g['iscrowd'] == 0]
|
| 308 |
+
crowd_inds = [i for i, g in enumerate(gt) \
|
| 309 |
+
if 'iscrowd' in g and g['iscrowd'] == 1]
|
| 310 |
+
dt_idx = np.argsort([-d["score"] for d in dt], kind="mergesort")
|
| 311 |
+
|
| 312 |
+
if len(self.ious[img_id, cat_id]) > 0:
|
| 313 |
+
ious = self.ious[img_id, cat_id]
|
| 314 |
+
iou = ious[:, no_crowd_inds]
|
| 315 |
+
iou = iou[dt_idx]
|
| 316 |
+
ioa = ious[:, crowd_inds]
|
| 317 |
+
ioa = ioa[dt_idx]
|
| 318 |
+
else:
|
| 319 |
+
iou = np.zeros((len(dt_idx), 0))
|
| 320 |
+
ioa = np.zeros((len(dt_idx), 0))
|
| 321 |
+
scores = np.array([dt[i]['score'] for i in dt_idx])
|
| 322 |
+
|
| 323 |
+
num_detected_boxes = len(dt)
|
| 324 |
+
tp_fp_labels = np.zeros(num_detected_boxes, dtype=bool)
|
| 325 |
+
is_matched_to_group_of = np.zeros(num_detected_boxes, dtype=bool)
|
| 326 |
+
|
| 327 |
+
def compute_match_iou(iou):
|
| 328 |
+
max_overlap_gt_ids = np.argmax(iou, axis=1)
|
| 329 |
+
is_gt_detected = np.zeros(iou.shape[1], dtype=bool)
|
| 330 |
+
for i in range(num_detected_boxes):
|
| 331 |
+
gt_id = max_overlap_gt_ids[i]
|
| 332 |
+
is_evaluatable = (not tp_fp_labels[i] and
|
| 333 |
+
iou[i, gt_id] >= 0.5 and
|
| 334 |
+
not is_matched_to_group_of[i])
|
| 335 |
+
if is_evaluatable:
|
| 336 |
+
if not is_gt_detected[gt_id]:
|
| 337 |
+
tp_fp_labels[i] = True
|
| 338 |
+
is_gt_detected[gt_id] = True
|
| 339 |
+
|
| 340 |
+
def compute_match_ioa(ioa):
|
| 341 |
+
scores_group_of = np.zeros(ioa.shape[1], dtype=float)
|
| 342 |
+
tp_fp_labels_group_of = np.ones(
|
| 343 |
+
ioa.shape[1], dtype=float)
|
| 344 |
+
max_overlap_group_of_gt_ids = np.argmax(ioa, axis=1)
|
| 345 |
+
for i in range(num_detected_boxes):
|
| 346 |
+
gt_id = max_overlap_group_of_gt_ids[i]
|
| 347 |
+
is_evaluatable = (not tp_fp_labels[i] and
|
| 348 |
+
ioa[i, gt_id] >= 0.5 and
|
| 349 |
+
not is_matched_to_group_of[i])
|
| 350 |
+
if is_evaluatable:
|
| 351 |
+
is_matched_to_group_of[i] = True
|
| 352 |
+
scores_group_of[gt_id] = max(scores_group_of[gt_id], scores[i])
|
| 353 |
+
selector = np.where((scores_group_of > 0) & (tp_fp_labels_group_of > 0))
|
| 354 |
+
scores_group_of = scores_group_of[selector]
|
| 355 |
+
tp_fp_labels_group_of = tp_fp_labels_group_of[selector]
|
| 356 |
+
|
| 357 |
+
return scores_group_of, tp_fp_labels_group_of
|
| 358 |
+
|
| 359 |
+
if iou.shape[1] > 0:
|
| 360 |
+
compute_match_iou(iou)
|
| 361 |
+
|
| 362 |
+
scores_box_group_of = np.ndarray([0], dtype=float)
|
| 363 |
+
tp_fp_labels_box_group_of = np.ndarray([0], dtype=float)
|
| 364 |
+
|
| 365 |
+
if ioa.shape[1] > 0:
|
| 366 |
+
scores_box_group_of, tp_fp_labels_box_group_of = compute_match_ioa(ioa)
|
| 367 |
+
|
| 368 |
+
valid_entries = (~is_matched_to_group_of)
|
| 369 |
+
|
| 370 |
+
scores = np.concatenate(
|
| 371 |
+
(scores[valid_entries], scores_box_group_of))
|
| 372 |
+
tp_fps = np.concatenate(
|
| 373 |
+
(tp_fp_labels[valid_entries].astype(float),
|
| 374 |
+
tp_fp_labels_box_group_of))
|
| 375 |
+
|
| 376 |
+
return {
|
| 377 |
+
"image_id": img_id,
|
| 378 |
+
"category_id": cat_id,
|
| 379 |
+
"area_rng": area_rng,
|
| 380 |
+
"dt_matches": np.array([1 if x > 0 else 0 for x in tp_fps], dtype=np.int32).reshape(1, -1),
|
| 381 |
+
"dt_scores": [x for x in scores],
|
| 382 |
+
"dt_ignore": np.array([0 for x in scores], dtype=np.int32).reshape(1, -1),
|
| 383 |
+
'num_gt': len(gt)
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
def accumulate(self):
|
| 387 |
+
"""Accumulate per image evaluation results and store the result in
|
| 388 |
+
self.eval.
|
| 389 |
+
"""
|
| 390 |
+
self.logger.info("Accumulating evaluation results.")
|
| 391 |
+
|
| 392 |
+
if not self.eval_imgs:
|
| 393 |
+
self.logger.warn("Please run evaluate first.")
|
| 394 |
+
|
| 395 |
+
if self.params.use_cats:
|
| 396 |
+
cat_ids = self.params.cat_ids
|
| 397 |
+
else:
|
| 398 |
+
cat_ids = [-1]
|
| 399 |
+
|
| 400 |
+
num_thrs = 1
|
| 401 |
+
num_recalls = 1
|
| 402 |
+
|
| 403 |
+
num_cats = len(cat_ids)
|
| 404 |
+
num_area_rngs = 1
|
| 405 |
+
num_imgs = len(self.params.img_ids)
|
| 406 |
+
|
| 407 |
+
# -1 for absent categories
|
| 408 |
+
precision = -np.ones(
|
| 409 |
+
(num_thrs, num_recalls, num_cats, num_area_rngs)
|
| 410 |
+
)
|
| 411 |
+
recall = -np.ones((num_thrs, num_cats, num_area_rngs))
|
| 412 |
+
|
| 413 |
+
# Initialize dt_pointers
|
| 414 |
+
dt_pointers = {}
|
| 415 |
+
for cat_idx in range(num_cats):
|
| 416 |
+
dt_pointers[cat_idx] = {}
|
| 417 |
+
for area_idx in range(num_area_rngs):
|
| 418 |
+
dt_pointers[cat_idx][area_idx] = {}
|
| 419 |
+
|
| 420 |
+
# Per category evaluation
|
| 421 |
+
for cat_idx in range(num_cats):
|
| 422 |
+
Nk = cat_idx * num_area_rngs * num_imgs
|
| 423 |
+
for area_idx in range(num_area_rngs):
|
| 424 |
+
Na = area_idx * num_imgs
|
| 425 |
+
E = [
|
| 426 |
+
self.eval_imgs[Nk + Na + img_idx]
|
| 427 |
+
for img_idx in range(num_imgs)
|
| 428 |
+
]
|
| 429 |
+
# Remove elements which are None
|
| 430 |
+
E = [e for e in E if not e is None]
|
| 431 |
+
if len(E) == 0:
|
| 432 |
+
continue
|
| 433 |
+
|
| 434 |
+
dt_scores = np.concatenate([e["dt_scores"] for e in E], axis=0)
|
| 435 |
+
dt_idx = np.argsort(-dt_scores, kind="mergesort")
|
| 436 |
+
dt_scores = dt_scores[dt_idx]
|
| 437 |
+
dt_m = np.concatenate([e["dt_matches"] for e in E], axis=1)[:, dt_idx]
|
| 438 |
+
dt_ig = np.concatenate([e["dt_ignore"] for e in E], axis=1)[:, dt_idx]
|
| 439 |
+
|
| 440 |
+
num_gt = sum([e['num_gt'] for e in E])
|
| 441 |
+
if num_gt == 0:
|
| 442 |
+
continue
|
| 443 |
+
|
| 444 |
+
tps = np.logical_and(dt_m, np.logical_not(dt_ig))
|
| 445 |
+
fps = np.logical_and(np.logical_not(dt_m), np.logical_not(dt_ig))
|
| 446 |
+
tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float)
|
| 447 |
+
fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float)
|
| 448 |
+
|
| 449 |
+
dt_pointers[cat_idx][area_idx] = {
|
| 450 |
+
"tps": tps,
|
| 451 |
+
"fps": fps,
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
for iou_thr_idx, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
|
| 455 |
+
tp = np.array(tp)
|
| 456 |
+
fp = np.array(fp)
|
| 457 |
+
num_tp = len(tp)
|
| 458 |
+
rc = tp / num_gt
|
| 459 |
+
|
| 460 |
+
if num_tp:
|
| 461 |
+
recall[iou_thr_idx, cat_idx, area_idx] = rc[
|
| 462 |
+
-1
|
| 463 |
+
]
|
| 464 |
+
else:
|
| 465 |
+
recall[iou_thr_idx, cat_idx, area_idx] = 0
|
| 466 |
+
|
| 467 |
+
# np.spacing(1) ~= eps
|
| 468 |
+
pr = tp / (fp + tp + np.spacing(1))
|
| 469 |
+
pr = pr.tolist()
|
| 470 |
+
|
| 471 |
+
for i in range(num_tp - 1, 0, -1):
|
| 472 |
+
if pr[i] > pr[i - 1]:
|
| 473 |
+
pr[i - 1] = pr[i]
|
| 474 |
+
|
| 475 |
+
mAP = compute_average_precision(
|
| 476 |
+
np.array(pr, np.float).reshape(-1),
|
| 477 |
+
np.array(rc, np.float).reshape(-1))
|
| 478 |
+
precision[iou_thr_idx, :, cat_idx, area_idx] = mAP
|
| 479 |
+
|
| 480 |
+
self.eval = {
|
| 481 |
+
"params": self.params,
|
| 482 |
+
"counts": [num_thrs, num_recalls, num_cats, num_area_rngs],
|
| 483 |
+
"date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
| 484 |
+
"precision": precision,
|
| 485 |
+
"recall": recall,
|
| 486 |
+
"dt_pointers": dt_pointers,
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
def _summarize(self, summary_type):
|
| 490 |
+
s = self.eval["precision"]
|
| 491 |
+
if len(s[s > -1]) == 0:
|
| 492 |
+
mean_s = -1
|
| 493 |
+
else:
|
| 494 |
+
mean_s = np.mean(s[s > -1])
|
| 495 |
+
# print(s.reshape(1, 1, -1, 1))
|
| 496 |
+
return mean_s
|
| 497 |
+
|
| 498 |
+
def summarize(self):
|
| 499 |
+
"""Compute and display summary metrics for evaluation results."""
|
| 500 |
+
if not self.eval:
|
| 501 |
+
raise RuntimeError("Please run accumulate() first.")
|
| 502 |
+
|
| 503 |
+
max_dets = self.params.max_dets
|
| 504 |
+
self.results["AP50"] = self._summarize('ap')
|
| 505 |
+
|
| 506 |
+
def run(self):
|
| 507 |
+
"""Wrapper function which calculates the results."""
|
| 508 |
+
self.evaluate()
|
| 509 |
+
self.accumulate()
|
| 510 |
+
self.summarize()
|
| 511 |
+
|
| 512 |
+
def print_results(self):
|
| 513 |
+
template = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} catIds={:>3s}] = {:0.3f}"
|
| 514 |
+
|
| 515 |
+
for key, value in self.results.items():
|
| 516 |
+
max_dets = self.params.max_dets
|
| 517 |
+
if "AP" in key:
|
| 518 |
+
title = "Average Precision"
|
| 519 |
+
_type = "(AP)"
|
| 520 |
+
else:
|
| 521 |
+
title = "Average Recall"
|
| 522 |
+
_type = "(AR)"
|
| 523 |
+
|
| 524 |
+
if len(key) > 2 and key[2].isdigit():
|
| 525 |
+
iou_thr = (float(key[2:]) / 100)
|
| 526 |
+
iou = "{:0.2f}".format(iou_thr)
|
| 527 |
+
else:
|
| 528 |
+
iou = "{:0.2f}:{:0.2f}".format(
|
| 529 |
+
self.params.iou_thrs[0], self.params.iou_thrs[-1]
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
cat_group_name = "all"
|
| 533 |
+
area_rng = "all"
|
| 534 |
+
|
| 535 |
+
print(template.format(title, _type, iou, area_rng, max_dets, cat_group_name, value))
|
| 536 |
+
|
| 537 |
+
def get_results(self):
|
| 538 |
+
if not self.results:
|
| 539 |
+
self.logger.warn("results is empty. Call run().")
|
| 540 |
+
return self.results
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
class Params:
|
| 544 |
+
def __init__(self, iou_type):
|
| 545 |
+
self.img_ids = []
|
| 546 |
+
self.cat_ids = []
|
| 547 |
+
# np.arange causes trouble. the data point on arange is slightly
|
| 548 |
+
# larger than the true value
|
| 549 |
+
self.iou_thrs = np.linspace(
|
| 550 |
+
0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True
|
| 551 |
+
)
|
| 552 |
+
self.google_style = True
|
| 553 |
+
# print('Using google style PR curve')
|
| 554 |
+
self.iou_thrs = self.iou_thrs[:1]
|
| 555 |
+
self.max_dets = 1000
|
| 556 |
+
|
| 557 |
+
self.area_rng = [
|
| 558 |
+
[0 ** 2, 1e5 ** 2],
|
| 559 |
+
]
|
| 560 |
+
self.area_rng_lbl = ["all"]
|
| 561 |
+
self.use_cats = 1
|
| 562 |
+
self.iou_type = iou_type
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
class OIDEvaluator(DatasetEvaluator):
|
| 566 |
+
def __init__(self, dataset_name, cfg, distributed, output_dir=None):
|
| 567 |
+
self._distributed = distributed
|
| 568 |
+
self._output_dir = output_dir
|
| 569 |
+
|
| 570 |
+
self._cpu_device = torch.device("cpu")
|
| 571 |
+
self._logger = logging.getLogger(__name__)
|
| 572 |
+
|
| 573 |
+
self._metadata = MetadataCatalog.get(dataset_name)
|
| 574 |
+
json_file = PathManager.get_local_path(self._metadata.json_file)
|
| 575 |
+
self._oid_api = LVIS(json_file)
|
| 576 |
+
# Test set json files do not contain annotations (evaluation must be
|
| 577 |
+
# performed using the LVIS evaluation server).
|
| 578 |
+
self._do_evaluation = len(self._oid_api.get_ann_ids()) > 0
|
| 579 |
+
self._mask_on = cfg.MODEL.MASK_ON
|
| 580 |
+
|
| 581 |
+
def reset(self):
|
| 582 |
+
self._predictions = []
|
| 583 |
+
self._oid_results = []
|
| 584 |
+
|
| 585 |
+
def process(self, inputs, outputs):
|
| 586 |
+
for input, output in zip(inputs, outputs):
|
| 587 |
+
prediction = {"image_id": input["image_id"]}
|
| 588 |
+
instances = output["instances"].to(self._cpu_device)
|
| 589 |
+
prediction["instances"] = instances_to_coco_json(
|
| 590 |
+
instances, input["image_id"])
|
| 591 |
+
self._predictions.append(prediction)
|
| 592 |
+
|
| 593 |
+
def evaluate(self):
|
| 594 |
+
if self._distributed:
|
| 595 |
+
comm.synchronize()
|
| 596 |
+
self._predictions = comm.gather(self._predictions, dst=0)
|
| 597 |
+
self._predictions = list(itertools.chain(*self._predictions))
|
| 598 |
+
|
| 599 |
+
if not comm.is_main_process():
|
| 600 |
+
return
|
| 601 |
+
|
| 602 |
+
if len(self._predictions) == 0:
|
| 603 |
+
self._logger.warning("[LVISEvaluator] Did not receive valid predictions.")
|
| 604 |
+
return {}
|
| 605 |
+
|
| 606 |
+
self._logger.info("Preparing results in the OID format ...")
|
| 607 |
+
self._oid_results = list(
|
| 608 |
+
itertools.chain(*[x["instances"] for x in self._predictions]))
|
| 609 |
+
|
| 610 |
+
# unmap the category ids for LVIS (from 0-indexed to 1-indexed)
|
| 611 |
+
for result in self._oid_results:
|
| 612 |
+
result["category_id"] += 1
|
| 613 |
+
|
| 614 |
+
PathManager.mkdirs(self._output_dir)
|
| 615 |
+
file_path = os.path.join(
|
| 616 |
+
self._output_dir, "oid_instances_results.json")
|
| 617 |
+
self._logger.info("Saving results to {}".format(file_path))
|
| 618 |
+
with PathManager.open(file_path, "w") as f:
|
| 619 |
+
f.write(json.dumps(self._oid_results))
|
| 620 |
+
f.flush()
|
| 621 |
+
|
| 622 |
+
if not self._do_evaluation:
|
| 623 |
+
self._logger.info("Annotations are not available for evaluation.")
|
| 624 |
+
return
|
| 625 |
+
|
| 626 |
+
self._logger.info("Evaluating predictions ...")
|
| 627 |
+
self._results = OrderedDict()
|
| 628 |
+
res, mAP = _evaluate_predictions_on_oid(
|
| 629 |
+
self._oid_api,
|
| 630 |
+
file_path,
|
| 631 |
+
eval_seg=self._mask_on,
|
| 632 |
+
class_names=self._metadata.get("thing_classes"),
|
| 633 |
+
)
|
| 634 |
+
self._results['bbox'] = res
|
| 635 |
+
mAP_out_path = os.path.join(self._output_dir, "oid_mAP.npy")
|
| 636 |
+
self._logger.info('Saving mAP to' + mAP_out_path)
|
| 637 |
+
np.save(mAP_out_path, mAP)
|
| 638 |
+
return copy.deepcopy(self._results)
|
| 639 |
+
|
| 640 |
+
def _evaluate_predictions_on_oid(
|
| 641 |
+
oid_gt, oid_results_path, eval_seg=False,
|
| 642 |
+
class_names=None):
|
| 643 |
+
logger = logging.getLogger(__name__)
|
| 644 |
+
metrics = ["AP50", "AP50_expand"]
|
| 645 |
+
|
| 646 |
+
results = {}
|
| 647 |
+
oid_eval = OIDEval(oid_gt, oid_results_path, 'bbox', expand_pred_label=False)
|
| 648 |
+
oid_eval.run()
|
| 649 |
+
oid_eval.print_results()
|
| 650 |
+
results["AP50"] = oid_eval.get_results()["AP50"]
|
| 651 |
+
|
| 652 |
+
if eval_seg:
|
| 653 |
+
oid_eval = OIDEval(oid_gt, oid_results_path, 'segm', expand_pred_label=False)
|
| 654 |
+
oid_eval.run()
|
| 655 |
+
oid_eval.print_results()
|
| 656 |
+
results["AP50_segm"] = oid_eval.get_results()["AP50"]
|
| 657 |
+
else:
|
| 658 |
+
oid_eval = OIDEval(oid_gt, oid_results_path, 'bbox', expand_pred_label=True)
|
| 659 |
+
oid_eval.run()
|
| 660 |
+
oid_eval.print_results()
|
| 661 |
+
results["AP50_expand"] = oid_eval.get_results()["AP50"]
|
| 662 |
+
|
| 663 |
+
mAP = np.zeros(len(class_names)) - 1
|
| 664 |
+
precisions = oid_eval.eval['precision']
|
| 665 |
+
assert len(class_names) == precisions.shape[2]
|
| 666 |
+
results_per_category = []
|
| 667 |
+
id2apiid = sorted(oid_gt.get_cat_ids())
|
| 668 |
+
inst_aware_ap, inst_count = 0, 0
|
| 669 |
+
for idx, name in enumerate(class_names):
|
| 670 |
+
precision = precisions[:, :, idx, 0]
|
| 671 |
+
precision = precision[precision > -1]
|
| 672 |
+
ap = np.mean(precision) if precision.size else float("nan")
|
| 673 |
+
inst_num = len(oid_gt.get_ann_ids(cat_ids=[id2apiid[idx]]))
|
| 674 |
+
if inst_num > 0:
|
| 675 |
+
results_per_category.append(("{} {}".format(
|
| 676 |
+
name.replace(' ', '_'),
|
| 677 |
+
inst_num if inst_num < 1000 else '{:.1f}k'.format(inst_num / 1000)),
|
| 678 |
+
float(ap * 100)))
|
| 679 |
+
inst_aware_ap += inst_num * ap
|
| 680 |
+
inst_count += inst_num
|
| 681 |
+
mAP[idx] = ap
|
| 682 |
+
# logger.info("{} {} {:.2f}".format(name, inst_num, ap * 100))
|
| 683 |
+
inst_aware_ap = inst_aware_ap * 100 / inst_count
|
| 684 |
+
N_COLS = min(6, len(results_per_category) * 2)
|
| 685 |
+
results_flatten = list(itertools.chain(*results_per_category))
|
| 686 |
+
results_2d = itertools.zip_longest(*[results_flatten[i::N_COLS] for i in range(N_COLS)])
|
| 687 |
+
table = tabulate(
|
| 688 |
+
results_2d,
|
| 689 |
+
tablefmt="pipe",
|
| 690 |
+
floatfmt=".3f",
|
| 691 |
+
headers=["category", "AP"] * (N_COLS // 2),
|
| 692 |
+
numalign="left",
|
| 693 |
+
)
|
| 694 |
+
logger.info("Per-category {} AP: \n".format('bbox') + table)
|
| 695 |
+
logger.info("Instance-aware {} AP: {:.4f}".format('bbox', inst_aware_ap))
|
| 696 |
+
|
| 697 |
+
logger.info("Evaluation results for bbox: \n" + \
|
| 698 |
+
create_small_table(results))
|
| 699 |
+
return results, mAP
|
proxydet/modeling/backbone/swintransformer.py
ADDED
|
@@ -0,0 +1,750 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# Swin Transformer
|
| 3 |
+
# Copyright (c) 2021 Microsoft
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# Written by Ze Liu, Yutong Lin, Yixuan Wei
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 9 |
+
# Modified by Xingyi Zhou from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
import torch.utils.checkpoint as checkpoint
|
| 16 |
+
import numpy as np
|
| 17 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 18 |
+
|
| 19 |
+
from detectron2.layers import ShapeSpec
|
| 20 |
+
from detectron2.modeling.backbone.backbone import Backbone
|
| 21 |
+
from detectron2.modeling.backbone.build import BACKBONE_REGISTRY
|
| 22 |
+
from detectron2.modeling.backbone.fpn import FPN
|
| 23 |
+
|
| 24 |
+
from centernet.modeling.backbone.fpn_p5 import LastLevelP6P7_P5
|
| 25 |
+
from centernet.modeling.backbone.bifpn import BiFPN
|
| 26 |
+
# from .checkpoint import load_checkpoint
|
| 27 |
+
|
| 28 |
+
class Mlp(nn.Module):
|
| 29 |
+
""" Multilayer perceptron."""
|
| 30 |
+
|
| 31 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 32 |
+
super().__init__()
|
| 33 |
+
out_features = out_features or in_features
|
| 34 |
+
hidden_features = hidden_features or in_features
|
| 35 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 36 |
+
self.act = act_layer()
|
| 37 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 38 |
+
self.drop = nn.Dropout(drop)
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
x = self.fc1(x)
|
| 42 |
+
x = self.act(x)
|
| 43 |
+
x = self.drop(x)
|
| 44 |
+
x = self.fc2(x)
|
| 45 |
+
x = self.drop(x)
|
| 46 |
+
return x
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def window_partition(x, window_size):
|
| 50 |
+
"""
|
| 51 |
+
Args:
|
| 52 |
+
x: (B, H, W, C)
|
| 53 |
+
window_size (int): window size
|
| 54 |
+
Returns:
|
| 55 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 56 |
+
"""
|
| 57 |
+
B, H, W, C = x.shape
|
| 58 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
| 59 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| 60 |
+
return windows
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def window_reverse(windows, window_size, H, W):
|
| 64 |
+
"""
|
| 65 |
+
Args:
|
| 66 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 67 |
+
window_size (int): Window size
|
| 68 |
+
H (int): Height of image
|
| 69 |
+
W (int): Width of image
|
| 70 |
+
Returns:
|
| 71 |
+
x: (B, H, W, C)
|
| 72 |
+
"""
|
| 73 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 74 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
| 75 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 76 |
+
return x
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class WindowAttention(nn.Module):
|
| 80 |
+
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
| 81 |
+
It supports both of shifted and non-shifted window.
|
| 82 |
+
Args:
|
| 83 |
+
dim (int): Number of input channels.
|
| 84 |
+
window_size (tuple[int]): The height and width of the window.
|
| 85 |
+
num_heads (int): Number of attention heads.
|
| 86 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 87 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
| 88 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
| 89 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 93 |
+
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.dim = dim
|
| 96 |
+
self.window_size = window_size # Wh, Ww
|
| 97 |
+
self.num_heads = num_heads
|
| 98 |
+
head_dim = dim // num_heads
|
| 99 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 100 |
+
|
| 101 |
+
# define a parameter table of relative position bias
|
| 102 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 103 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 104 |
+
|
| 105 |
+
# get pair-wise relative position index for each token inside the window
|
| 106 |
+
coords_h = torch.arange(self.window_size[0])
|
| 107 |
+
coords_w = torch.arange(self.window_size[1])
|
| 108 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 109 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 110 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
| 111 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 112 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
| 113 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 114 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 115 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 116 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 117 |
+
|
| 118 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 119 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 120 |
+
self.proj = nn.Linear(dim, dim)
|
| 121 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 122 |
+
|
| 123 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
| 124 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 125 |
+
|
| 126 |
+
def forward(self, x, mask=None):
|
| 127 |
+
""" Forward function.
|
| 128 |
+
Args:
|
| 129 |
+
x: input features with shape of (num_windows*B, N, C)
|
| 130 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
| 131 |
+
"""
|
| 132 |
+
B_, N, C = x.shape
|
| 133 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 134 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 135 |
+
|
| 136 |
+
q = q * self.scale
|
| 137 |
+
attn = (q @ k.transpose(-2, -1))
|
| 138 |
+
|
| 139 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
| 140 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
| 141 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 142 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 143 |
+
|
| 144 |
+
if mask is not None:
|
| 145 |
+
nW = mask.shape[0]
|
| 146 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
| 147 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
| 148 |
+
attn = self.softmax(attn)
|
| 149 |
+
else:
|
| 150 |
+
attn = self.softmax(attn)
|
| 151 |
+
|
| 152 |
+
attn = self.attn_drop(attn)
|
| 153 |
+
|
| 154 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
| 155 |
+
x = self.proj(x)
|
| 156 |
+
x = self.proj_drop(x)
|
| 157 |
+
return x
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class SwinTransformerBlock(nn.Module):
|
| 161 |
+
""" Swin Transformer Block.
|
| 162 |
+
Args:
|
| 163 |
+
dim (int): Number of input channels.
|
| 164 |
+
num_heads (int): Number of attention heads.
|
| 165 |
+
window_size (int): Window size.
|
| 166 |
+
shift_size (int): Shift size for SW-MSA.
|
| 167 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 168 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 169 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 170 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 171 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 172 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 173 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
| 174 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
|
| 178 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
| 179 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 180 |
+
super().__init__()
|
| 181 |
+
self.dim = dim
|
| 182 |
+
self.num_heads = num_heads
|
| 183 |
+
self.window_size = window_size
|
| 184 |
+
self.shift_size = shift_size
|
| 185 |
+
self.mlp_ratio = mlp_ratio
|
| 186 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
| 187 |
+
|
| 188 |
+
self.norm1 = norm_layer(dim)
|
| 189 |
+
self.attn = WindowAttention(
|
| 190 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
| 191 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 192 |
+
|
| 193 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 194 |
+
self.norm2 = norm_layer(dim)
|
| 195 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 196 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 197 |
+
|
| 198 |
+
self.H = None
|
| 199 |
+
self.W = None
|
| 200 |
+
|
| 201 |
+
def forward(self, x, mask_matrix):
|
| 202 |
+
""" Forward function.
|
| 203 |
+
Args:
|
| 204 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 205 |
+
H, W: Spatial resolution of the input feature.
|
| 206 |
+
mask_matrix: Attention mask for cyclic shift.
|
| 207 |
+
"""
|
| 208 |
+
B, L, C = x.shape
|
| 209 |
+
H, W = self.H, self.W
|
| 210 |
+
assert L == H * W, "input feature has wrong size"
|
| 211 |
+
|
| 212 |
+
shortcut = x
|
| 213 |
+
x = self.norm1(x)
|
| 214 |
+
x = x.view(B, H, W, C)
|
| 215 |
+
|
| 216 |
+
# pad feature maps to multiples of window size
|
| 217 |
+
pad_l = pad_t = 0
|
| 218 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
| 219 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
| 220 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
| 221 |
+
_, Hp, Wp, _ = x.shape
|
| 222 |
+
|
| 223 |
+
# cyclic shift
|
| 224 |
+
if self.shift_size > 0:
|
| 225 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
| 226 |
+
attn_mask = mask_matrix
|
| 227 |
+
else:
|
| 228 |
+
shifted_x = x
|
| 229 |
+
attn_mask = None
|
| 230 |
+
|
| 231 |
+
# partition windows
|
| 232 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
| 233 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
| 234 |
+
|
| 235 |
+
# W-MSA/SW-MSA
|
| 236 |
+
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
| 237 |
+
|
| 238 |
+
# merge windows
|
| 239 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
| 240 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
| 241 |
+
|
| 242 |
+
# reverse cyclic shift
|
| 243 |
+
if self.shift_size > 0:
|
| 244 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
| 245 |
+
else:
|
| 246 |
+
x = shifted_x
|
| 247 |
+
|
| 248 |
+
if pad_r > 0 or pad_b > 0:
|
| 249 |
+
x = x[:, :H, :W, :].contiguous()
|
| 250 |
+
|
| 251 |
+
x = x.view(B, H * W, C)
|
| 252 |
+
|
| 253 |
+
# FFN
|
| 254 |
+
x = shortcut + self.drop_path(x)
|
| 255 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 256 |
+
|
| 257 |
+
return x
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class PatchMerging(nn.Module):
|
| 261 |
+
""" Patch Merging Layer
|
| 262 |
+
Args:
|
| 263 |
+
dim (int): Number of input channels.
|
| 264 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 265 |
+
"""
|
| 266 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
| 267 |
+
super().__init__()
|
| 268 |
+
self.dim = dim
|
| 269 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 270 |
+
self.norm = norm_layer(4 * dim)
|
| 271 |
+
|
| 272 |
+
def forward(self, x, H, W):
|
| 273 |
+
""" Forward function.
|
| 274 |
+
Args:
|
| 275 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 276 |
+
H, W: Spatial resolution of the input feature.
|
| 277 |
+
"""
|
| 278 |
+
B, L, C = x.shape
|
| 279 |
+
assert L == H * W, "input feature has wrong size"
|
| 280 |
+
|
| 281 |
+
x = x.view(B, H, W, C)
|
| 282 |
+
|
| 283 |
+
# padding
|
| 284 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
| 285 |
+
if pad_input:
|
| 286 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
| 287 |
+
|
| 288 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
| 289 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
| 290 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
| 291 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
| 292 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
| 293 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
| 294 |
+
|
| 295 |
+
x = self.norm(x)
|
| 296 |
+
x = self.reduction(x)
|
| 297 |
+
|
| 298 |
+
return x
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class BasicLayer(nn.Module):
|
| 302 |
+
""" A basic Swin Transformer layer for one stage.
|
| 303 |
+
Args:
|
| 304 |
+
dim (int): Number of feature channels
|
| 305 |
+
depth (int): Depths of this stage.
|
| 306 |
+
num_heads (int): Number of attention head.
|
| 307 |
+
window_size (int): Local window size. Default: 7.
|
| 308 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 309 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 310 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 311 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 312 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 313 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 314 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 315 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 316 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 317 |
+
"""
|
| 318 |
+
|
| 319 |
+
def __init__(self,
|
| 320 |
+
dim,
|
| 321 |
+
depth,
|
| 322 |
+
num_heads,
|
| 323 |
+
window_size=7,
|
| 324 |
+
mlp_ratio=4.,
|
| 325 |
+
qkv_bias=True,
|
| 326 |
+
qk_scale=None,
|
| 327 |
+
drop=0.,
|
| 328 |
+
attn_drop=0.,
|
| 329 |
+
drop_path=0.,
|
| 330 |
+
norm_layer=nn.LayerNorm,
|
| 331 |
+
downsample=None,
|
| 332 |
+
use_checkpoint=False):
|
| 333 |
+
super().__init__()
|
| 334 |
+
self.window_size = window_size
|
| 335 |
+
self.shift_size = window_size // 2
|
| 336 |
+
self.depth = depth
|
| 337 |
+
self.use_checkpoint = use_checkpoint
|
| 338 |
+
|
| 339 |
+
# build blocks
|
| 340 |
+
self.blocks = nn.ModuleList([
|
| 341 |
+
SwinTransformerBlock(
|
| 342 |
+
dim=dim,
|
| 343 |
+
num_heads=num_heads,
|
| 344 |
+
window_size=window_size,
|
| 345 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
| 346 |
+
mlp_ratio=mlp_ratio,
|
| 347 |
+
qkv_bias=qkv_bias,
|
| 348 |
+
qk_scale=qk_scale,
|
| 349 |
+
drop=drop,
|
| 350 |
+
attn_drop=attn_drop,
|
| 351 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 352 |
+
norm_layer=norm_layer)
|
| 353 |
+
for i in range(depth)])
|
| 354 |
+
|
| 355 |
+
# patch merging layer
|
| 356 |
+
if downsample is not None:
|
| 357 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
| 358 |
+
else:
|
| 359 |
+
self.downsample = None
|
| 360 |
+
|
| 361 |
+
def forward(self, x, H, W):
|
| 362 |
+
""" Forward function.
|
| 363 |
+
Args:
|
| 364 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 365 |
+
H, W: Spatial resolution of the input feature.
|
| 366 |
+
"""
|
| 367 |
+
|
| 368 |
+
# calculate attention mask for SW-MSA
|
| 369 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
| 370 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
| 371 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
| 372 |
+
h_slices = (slice(0, -self.window_size),
|
| 373 |
+
slice(-self.window_size, -self.shift_size),
|
| 374 |
+
slice(-self.shift_size, None))
|
| 375 |
+
w_slices = (slice(0, -self.window_size),
|
| 376 |
+
slice(-self.window_size, -self.shift_size),
|
| 377 |
+
slice(-self.shift_size, None))
|
| 378 |
+
cnt = 0
|
| 379 |
+
for h in h_slices:
|
| 380 |
+
for w in w_slices:
|
| 381 |
+
img_mask[:, h, w, :] = cnt
|
| 382 |
+
cnt += 1
|
| 383 |
+
|
| 384 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
| 385 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
| 386 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 387 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 388 |
+
|
| 389 |
+
for blk in self.blocks:
|
| 390 |
+
blk.H, blk.W = H, W
|
| 391 |
+
if self.use_checkpoint:
|
| 392 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
| 393 |
+
else:
|
| 394 |
+
x = blk(x, attn_mask)
|
| 395 |
+
if self.downsample is not None:
|
| 396 |
+
x_down = self.downsample(x, H, W)
|
| 397 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
| 398 |
+
return x, H, W, x_down, Wh, Ww
|
| 399 |
+
else:
|
| 400 |
+
return x, H, W, x, H, W
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
class PatchEmbed(nn.Module):
|
| 404 |
+
""" Image to Patch Embedding
|
| 405 |
+
Args:
|
| 406 |
+
patch_size (int): Patch token size. Default: 4.
|
| 407 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 408 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 409 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
| 410 |
+
"""
|
| 411 |
+
|
| 412 |
+
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
| 413 |
+
super().__init__()
|
| 414 |
+
patch_size = to_2tuple(patch_size)
|
| 415 |
+
self.patch_size = patch_size
|
| 416 |
+
|
| 417 |
+
self.in_chans = in_chans
|
| 418 |
+
self.embed_dim = embed_dim
|
| 419 |
+
|
| 420 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 421 |
+
if norm_layer is not None:
|
| 422 |
+
self.norm = norm_layer(embed_dim)
|
| 423 |
+
else:
|
| 424 |
+
self.norm = None
|
| 425 |
+
|
| 426 |
+
def forward(self, x):
|
| 427 |
+
"""Forward function."""
|
| 428 |
+
# padding
|
| 429 |
+
_, _, H, W = x.size()
|
| 430 |
+
if W % self.patch_size[1] != 0:
|
| 431 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
| 432 |
+
if H % self.patch_size[0] != 0:
|
| 433 |
+
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
| 434 |
+
|
| 435 |
+
x = self.proj(x) # B C Wh Ww
|
| 436 |
+
if self.norm is not None:
|
| 437 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 438 |
+
x = x.flatten(2).transpose(1, 2)
|
| 439 |
+
x = self.norm(x)
|
| 440 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
| 441 |
+
|
| 442 |
+
return x
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
class SwinTransformer(Backbone):
|
| 446 |
+
""" Swin Transformer backbone.
|
| 447 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
| 448 |
+
https://arxiv.org/pdf/2103.14030
|
| 449 |
+
Args:
|
| 450 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
| 451 |
+
used in absolute postion embedding. Default 224.
|
| 452 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
| 453 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 454 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 455 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
| 456 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
| 457 |
+
window_size (int): Window size. Default: 7.
|
| 458 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 459 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
| 460 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
| 461 |
+
drop_rate (float): Dropout rate.
|
| 462 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
| 463 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
| 464 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
| 465 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
| 466 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
| 467 |
+
out_indices (Sequence[int]): Output from which stages.
|
| 468 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
| 469 |
+
-1 means not freezing any parameters.
|
| 470 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 471 |
+
"""
|
| 472 |
+
|
| 473 |
+
def __init__(self,
|
| 474 |
+
pretrain_img_size=224,
|
| 475 |
+
patch_size=4,
|
| 476 |
+
in_chans=3,
|
| 477 |
+
embed_dim=96,
|
| 478 |
+
depths=[2, 2, 6, 2],
|
| 479 |
+
num_heads=[3, 6, 12, 24],
|
| 480 |
+
window_size=7,
|
| 481 |
+
mlp_ratio=4.,
|
| 482 |
+
qkv_bias=True,
|
| 483 |
+
qk_scale=None,
|
| 484 |
+
drop_rate=0.,
|
| 485 |
+
attn_drop_rate=0.,
|
| 486 |
+
drop_path_rate=0.2,
|
| 487 |
+
norm_layer=nn.LayerNorm,
|
| 488 |
+
ape=False,
|
| 489 |
+
patch_norm=True,
|
| 490 |
+
out_indices=(0, 1, 2, 3),
|
| 491 |
+
frozen_stages=-1,
|
| 492 |
+
use_checkpoint=False):
|
| 493 |
+
super().__init__()
|
| 494 |
+
|
| 495 |
+
self.pretrain_img_size = pretrain_img_size
|
| 496 |
+
self.num_layers = len(depths)
|
| 497 |
+
self.embed_dim = embed_dim
|
| 498 |
+
self.ape = ape
|
| 499 |
+
self.patch_norm = patch_norm
|
| 500 |
+
self.out_indices = out_indices
|
| 501 |
+
self.frozen_stages = frozen_stages
|
| 502 |
+
|
| 503 |
+
# split image into non-overlapping patches
|
| 504 |
+
self.patch_embed = PatchEmbed(
|
| 505 |
+
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
| 506 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
| 507 |
+
|
| 508 |
+
# absolute position embedding
|
| 509 |
+
if self.ape:
|
| 510 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
| 511 |
+
patch_size = to_2tuple(patch_size)
|
| 512 |
+
patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
|
| 513 |
+
|
| 514 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
|
| 515 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
| 516 |
+
|
| 517 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 518 |
+
|
| 519 |
+
# stochastic depth
|
| 520 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
| 521 |
+
|
| 522 |
+
# build layers
|
| 523 |
+
self.layers = nn.ModuleList()
|
| 524 |
+
for i_layer in range(self.num_layers):
|
| 525 |
+
layer = BasicLayer(
|
| 526 |
+
dim=int(embed_dim * 2 ** i_layer),
|
| 527 |
+
depth=depths[i_layer],
|
| 528 |
+
num_heads=num_heads[i_layer],
|
| 529 |
+
window_size=window_size,
|
| 530 |
+
mlp_ratio=mlp_ratio,
|
| 531 |
+
qkv_bias=qkv_bias,
|
| 532 |
+
qk_scale=qk_scale,
|
| 533 |
+
drop=drop_rate,
|
| 534 |
+
attn_drop=attn_drop_rate,
|
| 535 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
| 536 |
+
norm_layer=norm_layer,
|
| 537 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
| 538 |
+
use_checkpoint=use_checkpoint)
|
| 539 |
+
self.layers.append(layer)
|
| 540 |
+
|
| 541 |
+
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
| 542 |
+
self.num_features = num_features
|
| 543 |
+
|
| 544 |
+
# add a norm layer for each output
|
| 545 |
+
for i_layer in out_indices:
|
| 546 |
+
layer = norm_layer(num_features[i_layer])
|
| 547 |
+
layer_name = f'norm{i_layer}'
|
| 548 |
+
self.add_module(layer_name, layer)
|
| 549 |
+
|
| 550 |
+
self._freeze_stages()
|
| 551 |
+
self._out_features = ['swin{}'.format(i) for i in self.out_indices]
|
| 552 |
+
self._out_feature_channels = {
|
| 553 |
+
'swin{}'.format(i): self.embed_dim * 2 ** i for i in self.out_indices
|
| 554 |
+
}
|
| 555 |
+
self._out_feature_strides = {
|
| 556 |
+
'swin{}'.format(i): 2 ** (i + 2) for i in self.out_indices
|
| 557 |
+
}
|
| 558 |
+
self._size_devisibility = 32
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
def _freeze_stages(self):
|
| 562 |
+
if self.frozen_stages >= 0:
|
| 563 |
+
self.patch_embed.eval()
|
| 564 |
+
for param in self.patch_embed.parameters():
|
| 565 |
+
param.requires_grad = False
|
| 566 |
+
|
| 567 |
+
if self.frozen_stages >= 1 and self.ape:
|
| 568 |
+
self.absolute_pos_embed.requires_grad = False
|
| 569 |
+
|
| 570 |
+
if self.frozen_stages >= 2:
|
| 571 |
+
self.pos_drop.eval()
|
| 572 |
+
for i in range(0, self.frozen_stages - 1):
|
| 573 |
+
m = self.layers[i]
|
| 574 |
+
m.eval()
|
| 575 |
+
for param in m.parameters():
|
| 576 |
+
param.requires_grad = False
|
| 577 |
+
|
| 578 |
+
def init_weights(self, pretrained=None):
|
| 579 |
+
"""Initialize the weights in backbone.
|
| 580 |
+
Args:
|
| 581 |
+
pretrained (str, optional): Path to pre-trained weights.
|
| 582 |
+
Defaults to None.
|
| 583 |
+
"""
|
| 584 |
+
|
| 585 |
+
def _init_weights(m):
|
| 586 |
+
if isinstance(m, nn.Linear):
|
| 587 |
+
trunc_normal_(m.weight, std=.02)
|
| 588 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 589 |
+
nn.init.constant_(m.bias, 0)
|
| 590 |
+
elif isinstance(m, nn.LayerNorm):
|
| 591 |
+
nn.init.constant_(m.bias, 0)
|
| 592 |
+
nn.init.constant_(m.weight, 1.0)
|
| 593 |
+
|
| 594 |
+
if isinstance(pretrained, str):
|
| 595 |
+
self.apply(_init_weights)
|
| 596 |
+
# load_checkpoint(self, pretrained, strict=False)
|
| 597 |
+
elif pretrained is None:
|
| 598 |
+
self.apply(_init_weights)
|
| 599 |
+
else:
|
| 600 |
+
raise TypeError('pretrained must be a str or None')
|
| 601 |
+
|
| 602 |
+
def forward(self, x):
|
| 603 |
+
"""Forward function."""
|
| 604 |
+
x = self.patch_embed(x)
|
| 605 |
+
|
| 606 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 607 |
+
if self.ape:
|
| 608 |
+
# interpolate the position embedding to the corresponding size
|
| 609 |
+
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
| 610 |
+
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
| 611 |
+
else:
|
| 612 |
+
x = x.flatten(2).transpose(1, 2)
|
| 613 |
+
x = self.pos_drop(x)
|
| 614 |
+
|
| 615 |
+
# outs = []
|
| 616 |
+
outs = {}
|
| 617 |
+
for i in range(self.num_layers):
|
| 618 |
+
layer = self.layers[i]
|
| 619 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
| 620 |
+
|
| 621 |
+
if i in self.out_indices:
|
| 622 |
+
norm_layer = getattr(self, f'norm{i}')
|
| 623 |
+
x_out = norm_layer(x_out)
|
| 624 |
+
|
| 625 |
+
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
| 626 |
+
# outs.append(out)
|
| 627 |
+
outs['swin{}'.format(i)] = out
|
| 628 |
+
|
| 629 |
+
return outs
|
| 630 |
+
|
| 631 |
+
def train(self, mode=True):
|
| 632 |
+
"""Convert the model into training mode while keep layers freezed."""
|
| 633 |
+
super(SwinTransformer, self).train(mode)
|
| 634 |
+
self._freeze_stages()
|
| 635 |
+
|
| 636 |
+
size2config = {
|
| 637 |
+
'T': {
|
| 638 |
+
'window_size': 7,
|
| 639 |
+
'embed_dim': 96,
|
| 640 |
+
'depth': [2, 2, 6, 2],
|
| 641 |
+
'num_heads': [3, 6, 12, 24],
|
| 642 |
+
'drop_path_rate': 0.2,
|
| 643 |
+
'pretrained': 'models/swin_tiny_patch4_window7_224.pth'
|
| 644 |
+
},
|
| 645 |
+
'S': {
|
| 646 |
+
'window_size': 7,
|
| 647 |
+
'embed_dim': 96,
|
| 648 |
+
'depth': [2, 2, 18, 2],
|
| 649 |
+
'num_heads': [3, 6, 12, 24],
|
| 650 |
+
'drop_path_rate': 0.2,
|
| 651 |
+
'pretrained': 'models/swin_small_patch4_window7_224.pth'
|
| 652 |
+
},
|
| 653 |
+
'B': {
|
| 654 |
+
'window_size': 7,
|
| 655 |
+
'embed_dim': 128,
|
| 656 |
+
'depth': [2, 2, 18, 2],
|
| 657 |
+
'num_heads': [4, 8, 16, 32],
|
| 658 |
+
'drop_path_rate': 0.3,
|
| 659 |
+
'pretrained': 'models/swin_base_patch4_window7_224.pth'
|
| 660 |
+
},
|
| 661 |
+
'B-22k': {
|
| 662 |
+
'window_size': 7,
|
| 663 |
+
'embed_dim': 128,
|
| 664 |
+
'depth': [2, 2, 18, 2],
|
| 665 |
+
'num_heads': [4, 8, 16, 32],
|
| 666 |
+
'drop_path_rate': 0.3,
|
| 667 |
+
'pretrained': 'models/swin_base_patch4_window7_224_22k.pth'
|
| 668 |
+
},
|
| 669 |
+
'B-22k-384': {
|
| 670 |
+
'window_size': 12,
|
| 671 |
+
'embed_dim': 128,
|
| 672 |
+
'depth': [2, 2, 18, 2],
|
| 673 |
+
'num_heads': [4, 8, 16, 32],
|
| 674 |
+
'drop_path_rate': 0.3,
|
| 675 |
+
'pretrained': 'models/swin_base_patch4_window12_384_22k.pth'
|
| 676 |
+
},
|
| 677 |
+
'L-22k': {
|
| 678 |
+
'window_size': 7,
|
| 679 |
+
'embed_dim': 192,
|
| 680 |
+
'depth': [2, 2, 18, 2],
|
| 681 |
+
'num_heads': [6, 12, 24, 48],
|
| 682 |
+
'drop_path_rate': 0.3, # TODO (xingyi): this is unclear
|
| 683 |
+
'pretrained': 'models/swin_large_patch4_window7_224_22k.pth'
|
| 684 |
+
},
|
| 685 |
+
'L-22k-384': {
|
| 686 |
+
'window_size': 12,
|
| 687 |
+
'embed_dim': 192,
|
| 688 |
+
'depth': [2, 2, 18, 2],
|
| 689 |
+
'num_heads': [6, 12, 24, 48],
|
| 690 |
+
'drop_path_rate': 0.3, # TODO (xingyi): this is unclear
|
| 691 |
+
'pretrained': 'models/swin_large_patch4_window12_384_22k.pth'
|
| 692 |
+
}
|
| 693 |
+
}
|
| 694 |
+
|
| 695 |
+
@BACKBONE_REGISTRY.register()
|
| 696 |
+
def build_swintransformer_backbone(cfg, input_shape):
|
| 697 |
+
"""
|
| 698 |
+
"""
|
| 699 |
+
config = size2config[cfg.MODEL.SWIN.SIZE]
|
| 700 |
+
out_indices = cfg.MODEL.SWIN.OUT_FEATURES
|
| 701 |
+
model = SwinTransformer(
|
| 702 |
+
embed_dim=config['embed_dim'],
|
| 703 |
+
window_size=config['window_size'],
|
| 704 |
+
depths=config['depth'],
|
| 705 |
+
num_heads=config['num_heads'],
|
| 706 |
+
drop_path_rate=config['drop_path_rate'],
|
| 707 |
+
out_indices=out_indices,
|
| 708 |
+
frozen_stages=-1,
|
| 709 |
+
use_checkpoint=cfg.MODEL.SWIN.USE_CHECKPOINT
|
| 710 |
+
)
|
| 711 |
+
# print('Initializing', config['pretrained'])
|
| 712 |
+
model.init_weights(config['pretrained'])
|
| 713 |
+
return model
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
@BACKBONE_REGISTRY.register()
|
| 717 |
+
def build_swintransformer_fpn_backbone(cfg, input_shape: ShapeSpec):
|
| 718 |
+
"""
|
| 719 |
+
"""
|
| 720 |
+
bottom_up = build_swintransformer_backbone(cfg, input_shape)
|
| 721 |
+
in_features = cfg.MODEL.FPN.IN_FEATURES
|
| 722 |
+
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
|
| 723 |
+
backbone = FPN(
|
| 724 |
+
bottom_up=bottom_up,
|
| 725 |
+
in_features=in_features,
|
| 726 |
+
out_channels=out_channels,
|
| 727 |
+
norm=cfg.MODEL.FPN.NORM,
|
| 728 |
+
top_block=LastLevelP6P7_P5(out_channels, out_channels),
|
| 729 |
+
fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
|
| 730 |
+
)
|
| 731 |
+
return backbone
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
@BACKBONE_REGISTRY.register()
|
| 735 |
+
def build_swintransformer_bifpn_backbone(cfg, input_shape: ShapeSpec):
|
| 736 |
+
"""
|
| 737 |
+
"""
|
| 738 |
+
bottom_up = build_swintransformer_backbone(cfg, input_shape)
|
| 739 |
+
in_features = cfg.MODEL.FPN.IN_FEATURES
|
| 740 |
+
backbone = BiFPN(
|
| 741 |
+
cfg=cfg,
|
| 742 |
+
bottom_up=bottom_up,
|
| 743 |
+
in_features=in_features,
|
| 744 |
+
out_channels=cfg.MODEL.BIFPN.OUT_CHANNELS,
|
| 745 |
+
norm=cfg.MODEL.BIFPN.NORM,
|
| 746 |
+
num_levels=cfg.MODEL.BIFPN.NUM_LEVELS,
|
| 747 |
+
num_bifpn=cfg.MODEL.BIFPN.NUM_BIFPN,
|
| 748 |
+
separable_conv=cfg.MODEL.BIFPN.SEPARABLE_CONV,
|
| 749 |
+
)
|
| 750 |
+
return backbone
|
proxydet/modeling/backbone/timm.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 4 |
+
import math
|
| 5 |
+
from os.path import join
|
| 6 |
+
import numpy as np
|
| 7 |
+
import copy
|
| 8 |
+
from functools import partial
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn
|
| 12 |
+
import torch.utils.model_zoo as model_zoo
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import fvcore.nn.weight_init as weight_init
|
| 15 |
+
|
| 16 |
+
from detectron2.modeling.backbone import FPN
|
| 17 |
+
from detectron2.modeling.backbone.build import BACKBONE_REGISTRY
|
| 18 |
+
from detectron2.layers.batch_norm import get_norm, FrozenBatchNorm2d
|
| 19 |
+
from detectron2.modeling.backbone import Backbone
|
| 20 |
+
|
| 21 |
+
from timm import create_model
|
| 22 |
+
from timm.models.helpers import build_model_with_cfg
|
| 23 |
+
from timm.models.registry import register_model
|
| 24 |
+
from timm.models.resnet import ResNet, Bottleneck
|
| 25 |
+
from timm.models.resnet import default_cfgs as default_cfgs_resnet
|
| 26 |
+
from timm.models.convnext import ConvNeXt, default_cfgs, checkpoint_filter_fn
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@register_model
|
| 30 |
+
def convnext_tiny_21k(pretrained=False, **kwargs):
|
| 31 |
+
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
|
| 32 |
+
cfg = default_cfgs['convnext_tiny']
|
| 33 |
+
cfg['url'] = 'https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth'
|
| 34 |
+
model = build_model_with_cfg(
|
| 35 |
+
ConvNeXt, 'convnext_tiny', pretrained,
|
| 36 |
+
default_cfg=cfg,
|
| 37 |
+
pretrained_filter_fn=checkpoint_filter_fn,
|
| 38 |
+
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
|
| 39 |
+
**model_args)
|
| 40 |
+
return model
|
| 41 |
+
|
| 42 |
+
class CustomResNet(ResNet):
|
| 43 |
+
def __init__(self, **kwargs):
|
| 44 |
+
self.out_indices = kwargs.pop('out_indices')
|
| 45 |
+
super().__init__(**kwargs)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
x = self.conv1(x)
|
| 50 |
+
x = self.bn1(x)
|
| 51 |
+
x = self.act1(x)
|
| 52 |
+
x = self.maxpool(x)
|
| 53 |
+
ret = [x]
|
| 54 |
+
x = self.layer1(x)
|
| 55 |
+
ret.append(x)
|
| 56 |
+
x = self.layer2(x)
|
| 57 |
+
ret.append(x)
|
| 58 |
+
x = self.layer3(x)
|
| 59 |
+
ret.append(x)
|
| 60 |
+
x = self.layer4(x)
|
| 61 |
+
ret.append(x)
|
| 62 |
+
return [ret[i] for i in self.out_indices]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def load_pretrained(self, cached_file):
|
| 66 |
+
data = torch.load(cached_file, map_location='cpu')
|
| 67 |
+
if 'state_dict' in data:
|
| 68 |
+
self.load_state_dict(data['state_dict'])
|
| 69 |
+
else:
|
| 70 |
+
self.load_state_dict(data)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
model_params = {
|
| 74 |
+
'resnet50_in21k': dict(block=Bottleneck, layers=[3, 4, 6, 3]),
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def create_timm_resnet(variant, out_indices, pretrained=False, **kwargs):
|
| 79 |
+
params = model_params[variant]
|
| 80 |
+
default_cfgs_resnet['resnet50_in21k'] = \
|
| 81 |
+
copy.deepcopy(default_cfgs_resnet['resnet50'])
|
| 82 |
+
default_cfgs_resnet['resnet50_in21k']['url'] = \
|
| 83 |
+
'https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/resnet50_miil_21k.pth'
|
| 84 |
+
default_cfgs_resnet['resnet50_in21k']['num_classes'] = 11221
|
| 85 |
+
|
| 86 |
+
return build_model_with_cfg(
|
| 87 |
+
CustomResNet, variant, pretrained,
|
| 88 |
+
default_cfg=default_cfgs_resnet[variant],
|
| 89 |
+
out_indices=out_indices,
|
| 90 |
+
pretrained_custom_load=True,
|
| 91 |
+
**params,
|
| 92 |
+
**kwargs)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class LastLevelP6P7_P5(nn.Module):
|
| 96 |
+
"""
|
| 97 |
+
"""
|
| 98 |
+
def __init__(self, in_channels, out_channels):
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.num_levels = 2
|
| 101 |
+
self.in_feature = "p5"
|
| 102 |
+
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
|
| 103 |
+
self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
|
| 104 |
+
for module in [self.p6, self.p7]:
|
| 105 |
+
weight_init.c2_xavier_fill(module)
|
| 106 |
+
|
| 107 |
+
def forward(self, c5):
|
| 108 |
+
p6 = self.p6(c5)
|
| 109 |
+
p7 = self.p7(F.relu(p6))
|
| 110 |
+
return [p6, p7]
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def freeze_module(x):
|
| 114 |
+
"""
|
| 115 |
+
"""
|
| 116 |
+
for p in x.parameters():
|
| 117 |
+
p.requires_grad = False
|
| 118 |
+
FrozenBatchNorm2d.convert_frozen_batchnorm(x)
|
| 119 |
+
return x
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class TIMM(Backbone):
|
| 123 |
+
def __init__(self, base_name, out_levels, freeze_at=0, norm='FrozenBN', pretrained=False):
|
| 124 |
+
super().__init__()
|
| 125 |
+
out_indices = [x - 1 for x in out_levels]
|
| 126 |
+
if base_name in model_params:
|
| 127 |
+
self.base = create_timm_resnet(
|
| 128 |
+
base_name, out_indices=out_indices,
|
| 129 |
+
pretrained=False)
|
| 130 |
+
elif 'eff' in base_name or 'resnet' in base_name or 'regnet' in base_name:
|
| 131 |
+
self.base = create_model(
|
| 132 |
+
base_name, features_only=True,
|
| 133 |
+
out_indices=out_indices, pretrained=pretrained)
|
| 134 |
+
elif 'convnext' in base_name:
|
| 135 |
+
drop_path_rate = 0.2 \
|
| 136 |
+
if ('tiny' in base_name or 'small' in base_name) else 0.3
|
| 137 |
+
self.base = create_model(
|
| 138 |
+
base_name, features_only=True,
|
| 139 |
+
out_indices=out_indices, pretrained=pretrained,
|
| 140 |
+
drop_path_rate=drop_path_rate)
|
| 141 |
+
else:
|
| 142 |
+
assert 0, base_name
|
| 143 |
+
feature_info = [dict(num_chs=f['num_chs'], reduction=f['reduction']) \
|
| 144 |
+
for i, f in enumerate(self.base.feature_info)]
|
| 145 |
+
self._out_features = ['layer{}'.format(x) for x in out_levels]
|
| 146 |
+
self._out_feature_channels = {
|
| 147 |
+
'layer{}'.format(l): feature_info[l - 1]['num_chs'] for l in out_levels}
|
| 148 |
+
self._out_feature_strides = {
|
| 149 |
+
'layer{}'.format(l): feature_info[l - 1]['reduction'] for l in out_levels}
|
| 150 |
+
self._size_divisibility = max(self._out_feature_strides.values())
|
| 151 |
+
if 'resnet' in base_name:
|
| 152 |
+
self.freeze(freeze_at)
|
| 153 |
+
if norm == 'FrozenBN':
|
| 154 |
+
self = FrozenBatchNorm2d.convert_frozen_batchnorm(self)
|
| 155 |
+
|
| 156 |
+
def freeze(self, freeze_at=0):
|
| 157 |
+
"""
|
| 158 |
+
"""
|
| 159 |
+
if freeze_at >= 1:
|
| 160 |
+
print('Frezing', self.base.conv1)
|
| 161 |
+
self.base.conv1 = freeze_module(self.base.conv1)
|
| 162 |
+
if freeze_at >= 2:
|
| 163 |
+
print('Frezing', self.base.layer1)
|
| 164 |
+
self.base.layer1 = freeze_module(self.base.layer1)
|
| 165 |
+
|
| 166 |
+
def forward(self, x):
|
| 167 |
+
features = self.base(x)
|
| 168 |
+
ret = {k: v for k, v in zip(self._out_features, features)}
|
| 169 |
+
return ret
|
| 170 |
+
|
| 171 |
+
@property
|
| 172 |
+
def size_divisibility(self):
|
| 173 |
+
return self._size_divisibility
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
@BACKBONE_REGISTRY.register()
|
| 177 |
+
def build_timm_backbone(cfg, input_shape):
|
| 178 |
+
model = TIMM(
|
| 179 |
+
cfg.MODEL.TIMM.BASE_NAME,
|
| 180 |
+
cfg.MODEL.TIMM.OUT_LEVELS,
|
| 181 |
+
freeze_at=cfg.MODEL.TIMM.FREEZE_AT,
|
| 182 |
+
norm=cfg.MODEL.TIMM.NORM,
|
| 183 |
+
pretrained=cfg.MODEL.TIMM.PRETRAINED,
|
| 184 |
+
)
|
| 185 |
+
return model
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
@BACKBONE_REGISTRY.register()
|
| 189 |
+
def build_p67_timm_fpn_backbone(cfg, input_shape):
|
| 190 |
+
"""
|
| 191 |
+
"""
|
| 192 |
+
bottom_up = build_timm_backbone(cfg, input_shape)
|
| 193 |
+
in_features = cfg.MODEL.FPN.IN_FEATURES
|
| 194 |
+
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
|
| 195 |
+
backbone = FPN(
|
| 196 |
+
bottom_up=bottom_up,
|
| 197 |
+
in_features=in_features,
|
| 198 |
+
out_channels=out_channels,
|
| 199 |
+
norm=cfg.MODEL.FPN.NORM,
|
| 200 |
+
top_block=LastLevelP6P7_P5(out_channels, out_channels),
|
| 201 |
+
fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
|
| 202 |
+
)
|
| 203 |
+
return backbone
|
| 204 |
+
|
| 205 |
+
@BACKBONE_REGISTRY.register()
|
| 206 |
+
def build_p35_timm_fpn_backbone(cfg, input_shape):
|
| 207 |
+
"""
|
| 208 |
+
"""
|
| 209 |
+
bottom_up = build_timm_backbone(cfg, input_shape)
|
| 210 |
+
|
| 211 |
+
in_features = cfg.MODEL.FPN.IN_FEATURES
|
| 212 |
+
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
|
| 213 |
+
backbone = FPN(
|
| 214 |
+
bottom_up=bottom_up,
|
| 215 |
+
in_features=in_features,
|
| 216 |
+
out_channels=out_channels,
|
| 217 |
+
norm=cfg.MODEL.FPN.NORM,
|
| 218 |
+
top_block=None,
|
| 219 |
+
fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
|
| 220 |
+
)
|
| 221 |
+
return backbone
|
proxydet/modeling/debug.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
COLORS = ((np.random.rand(1300, 3) * 0.4 + 0.6) * 255).astype(
|
| 9 |
+
np.uint8).reshape(1300, 1, 1, 3)
|
| 10 |
+
|
| 11 |
+
def _get_color_image(heatmap):
|
| 12 |
+
heatmap = heatmap.reshape(
|
| 13 |
+
heatmap.shape[0], heatmap.shape[1], heatmap.shape[2], 1)
|
| 14 |
+
if heatmap.shape[0] == 1:
|
| 15 |
+
color_map = (heatmap * np.ones((1, 1, 1, 3), np.uint8) * 255).max(
|
| 16 |
+
axis=0).astype(np.uint8) # H, W, 3
|
| 17 |
+
else:
|
| 18 |
+
color_map = (heatmap * COLORS[:heatmap.shape[0]]).max(axis=0).astype(np.uint8) # H, W, 3
|
| 19 |
+
|
| 20 |
+
return color_map
|
| 21 |
+
|
| 22 |
+
def _blend_image(image, color_map, a=0.7):
|
| 23 |
+
color_map = cv2.resize(color_map, (image.shape[1], image.shape[0]))
|
| 24 |
+
ret = np.clip(image * (1 - a) + color_map * a, 0, 255).astype(np.uint8)
|
| 25 |
+
return ret
|
| 26 |
+
|
| 27 |
+
def _blend_image_heatmaps(image, color_maps, a=0.7):
|
| 28 |
+
merges = np.zeros((image.shape[0], image.shape[1], 3), np.float32)
|
| 29 |
+
for color_map in color_maps:
|
| 30 |
+
color_map = cv2.resize(color_map, (image.shape[1], image.shape[0]))
|
| 31 |
+
merges = np.maximum(merges, color_map)
|
| 32 |
+
ret = np.clip(image * (1 - a) + merges * a, 0, 255).astype(np.uint8)
|
| 33 |
+
return ret
|
| 34 |
+
|
| 35 |
+
def _decompose_level(x, shapes_per_level, N):
|
| 36 |
+
'''
|
| 37 |
+
x: LNHiWi x C
|
| 38 |
+
'''
|
| 39 |
+
x = x.view(x.shape[0], -1)
|
| 40 |
+
ret = []
|
| 41 |
+
st = 0
|
| 42 |
+
for l in range(len(shapes_per_level)):
|
| 43 |
+
ret.append([])
|
| 44 |
+
h = shapes_per_level[l][0].int().item()
|
| 45 |
+
w = shapes_per_level[l][1].int().item()
|
| 46 |
+
for i in range(N):
|
| 47 |
+
ret[l].append(x[st + h * w * i:st + h * w * (i + 1)].view(
|
| 48 |
+
h, w, -1).permute(2, 0, 1))
|
| 49 |
+
st += h * w * N
|
| 50 |
+
return ret
|
| 51 |
+
|
| 52 |
+
def _imagelist_to_tensor(images):
|
| 53 |
+
images = [x for x in images]
|
| 54 |
+
image_sizes = [x.shape[-2:] for x in images]
|
| 55 |
+
h = max([size[0] for size in image_sizes])
|
| 56 |
+
w = max([size[1] for size in image_sizes])
|
| 57 |
+
S = 32
|
| 58 |
+
h, w = ((h - 1) // S + 1) * S, ((w - 1) // S + 1) * S
|
| 59 |
+
images = [F.pad(x, (0, w - x.shape[2], 0, h - x.shape[1], 0, 0)) \
|
| 60 |
+
for x in images]
|
| 61 |
+
images = torch.stack(images)
|
| 62 |
+
return images
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _ind2il(ind, shapes_per_level, N):
|
| 66 |
+
r = ind
|
| 67 |
+
l = 0
|
| 68 |
+
S = 0
|
| 69 |
+
while r - S >= N * shapes_per_level[l][0] * shapes_per_level[l][1]:
|
| 70 |
+
S += N * shapes_per_level[l][0] * shapes_per_level[l][1]
|
| 71 |
+
l += 1
|
| 72 |
+
i = (r - S) // (shapes_per_level[l][0] * shapes_per_level[l][1])
|
| 73 |
+
return i, l
|
| 74 |
+
|
| 75 |
+
def debug_train(
|
| 76 |
+
images, gt_instances, flattened_hms, reg_targets, labels, pos_inds,
|
| 77 |
+
shapes_per_level, locations, strides):
|
| 78 |
+
'''
|
| 79 |
+
images: N x 3 x H x W
|
| 80 |
+
flattened_hms: LNHiWi x C
|
| 81 |
+
shapes_per_level: L x 2 [(H_i, W_i)]
|
| 82 |
+
locations: LNHiWi x 2
|
| 83 |
+
'''
|
| 84 |
+
reg_inds = torch.nonzero(
|
| 85 |
+
reg_targets.max(dim=1)[0] > 0).squeeze(1)
|
| 86 |
+
N = len(images)
|
| 87 |
+
images = _imagelist_to_tensor(images)
|
| 88 |
+
repeated_locations = [torch.cat([loc] * N, dim=0) \
|
| 89 |
+
for loc in locations]
|
| 90 |
+
locations = torch.cat(repeated_locations, dim=0)
|
| 91 |
+
gt_hms = _decompose_level(flattened_hms, shapes_per_level, N)
|
| 92 |
+
masks = flattened_hms.new_zeros((flattened_hms.shape[0], 1))
|
| 93 |
+
masks[pos_inds] = 1
|
| 94 |
+
masks = _decompose_level(masks, shapes_per_level, N)
|
| 95 |
+
for i in range(len(images)):
|
| 96 |
+
image = images[i].detach().cpu().numpy().transpose(1, 2, 0)
|
| 97 |
+
color_maps = []
|
| 98 |
+
for l in range(len(gt_hms)):
|
| 99 |
+
color_map = _get_color_image(
|
| 100 |
+
gt_hms[l][i].detach().cpu().numpy())
|
| 101 |
+
color_maps.append(color_map)
|
| 102 |
+
cv2.imshow('gthm_{}'.format(l), color_map)
|
| 103 |
+
blend = _blend_image_heatmaps(image.copy(), color_maps)
|
| 104 |
+
if gt_instances is not None:
|
| 105 |
+
bboxes = gt_instances[i].gt_boxes.tensor
|
| 106 |
+
for j in range(len(bboxes)):
|
| 107 |
+
bbox = bboxes[j]
|
| 108 |
+
cv2.rectangle(
|
| 109 |
+
blend,
|
| 110 |
+
(int(bbox[0]), int(bbox[1])),
|
| 111 |
+
(int(bbox[2]), int(bbox[3])),
|
| 112 |
+
(0, 0, 255), 3, cv2.LINE_AA)
|
| 113 |
+
|
| 114 |
+
for j in range(len(pos_inds)):
|
| 115 |
+
image_id, l = _ind2il(pos_inds[j], shapes_per_level, N)
|
| 116 |
+
if image_id != i:
|
| 117 |
+
continue
|
| 118 |
+
loc = locations[pos_inds[j]]
|
| 119 |
+
cv2.drawMarker(
|
| 120 |
+
blend, (int(loc[0]), int(loc[1])), (0, 255, 255),
|
| 121 |
+
markerSize=(l + 1) * 16)
|
| 122 |
+
|
| 123 |
+
for j in range(len(reg_inds)):
|
| 124 |
+
image_id, l = _ind2il(reg_inds[j], shapes_per_level, N)
|
| 125 |
+
if image_id != i:
|
| 126 |
+
continue
|
| 127 |
+
ltrb = reg_targets[reg_inds[j]]
|
| 128 |
+
ltrb *= strides[l]
|
| 129 |
+
loc = locations[reg_inds[j]]
|
| 130 |
+
bbox = [(loc[0] - ltrb[0]), (loc[1] - ltrb[1]),
|
| 131 |
+
(loc[0] + ltrb[2]), (loc[1] + ltrb[3])]
|
| 132 |
+
cv2.rectangle(
|
| 133 |
+
blend,
|
| 134 |
+
(int(bbox[0]), int(bbox[1])),
|
| 135 |
+
(int(bbox[2]), int(bbox[3])),
|
| 136 |
+
(255, 0, 0), 1, cv2.LINE_AA)
|
| 137 |
+
cv2.circle(blend, (int(loc[0]), int(loc[1])), 2, (255, 0, 0), -1)
|
| 138 |
+
|
| 139 |
+
cv2.imshow('blend', blend)
|
| 140 |
+
cv2.waitKey()
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def debug_test(
|
| 144 |
+
images, logits_pred, reg_pred, agn_hm_pred=[], preds=[],
|
| 145 |
+
vis_thresh=0.3, debug_show_name=False, mult_agn=False):
|
| 146 |
+
'''
|
| 147 |
+
images: N x 3 x H x W
|
| 148 |
+
class_target: LNHiWi x C
|
| 149 |
+
cat_agn_heatmap: LNHiWi
|
| 150 |
+
shapes_per_level: L x 2 [(H_i, W_i)]
|
| 151 |
+
'''
|
| 152 |
+
N = len(images)
|
| 153 |
+
for i in range(len(images)):
|
| 154 |
+
image = images[i].detach().cpu().numpy().transpose(1, 2, 0)
|
| 155 |
+
result = image.copy().astype(np.uint8)
|
| 156 |
+
pred_image = image.copy().astype(np.uint8)
|
| 157 |
+
color_maps = []
|
| 158 |
+
L = len(logits_pred)
|
| 159 |
+
for l in range(L):
|
| 160 |
+
if logits_pred[0] is not None:
|
| 161 |
+
stride = min(image.shape[0], image.shape[1]) / min(
|
| 162 |
+
logits_pred[l][i].shape[1], logits_pred[l][i].shape[2])
|
| 163 |
+
else:
|
| 164 |
+
stride = min(image.shape[0], image.shape[1]) / min(
|
| 165 |
+
agn_hm_pred[l][i].shape[1], agn_hm_pred[l][i].shape[2])
|
| 166 |
+
stride = stride if stride < 60 else 64 if stride < 100 else 128
|
| 167 |
+
if logits_pred[0] is not None:
|
| 168 |
+
if mult_agn:
|
| 169 |
+
logits_pred[l][i] = logits_pred[l][i] * agn_hm_pred[l][i]
|
| 170 |
+
color_map = _get_color_image(
|
| 171 |
+
logits_pred[l][i].detach().cpu().numpy())
|
| 172 |
+
color_maps.append(color_map)
|
| 173 |
+
cv2.imshow('predhm_{}'.format(l), color_map)
|
| 174 |
+
|
| 175 |
+
if debug_show_name:
|
| 176 |
+
from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES
|
| 177 |
+
cat2name = [x['name'] for x in LVIS_CATEGORIES]
|
| 178 |
+
for j in range(len(preds[i].scores) if preds is not None else 0):
|
| 179 |
+
if preds[i].scores[j] > vis_thresh:
|
| 180 |
+
bbox = preds[i].proposal_boxes[j] \
|
| 181 |
+
if preds[i].has('proposal_boxes') else \
|
| 182 |
+
preds[i].pred_boxes[j]
|
| 183 |
+
bbox = bbox.tensor[0].detach().cpu().numpy().astype(np.int32)
|
| 184 |
+
cat = int(preds[i].pred_classes[j]) \
|
| 185 |
+
if preds[i].has('pred_classes') else 0
|
| 186 |
+
cl = COLORS[cat, 0, 0]
|
| 187 |
+
cv2.rectangle(
|
| 188 |
+
pred_image, (int(bbox[0]), int(bbox[1])),
|
| 189 |
+
(int(bbox[2]), int(bbox[3])),
|
| 190 |
+
(int(cl[0]), int(cl[1]), int(cl[2])), 2, cv2.LINE_AA)
|
| 191 |
+
if debug_show_name:
|
| 192 |
+
txt = '{}{:.1f}'.format(
|
| 193 |
+
cat2name[cat] if cat > 0 else '',
|
| 194 |
+
preds[i].scores[j])
|
| 195 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 196 |
+
cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0]
|
| 197 |
+
cv2.rectangle(
|
| 198 |
+
pred_image,
|
| 199 |
+
(int(bbox[0]), int(bbox[1] - cat_size[1] - 2)),
|
| 200 |
+
(int(bbox[0] + cat_size[0]), int(bbox[1] - 2)),
|
| 201 |
+
(int(cl[0]), int(cl[1]), int(cl[2])), -1)
|
| 202 |
+
cv2.putText(
|
| 203 |
+
pred_image, txt, (int(bbox[0]), int(bbox[1] - 2)),
|
| 204 |
+
font, 0.5, (0, 0, 0), thickness=1, lineType=cv2.LINE_AA)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
if agn_hm_pred[l] is not None:
|
| 208 |
+
agn_hm_ = agn_hm_pred[l][i, 0, :, :, None].detach().cpu().numpy()
|
| 209 |
+
agn_hm_ = (agn_hm_ * np.array([255, 255, 255]).reshape(
|
| 210 |
+
1, 1, 3)).astype(np.uint8)
|
| 211 |
+
cv2.imshow('agn_hm_{}'.format(l), agn_hm_)
|
| 212 |
+
blend = _blend_image_heatmaps(image.copy(), color_maps)
|
| 213 |
+
cv2.imshow('blend', blend)
|
| 214 |
+
cv2.imshow('preds', pred_image)
|
| 215 |
+
cv2.waitKey()
|
| 216 |
+
|
| 217 |
+
global cnt
|
| 218 |
+
cnt = 0
|
| 219 |
+
|
| 220 |
+
def debug_second_stage(images, instances, proposals=None, vis_thresh=0.3,
|
| 221 |
+
save_debug=False, debug_show_name=False, image_labels=[],
|
| 222 |
+
save_debug_path='output/save_debug/',
|
| 223 |
+
bgr=False):
|
| 224 |
+
images = _imagelist_to_tensor(images)
|
| 225 |
+
if 'COCO' in save_debug_path:
|
| 226 |
+
from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
|
| 227 |
+
cat2name = [x['name'] for x in COCO_CATEGORIES]
|
| 228 |
+
else:
|
| 229 |
+
from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES
|
| 230 |
+
cat2name = ['({}){}'.format(x['frequency'], x['name']) \
|
| 231 |
+
for x in LVIS_CATEGORIES]
|
| 232 |
+
for i in range(len(images)):
|
| 233 |
+
image = images[i].detach().cpu().numpy().transpose(1, 2, 0).astype(np.uint8).copy()
|
| 234 |
+
if bgr:
|
| 235 |
+
image = image[:, :, ::-1].copy()
|
| 236 |
+
if instances[i].has('gt_boxes'):
|
| 237 |
+
bboxes = instances[i].gt_boxes.tensor.cpu().numpy()
|
| 238 |
+
scores = np.ones(bboxes.shape[0])
|
| 239 |
+
cats = instances[i].gt_classes.cpu().numpy()
|
| 240 |
+
else:
|
| 241 |
+
bboxes = instances[i].pred_boxes.tensor.cpu().numpy()
|
| 242 |
+
scores = instances[i].scores.cpu().numpy()
|
| 243 |
+
cats = instances[i].pred_classes.cpu().numpy()
|
| 244 |
+
for j in range(len(bboxes)):
|
| 245 |
+
if scores[j] > vis_thresh:
|
| 246 |
+
bbox = bboxes[j]
|
| 247 |
+
cl = COLORS[cats[j], 0, 0]
|
| 248 |
+
cl = (int(cl[0]), int(cl[1]), int(cl[2]))
|
| 249 |
+
cv2.rectangle(
|
| 250 |
+
image,
|
| 251 |
+
(int(bbox[0]), int(bbox[1])),
|
| 252 |
+
(int(bbox[2]), int(bbox[3])),
|
| 253 |
+
cl, 2, cv2.LINE_AA)
|
| 254 |
+
if debug_show_name:
|
| 255 |
+
cat = cats[j]
|
| 256 |
+
txt = '{}{:.1f}'.format(
|
| 257 |
+
cat2name[cat] if cat > 0 else '',
|
| 258 |
+
scores[j])
|
| 259 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 260 |
+
cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0]
|
| 261 |
+
cv2.rectangle(
|
| 262 |
+
image,
|
| 263 |
+
(int(bbox[0]), int(bbox[1] - cat_size[1] - 2)),
|
| 264 |
+
(int(bbox[0] + cat_size[0]), int(bbox[1] - 2)),
|
| 265 |
+
(int(cl[0]), int(cl[1]), int(cl[2])), -1)
|
| 266 |
+
cv2.putText(
|
| 267 |
+
image, txt, (int(bbox[0]), int(bbox[1] - 2)),
|
| 268 |
+
font, 0.5, (0, 0, 0), thickness=1, lineType=cv2.LINE_AA)
|
| 269 |
+
if proposals is not None:
|
| 270 |
+
proposal_image = images[i].detach().cpu().numpy().transpose(1, 2, 0).astype(np.uint8).copy()
|
| 271 |
+
if bgr:
|
| 272 |
+
proposal_image = proposal_image.copy()
|
| 273 |
+
else:
|
| 274 |
+
proposal_image = proposal_image[:, :, ::-1].copy()
|
| 275 |
+
bboxes = proposals[i].proposal_boxes.tensor.cpu().numpy()
|
| 276 |
+
if proposals[i].has('scores'):
|
| 277 |
+
scores = proposals[i].scores.detach().cpu().numpy()
|
| 278 |
+
else:
|
| 279 |
+
scores = proposals[i].objectness_logits.detach().cpu().numpy()
|
| 280 |
+
# selected = -1
|
| 281 |
+
# if proposals[i].has('image_loss'):
|
| 282 |
+
# selected = proposals[i].image_loss.argmin()
|
| 283 |
+
if proposals[i].has('selected'):
|
| 284 |
+
selected = proposals[i].selected
|
| 285 |
+
else:
|
| 286 |
+
selected = [-1 for _ in range(len(bboxes))]
|
| 287 |
+
for j in range(len(bboxes)):
|
| 288 |
+
if scores[j] > vis_thresh or selected[j] >= 0:
|
| 289 |
+
bbox = bboxes[j]
|
| 290 |
+
cl = (209, 159, 83)
|
| 291 |
+
th = 2
|
| 292 |
+
if selected[j] >= 0:
|
| 293 |
+
cl = (0, 0, 0xa4)
|
| 294 |
+
th = 4
|
| 295 |
+
cv2.rectangle(
|
| 296 |
+
proposal_image,
|
| 297 |
+
(int(bbox[0]), int(bbox[1])),
|
| 298 |
+
(int(bbox[2]), int(bbox[3])),
|
| 299 |
+
cl, th, cv2.LINE_AA)
|
| 300 |
+
if selected[j] >= 0 and debug_show_name:
|
| 301 |
+
cat = selected[j].item()
|
| 302 |
+
txt = '{}'.format(cat2name[cat])
|
| 303 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 304 |
+
cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0]
|
| 305 |
+
cv2.rectangle(
|
| 306 |
+
proposal_image,
|
| 307 |
+
(int(bbox[0]), int(bbox[1] - cat_size[1] - 2)),
|
| 308 |
+
(int(bbox[0] + cat_size[0]), int(bbox[1] - 2)),
|
| 309 |
+
(int(cl[0]), int(cl[1]), int(cl[2])), -1)
|
| 310 |
+
cv2.putText(
|
| 311 |
+
proposal_image, txt,
|
| 312 |
+
(int(bbox[0]), int(bbox[1] - 2)),
|
| 313 |
+
font, 0.5, (0, 0, 0), thickness=1,
|
| 314 |
+
lineType=cv2.LINE_AA)
|
| 315 |
+
|
| 316 |
+
if save_debug:
|
| 317 |
+
global cnt
|
| 318 |
+
cnt = (cnt + 1) % 5000
|
| 319 |
+
if not os.path.exists(save_debug_path):
|
| 320 |
+
os.mkdir(save_debug_path)
|
| 321 |
+
save_name = '{}/{:05d}.jpg'.format(save_debug_path, cnt)
|
| 322 |
+
if i < len(image_labels):
|
| 323 |
+
image_label = image_labels[i]
|
| 324 |
+
save_name = '{}/{:05d}'.format(save_debug_path, cnt)
|
| 325 |
+
for x in image_label:
|
| 326 |
+
class_name = cat2name[x]
|
| 327 |
+
save_name = save_name + '|{}'.format(class_name)
|
| 328 |
+
save_name = save_name + '.jpg'
|
| 329 |
+
cv2.imwrite(save_name, proposal_image)
|
| 330 |
+
else:
|
| 331 |
+
cv2.imshow('image', image)
|
| 332 |
+
if proposals is not None:
|
| 333 |
+
cv2.imshow('proposals', proposal_image)
|
| 334 |
+
cv2.waitKey()
|
proxydet/modeling/meta_arch/custom_rcnn.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import copy
|
| 3 |
+
import logging
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import Dict, List, Optional, Tuple
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
import json
|
| 9 |
+
from detectron2.utils.events import get_event_storage
|
| 10 |
+
from detectron2.config import configurable
|
| 11 |
+
from detectron2.structures import ImageList, Instances, Boxes
|
| 12 |
+
import detectron2.utils.comm as comm
|
| 13 |
+
|
| 14 |
+
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
|
| 15 |
+
from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN
|
| 16 |
+
from detectron2.modeling.postprocessing import detector_postprocess
|
| 17 |
+
from detectron2.utils.visualizer import Visualizer, _create_text_labels
|
| 18 |
+
from detectron2.data.detection_utils import convert_image_to_rgb
|
| 19 |
+
|
| 20 |
+
from torch.cuda.amp import autocast
|
| 21 |
+
from ..text.text_encoder import build_text_encoder
|
| 22 |
+
from ..utils import load_class_freq, get_fed_loss_inds
|
| 23 |
+
|
| 24 |
+
@META_ARCH_REGISTRY.register()
|
| 25 |
+
class CustomRCNN(GeneralizedRCNN):
|
| 26 |
+
'''
|
| 27 |
+
Add image labels
|
| 28 |
+
'''
|
| 29 |
+
@configurable
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
with_image_labels = False,
|
| 33 |
+
dataset_loss_weight = [],
|
| 34 |
+
fp16 = False,
|
| 35 |
+
sync_caption_batch = False,
|
| 36 |
+
roi_head_name = '',
|
| 37 |
+
cap_batch_ratio = 4,
|
| 38 |
+
with_caption = False,
|
| 39 |
+
dynamic_classifier = False,
|
| 40 |
+
**kwargs):
|
| 41 |
+
"""
|
| 42 |
+
"""
|
| 43 |
+
self.with_image_labels = with_image_labels
|
| 44 |
+
self.dataset_loss_weight = dataset_loss_weight
|
| 45 |
+
self.fp16 = fp16
|
| 46 |
+
self.with_caption = with_caption
|
| 47 |
+
self.sync_caption_batch = sync_caption_batch
|
| 48 |
+
self.roi_head_name = roi_head_name
|
| 49 |
+
self.cap_batch_ratio = cap_batch_ratio
|
| 50 |
+
self.dynamic_classifier = dynamic_classifier
|
| 51 |
+
self.return_proposal = False
|
| 52 |
+
if self.dynamic_classifier:
|
| 53 |
+
self.freq_weight = kwargs.pop('freq_weight')
|
| 54 |
+
self.num_classes = kwargs.pop('num_classes')
|
| 55 |
+
self.num_sample_cats = kwargs.pop('num_sample_cats')
|
| 56 |
+
super().__init__(**kwargs)
|
| 57 |
+
assert self.proposal_generator is not None
|
| 58 |
+
if self.with_caption:
|
| 59 |
+
assert not self.dynamic_classifier
|
| 60 |
+
self.text_encoder = build_text_encoder(pretrain=True)
|
| 61 |
+
for v in self.text_encoder.parameters():
|
| 62 |
+
v.requires_grad = False
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def from_config(cls, cfg):
|
| 67 |
+
ret = super().from_config(cfg)
|
| 68 |
+
ret.update({
|
| 69 |
+
'with_image_labels': cfg.WITH_IMAGE_LABELS,
|
| 70 |
+
'dataset_loss_weight': cfg.MODEL.DATASET_LOSS_WEIGHT,
|
| 71 |
+
'fp16': cfg.FP16,
|
| 72 |
+
'with_caption': cfg.MODEL.WITH_CAPTION,
|
| 73 |
+
'sync_caption_batch': cfg.MODEL.SYNC_CAPTION_BATCH,
|
| 74 |
+
'dynamic_classifier': cfg.MODEL.DYNAMIC_CLASSIFIER,
|
| 75 |
+
'roi_head_name': cfg.MODEL.ROI_HEADS.NAME,
|
| 76 |
+
'cap_batch_ratio': cfg.MODEL.CAP_BATCH_RATIO,
|
| 77 |
+
})
|
| 78 |
+
if ret['dynamic_classifier']:
|
| 79 |
+
ret['freq_weight'] = load_class_freq(
|
| 80 |
+
cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH,
|
| 81 |
+
cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT)
|
| 82 |
+
ret['num_classes'] = cfg.MODEL.ROI_HEADS.NUM_CLASSES
|
| 83 |
+
ret['num_sample_cats'] = cfg.MODEL.NUM_SAMPLE_CATS
|
| 84 |
+
return ret
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def inference(
|
| 88 |
+
self,
|
| 89 |
+
batched_inputs: Tuple[Dict[str, torch.Tensor]],
|
| 90 |
+
detected_instances: Optional[List[Instances]] = None,
|
| 91 |
+
do_postprocess: bool = True,
|
| 92 |
+
):
|
| 93 |
+
assert not self.training
|
| 94 |
+
assert detected_instances is None
|
| 95 |
+
|
| 96 |
+
images = self.preprocess_image(batched_inputs)
|
| 97 |
+
features = self.backbone(images.tensor)
|
| 98 |
+
proposals, _ = self.proposal_generator(images, features, None)
|
| 99 |
+
results, _ = self.roi_heads(images, features, proposals)
|
| 100 |
+
if do_postprocess:
|
| 101 |
+
assert not torch.jit.is_scripting(), \
|
| 102 |
+
"Scripting is not supported for postprocess."
|
| 103 |
+
return CustomRCNN._postprocess(
|
| 104 |
+
results, batched_inputs, images.image_sizes)
|
| 105 |
+
else:
|
| 106 |
+
return results
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
|
| 110 |
+
"""
|
| 111 |
+
Add ann_type
|
| 112 |
+
Ignore proposal loss when training with image labels
|
| 113 |
+
"""
|
| 114 |
+
if not self.training:
|
| 115 |
+
return self.inference(batched_inputs)
|
| 116 |
+
|
| 117 |
+
images = self.preprocess_image(batched_inputs)
|
| 118 |
+
|
| 119 |
+
ann_type = 'box'
|
| 120 |
+
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
|
| 121 |
+
if self.with_image_labels:
|
| 122 |
+
for inst, x in zip(gt_instances, batched_inputs):
|
| 123 |
+
inst._ann_type = x['ann_type']
|
| 124 |
+
inst._pos_category_ids = x['pos_category_ids']
|
| 125 |
+
ann_types = [x['ann_type'] for x in batched_inputs]
|
| 126 |
+
assert len(set(ann_types)) == 1
|
| 127 |
+
ann_type = ann_types[0]
|
| 128 |
+
if ann_type in ['prop', 'proptag']:
|
| 129 |
+
for t in gt_instances:
|
| 130 |
+
t.gt_classes *= 0
|
| 131 |
+
|
| 132 |
+
if self.fp16: # TODO (zhouxy): improve
|
| 133 |
+
with autocast():
|
| 134 |
+
features = self.backbone(images.tensor.half())
|
| 135 |
+
features = {k: v.float() for k, v in features.items()}
|
| 136 |
+
else:
|
| 137 |
+
features = self.backbone(images.tensor)
|
| 138 |
+
|
| 139 |
+
cls_features, cls_inds, caption_features = None, None, None
|
| 140 |
+
|
| 141 |
+
if self.with_caption and 'caption' in ann_type:
|
| 142 |
+
inds = [torch.randint(len(x['captions']), (1,))[0].item() \
|
| 143 |
+
for x in batched_inputs]
|
| 144 |
+
caps = [x['captions'][ind] for ind, x in zip(inds, batched_inputs)]
|
| 145 |
+
caption_features = self.text_encoder(caps).float()
|
| 146 |
+
if self.sync_caption_batch:
|
| 147 |
+
caption_features = self._sync_caption_features(
|
| 148 |
+
caption_features, ann_type, len(batched_inputs))
|
| 149 |
+
|
| 150 |
+
if self.dynamic_classifier and ann_type != 'caption':
|
| 151 |
+
cls_inds = self._sample_cls_inds(gt_instances, ann_type) # inds, inv_inds
|
| 152 |
+
ind_with_bg = cls_inds[0].tolist() + [-1]
|
| 153 |
+
cls_features = self.roi_heads.box_predictor[
|
| 154 |
+
0].cls_score.zs_weight[:, ind_with_bg].permute(1, 0).contiguous()
|
| 155 |
+
|
| 156 |
+
classifier_info = cls_features, cls_inds, caption_features
|
| 157 |
+
proposals, proposal_losses = self.proposal_generator(
|
| 158 |
+
images, features, gt_instances)
|
| 159 |
+
|
| 160 |
+
if self.roi_head_name in ['StandardROIHeads', 'CascadeROIHeads']:
|
| 161 |
+
proposals, detector_losses = self.roi_heads(
|
| 162 |
+
images, features, proposals, gt_instances)
|
| 163 |
+
else:
|
| 164 |
+
proposals, detector_losses = self.roi_heads(
|
| 165 |
+
images, features, proposals, gt_instances,
|
| 166 |
+
ann_type=ann_type, classifier_info=classifier_info)
|
| 167 |
+
|
| 168 |
+
if self.vis_period > 0:
|
| 169 |
+
storage = get_event_storage()
|
| 170 |
+
if storage.iter % self.vis_period == 0:
|
| 171 |
+
self.visualize_training(batched_inputs, proposals)
|
| 172 |
+
|
| 173 |
+
losses = {}
|
| 174 |
+
losses.update(detector_losses)
|
| 175 |
+
if self.with_image_labels:
|
| 176 |
+
if ann_type in ['box', 'prop', 'proptag']:
|
| 177 |
+
losses.update(proposal_losses)
|
| 178 |
+
else: # ignore proposal loss for non-bbox data
|
| 179 |
+
losses.update({k: v * 0 for k, v in proposal_losses.items()})
|
| 180 |
+
else:
|
| 181 |
+
losses.update(proposal_losses)
|
| 182 |
+
if len(self.dataset_loss_weight) > 0:
|
| 183 |
+
dataset_sources = [x['dataset_source'] for x in batched_inputs]
|
| 184 |
+
assert len(set(dataset_sources)) == 1
|
| 185 |
+
dataset_source = dataset_sources[0]
|
| 186 |
+
for k in losses:
|
| 187 |
+
losses[k] *= self.dataset_loss_weight[dataset_source]
|
| 188 |
+
|
| 189 |
+
if self.return_proposal:
|
| 190 |
+
return proposals, losses
|
| 191 |
+
else:
|
| 192 |
+
return losses
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _sync_caption_features(self, caption_features, ann_type, BS):
|
| 196 |
+
has_caption_feature = (caption_features is not None)
|
| 197 |
+
BS = (BS * self.cap_batch_ratio) if (ann_type == 'box') else BS
|
| 198 |
+
rank = torch.full(
|
| 199 |
+
(BS, 1), comm.get_rank(), dtype=torch.float32,
|
| 200 |
+
device=self.device)
|
| 201 |
+
if not has_caption_feature:
|
| 202 |
+
caption_features = rank.new_zeros((BS, 512))
|
| 203 |
+
caption_features = torch.cat([caption_features, rank], dim=1)
|
| 204 |
+
global_caption_features = comm.all_gather(caption_features)
|
| 205 |
+
caption_features = torch.cat(
|
| 206 |
+
[x.to(self.device) for x in global_caption_features], dim=0) \
|
| 207 |
+
if has_caption_feature else None # (NB) x (D + 1)
|
| 208 |
+
return caption_features
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def _sample_cls_inds(self, gt_instances, ann_type='box'):
|
| 212 |
+
if ann_type == 'box':
|
| 213 |
+
gt_classes = torch.cat(
|
| 214 |
+
[x.gt_classes for x in gt_instances])
|
| 215 |
+
C = len(self.freq_weight)
|
| 216 |
+
freq_weight = self.freq_weight
|
| 217 |
+
else:
|
| 218 |
+
gt_classes = torch.cat(
|
| 219 |
+
[torch.tensor(
|
| 220 |
+
x._pos_category_ids,
|
| 221 |
+
dtype=torch.long, device=x.gt_classes.device) \
|
| 222 |
+
for x in gt_instances])
|
| 223 |
+
C = self.num_classes
|
| 224 |
+
freq_weight = None
|
| 225 |
+
assert gt_classes.max() < C, '{} {}'.format(gt_classes.max(), C)
|
| 226 |
+
inds = get_fed_loss_inds(
|
| 227 |
+
gt_classes, self.num_sample_cats, C,
|
| 228 |
+
weight=freq_weight)
|
| 229 |
+
cls_id_map = gt_classes.new_full(
|
| 230 |
+
(self.num_classes + 1,), len(inds))
|
| 231 |
+
cls_id_map[inds] = torch.arange(len(inds), device=cls_id_map.device)
|
| 232 |
+
return inds, cls_id_map
|
proxydet/modeling/meta_arch/d2_deformable_detr.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch import nn
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone
|
| 8 |
+
from detectron2.structures import Boxes, Instances
|
| 9 |
+
from ..utils import load_class_freq, get_fed_loss_inds
|
| 10 |
+
|
| 11 |
+
from models.backbone import Joiner
|
| 12 |
+
from models.deformable_detr import DeformableDETR, SetCriterion, MLP
|
| 13 |
+
from models.deformable_detr import _get_clones
|
| 14 |
+
from models.matcher import HungarianMatcher
|
| 15 |
+
from models.position_encoding import PositionEmbeddingSine
|
| 16 |
+
from models.deformable_transformer import DeformableTransformer
|
| 17 |
+
from models.segmentation import sigmoid_focal_loss
|
| 18 |
+
from util.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh
|
| 19 |
+
from util.misc import NestedTensor, accuracy
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
__all__ = ["DeformableDetr"]
|
| 23 |
+
|
| 24 |
+
class CustomSetCriterion(SetCriterion):
|
| 25 |
+
def __init__(self, num_classes, matcher, weight_dict, losses, \
|
| 26 |
+
focal_alpha=0.25, use_fed_loss=False):
|
| 27 |
+
super().__init__(num_classes, matcher, weight_dict, losses, focal_alpha)
|
| 28 |
+
self.use_fed_loss = use_fed_loss
|
| 29 |
+
if self.use_fed_loss:
|
| 30 |
+
self.register_buffer(
|
| 31 |
+
'fed_loss_weight', load_class_freq(freq_weight=0.5))
|
| 32 |
+
|
| 33 |
+
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
|
| 34 |
+
"""Classification loss (NLL)
|
| 35 |
+
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
| 36 |
+
"""
|
| 37 |
+
assert 'pred_logits' in outputs
|
| 38 |
+
src_logits = outputs['pred_logits']
|
| 39 |
+
|
| 40 |
+
idx = self._get_src_permutation_idx(indices)
|
| 41 |
+
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
|
| 42 |
+
target_classes = torch.full(src_logits.shape[:2], self.num_classes,
|
| 43 |
+
dtype=torch.int64, device=src_logits.device)
|
| 44 |
+
target_classes[idx] = target_classes_o
|
| 45 |
+
|
| 46 |
+
target_classes_onehot = torch.zeros(
|
| 47 |
+
[src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
|
| 48 |
+
dtype=src_logits.dtype, layout=src_logits.layout,
|
| 49 |
+
device=src_logits.device)
|
| 50 |
+
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
|
| 51 |
+
|
| 52 |
+
target_classes_onehot = target_classes_onehot[:,:,:-1] # B x N x C
|
| 53 |
+
if self.use_fed_loss:
|
| 54 |
+
inds = get_fed_loss_inds(
|
| 55 |
+
gt_classes=target_classes_o,
|
| 56 |
+
num_sample_cats=50,
|
| 57 |
+
weight=self.fed_loss_weight,
|
| 58 |
+
C=target_classes_onehot.shape[2])
|
| 59 |
+
loss_ce = sigmoid_focal_loss(
|
| 60 |
+
src_logits[:, :, inds],
|
| 61 |
+
target_classes_onehot[:, :, inds],
|
| 62 |
+
num_boxes,
|
| 63 |
+
alpha=self.focal_alpha,
|
| 64 |
+
gamma=2) * src_logits.shape[1]
|
| 65 |
+
else:
|
| 66 |
+
loss_ce = sigmoid_focal_loss(
|
| 67 |
+
src_logits, target_classes_onehot, num_boxes,
|
| 68 |
+
alpha=self.focal_alpha,
|
| 69 |
+
gamma=2) * src_logits.shape[1]
|
| 70 |
+
losses = {'loss_ce': loss_ce}
|
| 71 |
+
|
| 72 |
+
if log:
|
| 73 |
+
# TODO this should probably be a separate loss, not hacked in this one here
|
| 74 |
+
losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
|
| 75 |
+
return losses
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class MaskedBackbone(nn.Module):
|
| 79 |
+
""" This is a thin wrapper around D2's backbone to provide padding masking"""
|
| 80 |
+
|
| 81 |
+
def __init__(self, cfg):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.backbone = build_backbone(cfg)
|
| 84 |
+
backbone_shape = self.backbone.output_shape()
|
| 85 |
+
self.feature_strides = [backbone_shape[f].stride for f in backbone_shape.keys()]
|
| 86 |
+
self.strides = [backbone_shape[f].stride for f in backbone_shape.keys()]
|
| 87 |
+
self.num_channels = [backbone_shape[x].channels for x in backbone_shape.keys()]
|
| 88 |
+
|
| 89 |
+
def forward(self, tensor_list: NestedTensor):
|
| 90 |
+
xs = self.backbone(tensor_list.tensors)
|
| 91 |
+
out = {}
|
| 92 |
+
for name, x in xs.items():
|
| 93 |
+
m = tensor_list.mask
|
| 94 |
+
assert m is not None
|
| 95 |
+
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
| 96 |
+
out[name] = NestedTensor(x, mask)
|
| 97 |
+
return out
|
| 98 |
+
|
| 99 |
+
@META_ARCH_REGISTRY.register()
|
| 100 |
+
class DeformableDetr(nn.Module):
|
| 101 |
+
"""
|
| 102 |
+
Implement Deformable Detr
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(self, cfg):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.with_image_labels = cfg.WITH_IMAGE_LABELS
|
| 108 |
+
self.weak_weight = cfg.MODEL.DETR.WEAK_WEIGHT
|
| 109 |
+
|
| 110 |
+
self.device = torch.device(cfg.MODEL.DEVICE)
|
| 111 |
+
self.test_topk = cfg.TEST.DETECTIONS_PER_IMAGE
|
| 112 |
+
self.num_classes = cfg.MODEL.DETR.NUM_CLASSES
|
| 113 |
+
self.mask_on = cfg.MODEL.MASK_ON
|
| 114 |
+
hidden_dim = cfg.MODEL.DETR.HIDDEN_DIM
|
| 115 |
+
num_queries = cfg.MODEL.DETR.NUM_OBJECT_QUERIES
|
| 116 |
+
|
| 117 |
+
# Transformer parameters:
|
| 118 |
+
nheads = cfg.MODEL.DETR.NHEADS
|
| 119 |
+
dropout = cfg.MODEL.DETR.DROPOUT
|
| 120 |
+
dim_feedforward = cfg.MODEL.DETR.DIM_FEEDFORWARD
|
| 121 |
+
enc_layers = cfg.MODEL.DETR.ENC_LAYERS
|
| 122 |
+
dec_layers = cfg.MODEL.DETR.DEC_LAYERS
|
| 123 |
+
num_feature_levels = cfg.MODEL.DETR.NUM_FEATURE_LEVELS
|
| 124 |
+
two_stage = cfg.MODEL.DETR.TWO_STAGE
|
| 125 |
+
with_box_refine = cfg.MODEL.DETR.WITH_BOX_REFINE
|
| 126 |
+
|
| 127 |
+
# Loss parameters:
|
| 128 |
+
giou_weight = cfg.MODEL.DETR.GIOU_WEIGHT
|
| 129 |
+
l1_weight = cfg.MODEL.DETR.L1_WEIGHT
|
| 130 |
+
deep_supervision = cfg.MODEL.DETR.DEEP_SUPERVISION
|
| 131 |
+
cls_weight = cfg.MODEL.DETR.CLS_WEIGHT
|
| 132 |
+
focal_alpha = cfg.MODEL.DETR.FOCAL_ALPHA
|
| 133 |
+
|
| 134 |
+
N_steps = hidden_dim // 2
|
| 135 |
+
d2_backbone = MaskedBackbone(cfg)
|
| 136 |
+
backbone = Joiner(d2_backbone, PositionEmbeddingSine(N_steps, normalize=True))
|
| 137 |
+
|
| 138 |
+
transformer = DeformableTransformer(
|
| 139 |
+
d_model=hidden_dim,
|
| 140 |
+
nhead=nheads,
|
| 141 |
+
num_encoder_layers=enc_layers,
|
| 142 |
+
num_decoder_layers=dec_layers,
|
| 143 |
+
dim_feedforward=dim_feedforward,
|
| 144 |
+
dropout=dropout,
|
| 145 |
+
activation="relu",
|
| 146 |
+
return_intermediate_dec=True,
|
| 147 |
+
num_feature_levels=num_feature_levels,
|
| 148 |
+
dec_n_points=4,
|
| 149 |
+
enc_n_points=4,
|
| 150 |
+
two_stage=two_stage,
|
| 151 |
+
two_stage_num_proposals=num_queries)
|
| 152 |
+
|
| 153 |
+
self.detr = DeformableDETR(
|
| 154 |
+
backbone, transformer, num_classes=self.num_classes,
|
| 155 |
+
num_queries=num_queries,
|
| 156 |
+
num_feature_levels=num_feature_levels,
|
| 157 |
+
aux_loss=deep_supervision,
|
| 158 |
+
with_box_refine=with_box_refine,
|
| 159 |
+
two_stage=two_stage,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
if self.mask_on:
|
| 163 |
+
assert 0, 'Mask is not supported yet :('
|
| 164 |
+
|
| 165 |
+
matcher = HungarianMatcher(
|
| 166 |
+
cost_class=cls_weight, cost_bbox=l1_weight, cost_giou=giou_weight)
|
| 167 |
+
weight_dict = {"loss_ce": cls_weight, "loss_bbox": l1_weight}
|
| 168 |
+
weight_dict["loss_giou"] = giou_weight
|
| 169 |
+
if deep_supervision:
|
| 170 |
+
aux_weight_dict = {}
|
| 171 |
+
for i in range(dec_layers - 1):
|
| 172 |
+
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
| 173 |
+
weight_dict.update(aux_weight_dict)
|
| 174 |
+
print('weight_dict', weight_dict)
|
| 175 |
+
losses = ["labels", "boxes", "cardinality"]
|
| 176 |
+
if self.mask_on:
|
| 177 |
+
losses += ["masks"]
|
| 178 |
+
self.criterion = CustomSetCriterion(
|
| 179 |
+
self.num_classes, matcher=matcher, weight_dict=weight_dict,
|
| 180 |
+
focal_alpha=focal_alpha,
|
| 181 |
+
losses=losses,
|
| 182 |
+
use_fed_loss=cfg.MODEL.DETR.USE_FED_LOSS
|
| 183 |
+
)
|
| 184 |
+
pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1)
|
| 185 |
+
pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1)
|
| 186 |
+
self.normalizer = lambda x: (x - pixel_mean) / pixel_std
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def forward(self, batched_inputs):
|
| 190 |
+
"""
|
| 191 |
+
Args:
|
| 192 |
+
Returns:
|
| 193 |
+
dict[str: Tensor]:
|
| 194 |
+
mapping from a named loss to a tensor storing the loss. Used during training only.
|
| 195 |
+
"""
|
| 196 |
+
images = self.preprocess_image(batched_inputs)
|
| 197 |
+
output = self.detr(images)
|
| 198 |
+
if self.training:
|
| 199 |
+
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
|
| 200 |
+
targets = self.prepare_targets(gt_instances)
|
| 201 |
+
loss_dict = self.criterion(output, targets)
|
| 202 |
+
weight_dict = self.criterion.weight_dict
|
| 203 |
+
for k in loss_dict.keys():
|
| 204 |
+
if k in weight_dict:
|
| 205 |
+
loss_dict[k] *= weight_dict[k]
|
| 206 |
+
if self.with_image_labels:
|
| 207 |
+
if batched_inputs[0]['ann_type'] in ['image', 'captiontag']:
|
| 208 |
+
loss_dict['loss_image'] = self.weak_weight * self._weak_loss(
|
| 209 |
+
output, batched_inputs)
|
| 210 |
+
else:
|
| 211 |
+
loss_dict['loss_image'] = images[0].new_zeros(
|
| 212 |
+
[1], dtype=torch.float32)[0]
|
| 213 |
+
# import pdb; pdb.set_trace()
|
| 214 |
+
return loss_dict
|
| 215 |
+
else:
|
| 216 |
+
image_sizes = output["pred_boxes"].new_tensor(
|
| 217 |
+
[(t["height"], t["width"]) for t in batched_inputs])
|
| 218 |
+
results = self.post_process(output, image_sizes)
|
| 219 |
+
return results
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def prepare_targets(self, targets):
|
| 223 |
+
new_targets = []
|
| 224 |
+
for targets_per_image in targets:
|
| 225 |
+
h, w = targets_per_image.image_size
|
| 226 |
+
image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device)
|
| 227 |
+
gt_classes = targets_per_image.gt_classes
|
| 228 |
+
gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy
|
| 229 |
+
gt_boxes = box_xyxy_to_cxcywh(gt_boxes)
|
| 230 |
+
new_targets.append({"labels": gt_classes, "boxes": gt_boxes})
|
| 231 |
+
if self.mask_on and hasattr(targets_per_image, 'gt_masks'):
|
| 232 |
+
assert 0, 'Mask is not supported yet :('
|
| 233 |
+
gt_masks = targets_per_image.gt_masks
|
| 234 |
+
gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)
|
| 235 |
+
new_targets[-1].update({'masks': gt_masks})
|
| 236 |
+
return new_targets
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def post_process(self, outputs, target_sizes):
|
| 240 |
+
"""
|
| 241 |
+
"""
|
| 242 |
+
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
|
| 243 |
+
assert len(out_logits) == len(target_sizes)
|
| 244 |
+
assert target_sizes.shape[1] == 2
|
| 245 |
+
|
| 246 |
+
prob = out_logits.sigmoid()
|
| 247 |
+
topk_values, topk_indexes = torch.topk(
|
| 248 |
+
prob.view(out_logits.shape[0], -1), self.test_topk, dim=1)
|
| 249 |
+
scores = topk_values
|
| 250 |
+
topk_boxes = topk_indexes // out_logits.shape[2]
|
| 251 |
+
labels = topk_indexes % out_logits.shape[2]
|
| 252 |
+
boxes = box_cxcywh_to_xyxy(out_bbox)
|
| 253 |
+
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
|
| 254 |
+
|
| 255 |
+
# and from relative [0, 1] to absolute [0, height] coordinates
|
| 256 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 257 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
| 258 |
+
boxes = boxes * scale_fct[:, None, :]
|
| 259 |
+
|
| 260 |
+
results = []
|
| 261 |
+
for s, l, b, size in zip(scores, labels, boxes, target_sizes):
|
| 262 |
+
r = Instances((size[0], size[1]))
|
| 263 |
+
r.pred_boxes = Boxes(b)
|
| 264 |
+
r.scores = s
|
| 265 |
+
r.pred_classes = l
|
| 266 |
+
results.append({'instances': r})
|
| 267 |
+
return results
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def preprocess_image(self, batched_inputs):
|
| 271 |
+
"""
|
| 272 |
+
Normalize, pad and batch the input images.
|
| 273 |
+
"""
|
| 274 |
+
images = [self.normalizer(x["image"].to(self.device)) for x in batched_inputs]
|
| 275 |
+
return images
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def _weak_loss(self, outputs, batched_inputs):
|
| 279 |
+
loss = 0
|
| 280 |
+
for b, x in enumerate(batched_inputs):
|
| 281 |
+
labels = x['pos_category_ids']
|
| 282 |
+
pred_logits = [outputs['pred_logits'][b]]
|
| 283 |
+
pred_boxes = [outputs['pred_boxes'][b]]
|
| 284 |
+
for xx in outputs['aux_outputs']:
|
| 285 |
+
pred_logits.append(xx['pred_logits'][b])
|
| 286 |
+
pred_boxes.append(xx['pred_boxes'][b])
|
| 287 |
+
pred_logits = torch.stack(pred_logits, dim=0) # L x N x C
|
| 288 |
+
pred_boxes = torch.stack(pred_boxes, dim=0) # L x N x 4
|
| 289 |
+
for label in labels:
|
| 290 |
+
loss += self._max_size_loss(
|
| 291 |
+
pred_logits, pred_boxes, label) / len(labels)
|
| 292 |
+
loss = loss / len(batched_inputs)
|
| 293 |
+
return loss
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def _max_size_loss(self, logits, boxes, label):
|
| 297 |
+
'''
|
| 298 |
+
Inputs:
|
| 299 |
+
logits: L x N x C
|
| 300 |
+
boxes: L x N x 4
|
| 301 |
+
'''
|
| 302 |
+
target = logits.new_zeros((logits.shape[0], logits.shape[2]))
|
| 303 |
+
target[:, label] = 1.
|
| 304 |
+
sizes = boxes[..., 2] * boxes[..., 3] # L x N
|
| 305 |
+
ind = sizes.argmax(dim=1) # L
|
| 306 |
+
loss = F.binary_cross_entropy_with_logits(
|
| 307 |
+
logits[range(len(ind)), ind], target, reduction='sum')
|
| 308 |
+
return loss
|
proxydet/modeling/roi_heads/proxydet_fast_rcnn.py
ADDED
|
@@ -0,0 +1,618 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
'''
|
| 3 |
+
Modifications Copyright (c) 2024-present NAVER Corp, Apache License v2.0
|
| 4 |
+
original source: https://github.com/facebookresearch/Detic/blob/main/detic/modeling/roi_heads/detic_fast_rcnn.py
|
| 5 |
+
'''
|
| 6 |
+
import logging
|
| 7 |
+
import math
|
| 8 |
+
import json
|
| 9 |
+
import numpy as np
|
| 10 |
+
from typing import Dict, Union
|
| 11 |
+
import torch
|
| 12 |
+
from fvcore.nn import giou_loss, smooth_l1_loss
|
| 13 |
+
from torch import nn
|
| 14 |
+
from torch.nn import functional as F
|
| 15 |
+
import fvcore.nn.weight_init as weight_init
|
| 16 |
+
import detectron2.utils.comm as comm
|
| 17 |
+
from detectron2.config import configurable
|
| 18 |
+
from detectron2.layers import ShapeSpec, batched_nms, cat, cross_entropy, nonzero_tuple
|
| 19 |
+
from detectron2.structures import Boxes, Instances
|
| 20 |
+
from detectron2.utils.events import get_event_storage
|
| 21 |
+
from detectron2.modeling.box_regression import Box2BoxTransform
|
| 22 |
+
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
|
| 23 |
+
from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference
|
| 24 |
+
from detectron2.modeling.roi_heads.fast_rcnn import _log_classification_stats
|
| 25 |
+
|
| 26 |
+
from torch.cuda.amp import autocast
|
| 27 |
+
from ..utils import load_class_freq, get_fed_loss_inds
|
| 28 |
+
from .zero_shot_classifier import ZeroShotClassifier
|
| 29 |
+
|
| 30 |
+
__all__ = ["ProxydetFastRCNNOutputLayers"]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ProxydetFastRCNNOutputLayers(FastRCNNOutputLayers):
|
| 34 |
+
@configurable
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
input_shape: ShapeSpec,
|
| 38 |
+
*,
|
| 39 |
+
mult_proposal_score=False,
|
| 40 |
+
cls_score=None,
|
| 41 |
+
sync_caption_batch = False,
|
| 42 |
+
use_sigmoid_ce = False,
|
| 43 |
+
use_fed_loss = False,
|
| 44 |
+
ignore_zero_cats = False,
|
| 45 |
+
fed_loss_num_cat = 50,
|
| 46 |
+
dynamic_classifier = False,
|
| 47 |
+
image_label_loss = '',
|
| 48 |
+
use_zeroshot_cls = False,
|
| 49 |
+
image_loss_weight = 0.1,
|
| 50 |
+
with_softmax_prop = False,
|
| 51 |
+
caption_weight = 1.0,
|
| 52 |
+
neg_cap_weight = 1.0,
|
| 53 |
+
add_image_box = False,
|
| 54 |
+
debug = False,
|
| 55 |
+
prior_prob = 0.01,
|
| 56 |
+
cat_freq_path = '',
|
| 57 |
+
fed_loss_freq_weight = 0.5,
|
| 58 |
+
softmax_weak_loss = False,
|
| 59 |
+
use_regional_embedding=False,
|
| 60 |
+
base_cat_mask: str = None,
|
| 61 |
+
**kwargs,
|
| 62 |
+
):
|
| 63 |
+
super().__init__(
|
| 64 |
+
input_shape=input_shape,
|
| 65 |
+
**kwargs,
|
| 66 |
+
)
|
| 67 |
+
self.mult_proposal_score = mult_proposal_score
|
| 68 |
+
self.sync_caption_batch = sync_caption_batch
|
| 69 |
+
self.use_sigmoid_ce = use_sigmoid_ce
|
| 70 |
+
self.use_fed_loss = use_fed_loss
|
| 71 |
+
self.ignore_zero_cats = ignore_zero_cats
|
| 72 |
+
self.fed_loss_num_cat = fed_loss_num_cat
|
| 73 |
+
self.dynamic_classifier = dynamic_classifier
|
| 74 |
+
self.image_label_loss = image_label_loss
|
| 75 |
+
self.use_zeroshot_cls = use_zeroshot_cls
|
| 76 |
+
self.image_loss_weight = image_loss_weight
|
| 77 |
+
self.with_softmax_prop = with_softmax_prop
|
| 78 |
+
self.caption_weight = caption_weight
|
| 79 |
+
self.neg_cap_weight = neg_cap_weight
|
| 80 |
+
self.add_image_box = add_image_box
|
| 81 |
+
self.softmax_weak_loss = softmax_weak_loss
|
| 82 |
+
self.debug = debug
|
| 83 |
+
self.use_regional_embedding = use_regional_embedding
|
| 84 |
+
self.base_cat_mask = torch.tensor(np.load(base_cat_mask).nonzero()[0])
|
| 85 |
+
|
| 86 |
+
if softmax_weak_loss:
|
| 87 |
+
assert image_label_loss in ['max_size']
|
| 88 |
+
|
| 89 |
+
if self.use_sigmoid_ce:
|
| 90 |
+
bias_value = -math.log((1 - prior_prob) / prior_prob)
|
| 91 |
+
nn.init.constant_(self.cls_score.bias, bias_value)
|
| 92 |
+
|
| 93 |
+
if self.use_fed_loss or self.ignore_zero_cats:
|
| 94 |
+
freq_weight = load_class_freq(cat_freq_path, fed_loss_freq_weight)
|
| 95 |
+
self.register_buffer('freq_weight', freq_weight)
|
| 96 |
+
else:
|
| 97 |
+
self.freq_weight = None
|
| 98 |
+
|
| 99 |
+
if self.use_fed_loss and len(self.freq_weight) < self.num_classes:
|
| 100 |
+
# assert self.num_classes == 11493
|
| 101 |
+
print('Extending federated loss weight')
|
| 102 |
+
self.freq_weight = torch.cat(
|
| 103 |
+
[self.freq_weight,
|
| 104 |
+
self.freq_weight.new_zeros(
|
| 105 |
+
self.num_classes - len(self.freq_weight))]
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
assert (not self.dynamic_classifier) or (not self.use_fed_loss)
|
| 109 |
+
input_size = input_shape.channels * \
|
| 110 |
+
(input_shape.width or 1) * (input_shape.height or 1)
|
| 111 |
+
|
| 112 |
+
if self.use_zeroshot_cls:
|
| 113 |
+
del self.cls_score
|
| 114 |
+
del self.bbox_pred
|
| 115 |
+
assert cls_score is not None
|
| 116 |
+
self.cls_score = cls_score
|
| 117 |
+
self.bbox_pred = nn.Sequential(
|
| 118 |
+
nn.Linear(input_size, input_size),
|
| 119 |
+
nn.ReLU(inplace=True),
|
| 120 |
+
nn.Linear(input_size, 4)
|
| 121 |
+
)
|
| 122 |
+
weight_init.c2_xavier_fill(self.bbox_pred[0])
|
| 123 |
+
nn.init.normal_(self.bbox_pred[-1].weight, std=0.001)
|
| 124 |
+
nn.init.constant_(self.bbox_pred[-1].bias, 0)
|
| 125 |
+
|
| 126 |
+
if self.with_softmax_prop:
|
| 127 |
+
self.prop_score = nn.Sequential(
|
| 128 |
+
nn.Linear(input_size, input_size),
|
| 129 |
+
nn.ReLU(inplace=True),
|
| 130 |
+
nn.Linear(input_size, self.num_classes + 1),
|
| 131 |
+
)
|
| 132 |
+
weight_init.c2_xavier_fill(self.prop_score[0])
|
| 133 |
+
nn.init.normal_(self.prop_score[-1].weight, mean=0, std=0.001)
|
| 134 |
+
nn.init.constant_(self.prop_score[-1].bias, 0)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@classmethod
|
| 138 |
+
def from_config(cls, cfg, input_shape):
|
| 139 |
+
ret = super().from_config(cfg, input_shape)
|
| 140 |
+
ret.update({
|
| 141 |
+
'mult_proposal_score': cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE,
|
| 142 |
+
'sync_caption_batch': cfg.MODEL.SYNC_CAPTION_BATCH,
|
| 143 |
+
'use_sigmoid_ce': cfg.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE,
|
| 144 |
+
'use_fed_loss': cfg.MODEL.ROI_BOX_HEAD.USE_FED_LOSS,
|
| 145 |
+
'ignore_zero_cats': cfg.MODEL.ROI_BOX_HEAD.IGNORE_ZERO_CATS,
|
| 146 |
+
'fed_loss_num_cat': cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CAT,
|
| 147 |
+
'dynamic_classifier': cfg.MODEL.DYNAMIC_CLASSIFIER,
|
| 148 |
+
'image_label_loss': cfg.MODEL.ROI_BOX_HEAD.IMAGE_LABEL_LOSS,
|
| 149 |
+
'use_zeroshot_cls': cfg.MODEL.ROI_BOX_HEAD.USE_ZEROSHOT_CLS,
|
| 150 |
+
'image_loss_weight': cfg.MODEL.ROI_BOX_HEAD.IMAGE_LOSS_WEIGHT,
|
| 151 |
+
'with_softmax_prop': cfg.MODEL.ROI_BOX_HEAD.WITH_SOFTMAX_PROP,
|
| 152 |
+
'caption_weight': cfg.MODEL.ROI_BOX_HEAD.CAPTION_WEIGHT,
|
| 153 |
+
'neg_cap_weight': cfg.MODEL.ROI_BOX_HEAD.NEG_CAP_WEIGHT,
|
| 154 |
+
'add_image_box': cfg.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX,
|
| 155 |
+
'debug': cfg.DEBUG or cfg.SAVE_DEBUG or cfg.IS_DEBUG,
|
| 156 |
+
'prior_prob': cfg.MODEL.ROI_BOX_HEAD.PRIOR_PROB,
|
| 157 |
+
'cat_freq_path': cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH,
|
| 158 |
+
'fed_loss_freq_weight': cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT,
|
| 159 |
+
'softmax_weak_loss': cfg.MODEL.ROI_BOX_HEAD.SOFTMAX_WEAK_LOSS,
|
| 160 |
+
"use_regional_embedding": cfg.MODEL.ROI_BOX_HEAD.USE_REGIONAL_EMBEDDING,
|
| 161 |
+
"base_cat_mask": cfg.MODEL.ROI_HEADS.BASE_CAT_MASK,
|
| 162 |
+
})
|
| 163 |
+
if ret['use_zeroshot_cls']:
|
| 164 |
+
ret['cls_score'] = ZeroShotClassifier(cfg, input_shape)
|
| 165 |
+
return ret
|
| 166 |
+
|
| 167 |
+
def losses(self, predictions, proposals, \
|
| 168 |
+
use_advanced_loss=True,
|
| 169 |
+
classifier_info=(None,None,None)):
|
| 170 |
+
"""
|
| 171 |
+
enable advanced loss
|
| 172 |
+
"""
|
| 173 |
+
scores, proposal_deltas = predictions
|
| 174 |
+
gt_classes = (
|
| 175 |
+
cat([p.gt_classes for p in proposals], dim=0) if len(proposals) else torch.empty(0)
|
| 176 |
+
)
|
| 177 |
+
num_classes = self.num_classes
|
| 178 |
+
if self.dynamic_classifier:
|
| 179 |
+
_, cls_id_map = classifier_info[1]
|
| 180 |
+
gt_classes = cls_id_map[gt_classes]
|
| 181 |
+
num_classes = scores.shape[1] - 1
|
| 182 |
+
assert cls_id_map[self.num_classes] == num_classes
|
| 183 |
+
_log_classification_stats(scores, gt_classes)
|
| 184 |
+
|
| 185 |
+
if len(proposals):
|
| 186 |
+
proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) # Nx4
|
| 187 |
+
assert not proposal_boxes.requires_grad, "Proposals should not require gradients!"
|
| 188 |
+
gt_boxes = cat(
|
| 189 |
+
[(p.gt_boxes if p.has("gt_boxes") else p.proposal_boxes).tensor for p in proposals],
|
| 190 |
+
dim=0,
|
| 191 |
+
)
|
| 192 |
+
else:
|
| 193 |
+
proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device)
|
| 194 |
+
|
| 195 |
+
if self.use_sigmoid_ce:
|
| 196 |
+
loss_cls = self.sigmoid_cross_entropy_loss(scores, gt_classes)
|
| 197 |
+
else:
|
| 198 |
+
loss_cls = self.softmax_cross_entropy_loss(scores, gt_classes)
|
| 199 |
+
return {
|
| 200 |
+
"loss_cls": loss_cls,
|
| 201 |
+
"loss_box_reg": self.box_reg_loss(
|
| 202 |
+
proposal_boxes, gt_boxes, proposal_deltas, gt_classes,
|
| 203 |
+
num_classes=num_classes)
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def sigmoid_cross_entropy_loss(self, pred_class_logits, gt_classes):
|
| 208 |
+
if pred_class_logits.numel() == 0:
|
| 209 |
+
return pred_class_logits.new_zeros([1])[0] # This is more robust than .sum() * 0.
|
| 210 |
+
|
| 211 |
+
B = pred_class_logits.shape[0]
|
| 212 |
+
C = pred_class_logits.shape[1] - 1
|
| 213 |
+
|
| 214 |
+
target = pred_class_logits.new_zeros(B, C + 1)
|
| 215 |
+
target[range(len(gt_classes)), gt_classes] = 1 # B x (C + 1)
|
| 216 |
+
target = target[:, :C] # B x C
|
| 217 |
+
|
| 218 |
+
weight = 1
|
| 219 |
+
|
| 220 |
+
if self.use_fed_loss and (self.freq_weight is not None): # fedloss
|
| 221 |
+
appeared = get_fed_loss_inds(
|
| 222 |
+
gt_classes,
|
| 223 |
+
num_sample_cats=self.fed_loss_num_cat,
|
| 224 |
+
C=C,
|
| 225 |
+
weight=self.freq_weight)
|
| 226 |
+
appeared_mask = appeared.new_zeros(C + 1)
|
| 227 |
+
appeared_mask[appeared] = 1 # C + 1
|
| 228 |
+
appeared_mask = appeared_mask[:C]
|
| 229 |
+
fed_w = appeared_mask.view(1, C).expand(B, C)
|
| 230 |
+
weight = weight * fed_w.float()
|
| 231 |
+
if self.ignore_zero_cats and (self.freq_weight is not None):
|
| 232 |
+
w = (self.freq_weight.view(-1) > 1e-4).float()
|
| 233 |
+
weight = weight * w.view(1, C).expand(B, C)
|
| 234 |
+
# import pdb; pdb.set_trace()
|
| 235 |
+
|
| 236 |
+
cls_loss = F.binary_cross_entropy_with_logits(
|
| 237 |
+
pred_class_logits[:, :-1], target, reduction="none"
|
| 238 |
+
) # B x C
|
| 239 |
+
loss = torch.sum(cls_loss * weight) / B
|
| 240 |
+
return loss
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def softmax_cross_entropy_loss(self, pred_class_logits, gt_classes):
|
| 244 |
+
"""
|
| 245 |
+
change _no_instance handling
|
| 246 |
+
"""
|
| 247 |
+
if pred_class_logits.numel() == 0:
|
| 248 |
+
return pred_class_logits.new_zeros([1])[0]
|
| 249 |
+
|
| 250 |
+
if self.ignore_zero_cats and (self.freq_weight is not None):
|
| 251 |
+
zero_weight = torch.cat([
|
| 252 |
+
(self.freq_weight.view(-1) > 1e-4).float(),
|
| 253 |
+
self.freq_weight.new_ones(1)]) # C + 1
|
| 254 |
+
loss = F.cross_entropy(
|
| 255 |
+
pred_class_logits, gt_classes,
|
| 256 |
+
weight=zero_weight, reduction="mean")
|
| 257 |
+
elif self.use_fed_loss and (self.freq_weight is not None): # fedloss
|
| 258 |
+
C = pred_class_logits.shape[1] - 1
|
| 259 |
+
appeared = get_fed_loss_inds(
|
| 260 |
+
gt_classes,
|
| 261 |
+
num_sample_cats=self.fed_loss_num_cat,
|
| 262 |
+
C=C,
|
| 263 |
+
weight=self.freq_weight)
|
| 264 |
+
appeared_mask = appeared.new_zeros(C + 1).float()
|
| 265 |
+
appeared_mask[appeared] = 1. # C + 1
|
| 266 |
+
appeared_mask[C] = 1.
|
| 267 |
+
loss = F.cross_entropy(
|
| 268 |
+
pred_class_logits, gt_classes,
|
| 269 |
+
weight=appeared_mask, reduction="mean")
|
| 270 |
+
else:
|
| 271 |
+
loss = F.cross_entropy(
|
| 272 |
+
pred_class_logits, gt_classes, reduction="mean")
|
| 273 |
+
return loss
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def box_reg_loss(
|
| 277 |
+
self, proposal_boxes, gt_boxes, pred_deltas, gt_classes,
|
| 278 |
+
num_classes=-1):
|
| 279 |
+
"""
|
| 280 |
+
Allow custom background index
|
| 281 |
+
"""
|
| 282 |
+
num_classes = num_classes if num_classes > 0 else self.num_classes
|
| 283 |
+
box_dim = proposal_boxes.shape[1] # 4 or 5
|
| 284 |
+
fg_inds = nonzero_tuple((gt_classes >= 0) & (gt_classes < num_classes))[0]
|
| 285 |
+
if pred_deltas.shape[1] == box_dim: # cls-agnostic regression
|
| 286 |
+
fg_pred_deltas = pred_deltas[fg_inds]
|
| 287 |
+
else:
|
| 288 |
+
fg_pred_deltas = pred_deltas.view(-1, self.num_classes, box_dim)[
|
| 289 |
+
fg_inds, gt_classes[fg_inds]
|
| 290 |
+
]
|
| 291 |
+
|
| 292 |
+
if self.box_reg_loss_type == "smooth_l1":
|
| 293 |
+
gt_pred_deltas = self.box2box_transform.get_deltas(
|
| 294 |
+
proposal_boxes[fg_inds],
|
| 295 |
+
gt_boxes[fg_inds],
|
| 296 |
+
)
|
| 297 |
+
loss_box_reg = smooth_l1_loss(
|
| 298 |
+
fg_pred_deltas, gt_pred_deltas, self.smooth_l1_beta, reduction="sum"
|
| 299 |
+
)
|
| 300 |
+
elif self.box_reg_loss_type == "giou":
|
| 301 |
+
fg_pred_boxes = self.box2box_transform.apply_deltas(
|
| 302 |
+
fg_pred_deltas, proposal_boxes[fg_inds]
|
| 303 |
+
)
|
| 304 |
+
loss_box_reg = giou_loss(fg_pred_boxes, gt_boxes[fg_inds], reduction="sum")
|
| 305 |
+
else:
|
| 306 |
+
raise ValueError(f"Invalid bbox reg loss type '{self.box_reg_loss_type}'")
|
| 307 |
+
return loss_box_reg / max(gt_classes.numel(), 1.0)
|
| 308 |
+
|
| 309 |
+
def inference(self, predictions, proposals):
|
| 310 |
+
"""
|
| 311 |
+
enable use proposal boxes
|
| 312 |
+
"""
|
| 313 |
+
predictions = (predictions[0], predictions[1])
|
| 314 |
+
boxes = self.predict_boxes(predictions, proposals)
|
| 315 |
+
scores = self.predict_probs(predictions, proposals)
|
| 316 |
+
if self.mult_proposal_score:
|
| 317 |
+
proposal_scores = [p.get('objectness_logits') for p in proposals]
|
| 318 |
+
scores = [(s * ps[:, None]) ** 0.5 \
|
| 319 |
+
for s, ps in zip(scores, proposal_scores)]
|
| 320 |
+
image_shapes = [x.image_size for x in proposals]
|
| 321 |
+
return fast_rcnn_inference(
|
| 322 |
+
boxes,
|
| 323 |
+
scores,
|
| 324 |
+
image_shapes,
|
| 325 |
+
self.test_score_thresh,
|
| 326 |
+
self.test_nms_thresh,
|
| 327 |
+
self.test_topk_per_image,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def predict_probs(self, predictions, proposals):
|
| 332 |
+
"""
|
| 333 |
+
support sigmoid
|
| 334 |
+
"""
|
| 335 |
+
# scores, _ = predictions
|
| 336 |
+
scores = predictions[0]
|
| 337 |
+
num_inst_per_image = [len(p) for p in proposals]
|
| 338 |
+
if self.use_sigmoid_ce:
|
| 339 |
+
probs = scores.sigmoid()
|
| 340 |
+
else:
|
| 341 |
+
probs = F.softmax(scores, dim=-1)
|
| 342 |
+
return probs.split(num_inst_per_image, dim=0)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def image_label_losses(self, predictions, proposals, image_labels, \
|
| 346 |
+
classifier_info=(None,None,None), ann_type='image'):
|
| 347 |
+
'''
|
| 348 |
+
Inputs:
|
| 349 |
+
scores: N x (C + 1)
|
| 350 |
+
image_labels B x 1
|
| 351 |
+
'''
|
| 352 |
+
num_inst_per_image = [len(p) for p in proposals]
|
| 353 |
+
scores = predictions[0]
|
| 354 |
+
scores = scores.split(num_inst_per_image, dim=0) # B x n x (C + 1)
|
| 355 |
+
if self.with_softmax_prop:
|
| 356 |
+
prop_scores = predictions[2].split(num_inst_per_image, dim=0)
|
| 357 |
+
else:
|
| 358 |
+
prop_scores = [None for _ in num_inst_per_image]
|
| 359 |
+
B = len(scores)
|
| 360 |
+
img_box_count = 0
|
| 361 |
+
select_size_count = 0
|
| 362 |
+
select_x_count = 0
|
| 363 |
+
select_y_count = 0
|
| 364 |
+
max_score_count = 0
|
| 365 |
+
storage = get_event_storage()
|
| 366 |
+
loss = scores[0].new_zeros([1])[0]
|
| 367 |
+
caption_loss = scores[0].new_zeros([1])[0]
|
| 368 |
+
for idx, (score, labels, prop_score, p) in enumerate(zip(
|
| 369 |
+
scores, image_labels, prop_scores, proposals)):
|
| 370 |
+
if score.shape[0] == 0:
|
| 371 |
+
loss += score.new_zeros([1])[0]
|
| 372 |
+
continue
|
| 373 |
+
if 'caption' in ann_type:
|
| 374 |
+
score, caption_loss_img = self._caption_loss(
|
| 375 |
+
score, classifier_info, idx, B)
|
| 376 |
+
caption_loss += self.caption_weight * caption_loss_img
|
| 377 |
+
if ann_type == 'caption':
|
| 378 |
+
continue
|
| 379 |
+
|
| 380 |
+
if self.debug:
|
| 381 |
+
p.selected = score.new_zeros(
|
| 382 |
+
(len(p),), dtype=torch.long) - 1
|
| 383 |
+
for i_l, label in enumerate(labels):
|
| 384 |
+
if self.dynamic_classifier:
|
| 385 |
+
if idx == 0 and i_l == 0 and comm.is_main_process():
|
| 386 |
+
storage.put_scalar('stats_label', label)
|
| 387 |
+
label = classifier_info[1][1][label]
|
| 388 |
+
assert label < score.shape[1]
|
| 389 |
+
if self.image_label_loss in ['wsod', 'wsddn']:
|
| 390 |
+
loss_i, ind = self._wsddn_loss(score, prop_score, label)
|
| 391 |
+
elif self.image_label_loss == 'max_score':
|
| 392 |
+
loss_i, ind = self._max_score_loss(score, label)
|
| 393 |
+
elif self.image_label_loss == 'max_size':
|
| 394 |
+
loss_i, ind = self._max_size_loss(score, label, p)
|
| 395 |
+
elif self.image_label_loss == 'first':
|
| 396 |
+
loss_i, ind = self._first_loss(score, label)
|
| 397 |
+
elif self.image_label_loss == 'image':
|
| 398 |
+
loss_i, ind = self._image_loss(score, label)
|
| 399 |
+
elif self.image_label_loss == 'min_loss':
|
| 400 |
+
loss_i, ind = self._min_loss_loss(score, label)
|
| 401 |
+
else:
|
| 402 |
+
assert 0
|
| 403 |
+
loss += loss_i / len(labels)
|
| 404 |
+
if type(ind) == type([]):
|
| 405 |
+
img_box_count = sum(ind) / len(ind)
|
| 406 |
+
if self.debug:
|
| 407 |
+
for ind_i in ind:
|
| 408 |
+
p.selected[ind_i] = label
|
| 409 |
+
else:
|
| 410 |
+
img_box_count = ind
|
| 411 |
+
select_size_count = p[ind].proposal_boxes.area() / \
|
| 412 |
+
(p.image_size[0] * p.image_size[1])
|
| 413 |
+
max_score_count = score[ind, label].sigmoid()
|
| 414 |
+
select_x_count = (p.proposal_boxes.tensor[ind, 0] + \
|
| 415 |
+
p.proposal_boxes.tensor[ind, 2]) / 2 / p.image_size[1]
|
| 416 |
+
select_y_count = (p.proposal_boxes.tensor[ind, 1] + \
|
| 417 |
+
p.proposal_boxes.tensor[ind, 3]) / 2 / p.image_size[0]
|
| 418 |
+
if self.debug:
|
| 419 |
+
p.selected[ind] = label
|
| 420 |
+
|
| 421 |
+
loss = loss / B
|
| 422 |
+
storage.put_scalar('stats_l_image', loss.item())
|
| 423 |
+
if 'caption' in ann_type:
|
| 424 |
+
caption_loss = caption_loss / B
|
| 425 |
+
loss = loss + caption_loss
|
| 426 |
+
storage.put_scalar('stats_l_caption', caption_loss.item())
|
| 427 |
+
if comm.is_main_process():
|
| 428 |
+
storage.put_scalar('pool_stats', img_box_count)
|
| 429 |
+
storage.put_scalar('stats_select_size', select_size_count)
|
| 430 |
+
storage.put_scalar('stats_select_x', select_x_count)
|
| 431 |
+
storage.put_scalar('stats_select_y', select_y_count)
|
| 432 |
+
storage.put_scalar('stats_max_label_score', max_score_count)
|
| 433 |
+
|
| 434 |
+
return {
|
| 435 |
+
'image_loss': loss * self.image_loss_weight,
|
| 436 |
+
'loss_cls': score.new_zeros([1])[0],
|
| 437 |
+
'loss_box_reg': score.new_zeros([1])[0]}
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def forward(self, x, classifier_info=(None,None,None)):
|
| 441 |
+
"""
|
| 442 |
+
enable classifier_info
|
| 443 |
+
"""
|
| 444 |
+
if x.dim() > 2:
|
| 445 |
+
x = torch.flatten(x, start_dim=1)
|
| 446 |
+
scores = []
|
| 447 |
+
|
| 448 |
+
if classifier_info[0] is not None:
|
| 449 |
+
classifier_out = self.cls_score(x, classifier=classifier_info[0])
|
| 450 |
+
else:
|
| 451 |
+
classifier_out = self.cls_score(x)
|
| 452 |
+
if self.use_regional_embedding:
|
| 453 |
+
cls_scores, regional_embeddings = classifier_out
|
| 454 |
+
else:
|
| 455 |
+
cls_scores = classifier_out
|
| 456 |
+
scores.append(cls_scores)
|
| 457 |
+
|
| 458 |
+
if classifier_info[2] is not None:
|
| 459 |
+
classifier_out = classifier_info[2]
|
| 460 |
+
if self.use_regional_embedding:
|
| 461 |
+
cap_cls, regional_embeddings = classifier_out
|
| 462 |
+
else:
|
| 463 |
+
cap_cls = classifier_out
|
| 464 |
+
if self.sync_caption_batch:
|
| 465 |
+
caption_scores = self.cls_score(x, classifier=cap_cls[:, :-1])
|
| 466 |
+
else:
|
| 467 |
+
caption_scores = self.cls_score(x, classifier=cap_cls)
|
| 468 |
+
scores.append(caption_scores)
|
| 469 |
+
scores = torch.cat(scores, dim=1) # B x C' or B x N or B x (C'+N)
|
| 470 |
+
|
| 471 |
+
proposal_deltas = self.bbox_pred(x)
|
| 472 |
+
if self.with_softmax_prop:
|
| 473 |
+
prop_score = self.prop_score(x)
|
| 474 |
+
return scores, proposal_deltas, prop_score
|
| 475 |
+
elif self.use_regional_embedding:
|
| 476 |
+
# NOTE: scores: [B * # proposals for each image, 1204 (1203 + 1 bg)]
|
| 477 |
+
# NOTE: proposal_deltas: [B * # proposals for each image, 4]
|
| 478 |
+
# NOTE: regional_embeddings: [B * # proposals for each image, 512]
|
| 479 |
+
return scores, proposal_deltas, regional_embeddings
|
| 480 |
+
else:
|
| 481 |
+
return scores, proposal_deltas
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def _caption_loss(self, score, classifier_info, idx, B):
|
| 485 |
+
assert (classifier_info[2] is not None)
|
| 486 |
+
assert self.add_image_box
|
| 487 |
+
cls_and_cap_num = score.shape[1]
|
| 488 |
+
cap_num = classifier_info[2].shape[0]
|
| 489 |
+
score, caption_score = score.split(
|
| 490 |
+
[cls_and_cap_num - cap_num, cap_num], dim=1)
|
| 491 |
+
# n x (C + 1), n x B
|
| 492 |
+
caption_score = caption_score[-1:] # 1 x B # -1: image level box
|
| 493 |
+
caption_target = caption_score.new_zeros(
|
| 494 |
+
caption_score.shape) # 1 x B or 1 x MB, M: num machines
|
| 495 |
+
if self.sync_caption_batch:
|
| 496 |
+
# caption_target: 1 x MB
|
| 497 |
+
rank = comm.get_rank()
|
| 498 |
+
global_idx = B * rank + idx
|
| 499 |
+
assert (classifier_info[2][
|
| 500 |
+
global_idx, -1] - rank) ** 2 < 1e-8, \
|
| 501 |
+
'{} {} {} {} {}'.format(
|
| 502 |
+
rank, global_idx,
|
| 503 |
+
classifier_info[2][global_idx, -1],
|
| 504 |
+
classifier_info[2].shape,
|
| 505 |
+
classifier_info[2][:, -1])
|
| 506 |
+
caption_target[:, global_idx] = 1.
|
| 507 |
+
else:
|
| 508 |
+
assert caption_score.shape[1] == B
|
| 509 |
+
caption_target[:, idx] = 1.
|
| 510 |
+
caption_loss_img = F.binary_cross_entropy_with_logits(
|
| 511 |
+
caption_score, caption_target, reduction='none')
|
| 512 |
+
if self.sync_caption_batch:
|
| 513 |
+
fg_mask = (caption_target > 0.5).float()
|
| 514 |
+
assert (fg_mask.sum().item() - 1.) ** 2 < 1e-8, '{} {}'.format(
|
| 515 |
+
fg_mask.shape, fg_mask)
|
| 516 |
+
pos_loss = (caption_loss_img * fg_mask).sum()
|
| 517 |
+
neg_loss = (caption_loss_img * (1. - fg_mask)).sum()
|
| 518 |
+
caption_loss_img = pos_loss + self.neg_cap_weight * neg_loss
|
| 519 |
+
else:
|
| 520 |
+
caption_loss_img = caption_loss_img.sum()
|
| 521 |
+
return score, caption_loss_img
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def _wsddn_loss(self, score, prop_score, label):
|
| 525 |
+
assert prop_score is not None
|
| 526 |
+
loss = 0
|
| 527 |
+
final_score = score.sigmoid() * \
|
| 528 |
+
F.softmax(prop_score, dim=0) # B x (C + 1)
|
| 529 |
+
img_score = torch.clamp(
|
| 530 |
+
torch.sum(final_score, dim=0),
|
| 531 |
+
min=1e-10, max=1-1e-10) # (C + 1)
|
| 532 |
+
target = img_score.new_zeros(img_score.shape) # (C + 1)
|
| 533 |
+
target[label] = 1.
|
| 534 |
+
loss += F.binary_cross_entropy(img_score, target)
|
| 535 |
+
ind = final_score[:, label].argmax()
|
| 536 |
+
return loss, ind
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def _max_score_loss(self, score, label):
|
| 540 |
+
loss = 0
|
| 541 |
+
target = score.new_zeros(score.shape[1])
|
| 542 |
+
target[label] = 1.
|
| 543 |
+
ind = score[:, label].argmax().item()
|
| 544 |
+
loss += F.binary_cross_entropy_with_logits(
|
| 545 |
+
score[ind], target, reduction='sum')
|
| 546 |
+
return loss, ind
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def _min_loss_loss(self, score, label):
|
| 550 |
+
loss = 0
|
| 551 |
+
target = score.new_zeros(score.shape)
|
| 552 |
+
target[:, label] = 1.
|
| 553 |
+
with torch.no_grad():
|
| 554 |
+
x = F.binary_cross_entropy_with_logits(
|
| 555 |
+
score, target, reduction='none').sum(dim=1) # n
|
| 556 |
+
ind = x.argmin().item()
|
| 557 |
+
loss += F.binary_cross_entropy_with_logits(
|
| 558 |
+
score[ind], target[0], reduction='sum')
|
| 559 |
+
return loss, ind
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
def _first_loss(self, score, label):
|
| 563 |
+
loss = 0
|
| 564 |
+
target = score.new_zeros(score.shape[1])
|
| 565 |
+
target[label] = 1.
|
| 566 |
+
ind = 0
|
| 567 |
+
loss += F.binary_cross_entropy_with_logits(
|
| 568 |
+
score[ind], target, reduction='sum')
|
| 569 |
+
return loss, ind
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
def _image_loss(self, score, label):
|
| 573 |
+
assert self.add_image_box
|
| 574 |
+
target = score.new_zeros(score.shape[1])
|
| 575 |
+
target[label] = 1.
|
| 576 |
+
ind = score.shape[0] - 1
|
| 577 |
+
loss = F.binary_cross_entropy_with_logits(
|
| 578 |
+
score[ind], target, reduction='sum')
|
| 579 |
+
return loss, ind
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
def _max_size_loss(self, score, label, p):
|
| 583 |
+
loss = 0
|
| 584 |
+
target = score.new_zeros(score.shape[1])
|
| 585 |
+
target[label] = 1.
|
| 586 |
+
sizes = p.proposal_boxes.area()
|
| 587 |
+
ind = sizes[:-1].argmax().item() if len(sizes) > 1 else 0
|
| 588 |
+
if self.softmax_weak_loss:
|
| 589 |
+
loss += F.cross_entropy(
|
| 590 |
+
score[ind:ind+1],
|
| 591 |
+
score.new_tensor(label, dtype=torch.long).view(1),
|
| 592 |
+
reduction='sum')
|
| 593 |
+
else:
|
| 594 |
+
loss += F.binary_cross_entropy_with_logits(
|
| 595 |
+
score[ind], target, reduction='sum')
|
| 596 |
+
return loss, ind
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def put_label_distribution(storage, hist_name, hist_counts, num_classes):
|
| 601 |
+
"""
|
| 602 |
+
"""
|
| 603 |
+
ht_min, ht_max = 0, num_classes
|
| 604 |
+
hist_edges = torch.linspace(
|
| 605 |
+
start=ht_min, end=ht_max, steps=num_classes + 1, dtype=torch.float32)
|
| 606 |
+
|
| 607 |
+
hist_params = dict(
|
| 608 |
+
tag=hist_name,
|
| 609 |
+
min=ht_min,
|
| 610 |
+
max=ht_max,
|
| 611 |
+
num=float(hist_counts.sum()),
|
| 612 |
+
sum=float((hist_counts * torch.arange(len(hist_counts))).sum()),
|
| 613 |
+
sum_squares=float(((hist_counts * torch.arange(len(hist_counts))) ** 2).sum()),
|
| 614 |
+
bucket_limits=hist_edges[1:].tolist(),
|
| 615 |
+
bucket_counts=hist_counts.tolist(),
|
| 616 |
+
global_step=storage._iter,
|
| 617 |
+
)
|
| 618 |
+
storage._histograms.append(hist_params)
|
proxydet/modeling/roi_heads/proxydet_roi_heads.py
ADDED
|
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
'''
|
| 3 |
+
Modifications Copyright (c) 2024-present NAVER Corp, Apache License v2.0
|
| 4 |
+
original source: https://github.com/facebookresearch/Detic/blob/main/detic/modeling/roi_heads/detic_roi_heads.py
|
| 5 |
+
'''
|
| 6 |
+
import copy
|
| 7 |
+
import numpy as np
|
| 8 |
+
import json
|
| 9 |
+
import math
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn
|
| 12 |
+
from torch.autograd.function import Function
|
| 13 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 14 |
+
from torch.nn import functional as F
|
| 15 |
+
|
| 16 |
+
from fvcore.nn import giou_loss
|
| 17 |
+
|
| 18 |
+
from detectron2.config import configurable
|
| 19 |
+
from detectron2.layers import ShapeSpec
|
| 20 |
+
from detectron2.layers import batched_nms, cat
|
| 21 |
+
from detectron2.structures import Boxes, Instances, pairwise_iou
|
| 22 |
+
from detectron2.utils.events import get_event_storage
|
| 23 |
+
|
| 24 |
+
from detectron2.modeling.box_regression import Box2BoxTransform
|
| 25 |
+
from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference
|
| 26 |
+
from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, StandardROIHeads
|
| 27 |
+
from detectron2.modeling.roi_heads.cascade_rcnn import CascadeROIHeads, _ScaleGradient
|
| 28 |
+
from detectron2.modeling.roi_heads.box_head import build_box_head
|
| 29 |
+
from .proxydet_fast_rcnn import ProxydetFastRCNNOutputLayers
|
| 30 |
+
from ..debug import debug_second_stage
|
| 31 |
+
|
| 32 |
+
from torch.cuda.amp import autocast
|
| 33 |
+
from copy import deepcopy
|
| 34 |
+
|
| 35 |
+
@ROI_HEADS_REGISTRY.register()
|
| 36 |
+
class ProxydetCascadeROIHeads(CascadeROIHeads):
|
| 37 |
+
@configurable
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
*,
|
| 41 |
+
mult_proposal_score: bool = False,
|
| 42 |
+
with_image_labels: bool = False,
|
| 43 |
+
add_image_box: bool = False,
|
| 44 |
+
image_box_size: float = 1.0,
|
| 45 |
+
ws_num_props: int = 512,
|
| 46 |
+
add_feature_to_prop: bool = False,
|
| 47 |
+
mask_weight: float = 1.0,
|
| 48 |
+
one_class_per_proposal: bool = False,
|
| 49 |
+
use_regional_embedding: bool = False,
|
| 50 |
+
base_cat_mask: str = None,
|
| 51 |
+
cmm_stage: list = [],
|
| 52 |
+
cmm_stage_test: list = None,
|
| 53 |
+
cmm_beta: float = 1.0,
|
| 54 |
+
cmm_loss: str = "l1",
|
| 55 |
+
cmm_loss_weight: float = 1.0,
|
| 56 |
+
cmm_separated_branch: bool = False,
|
| 57 |
+
cmm_base_alpha: float = 0.5,
|
| 58 |
+
cmm_novel_beta: float = 0.5,
|
| 59 |
+
cmm_use_inl: bool = False,
|
| 60 |
+
cmm_prototype: str = "center",
|
| 61 |
+
cmm_prototype_temp: float = 1.0,
|
| 62 |
+
cmm_classifier_temp: float = None,
|
| 63 |
+
cmm_use_sigmoid_ce: bool = True,
|
| 64 |
+
**kwargs,
|
| 65 |
+
):
|
| 66 |
+
super().__init__(**kwargs)
|
| 67 |
+
self.mult_proposal_score = mult_proposal_score
|
| 68 |
+
self.with_image_labels = with_image_labels
|
| 69 |
+
self.add_image_box = add_image_box
|
| 70 |
+
self.image_box_size = image_box_size
|
| 71 |
+
self.ws_num_props = ws_num_props
|
| 72 |
+
self.add_feature_to_prop = add_feature_to_prop
|
| 73 |
+
self.mask_weight = mask_weight
|
| 74 |
+
self.one_class_per_proposal = one_class_per_proposal
|
| 75 |
+
self.use_regional_embedding = use_regional_embedding
|
| 76 |
+
self.base_cat_mask = torch.tensor(np.load(base_cat_mask)).bool()
|
| 77 |
+
self.cmm_stage = cmm_stage
|
| 78 |
+
self.cmm_stage_test = cmm_stage_test
|
| 79 |
+
self.cmm_beta = cmm_beta
|
| 80 |
+
self.cmm_loss = cmm_loss
|
| 81 |
+
self.cmm_loss_weight = cmm_loss_weight
|
| 82 |
+
self.cmm_separated_branch = cmm_separated_branch
|
| 83 |
+
self.cmm_base_alpha = cmm_base_alpha
|
| 84 |
+
self.cmm_novel_beta = cmm_novel_beta
|
| 85 |
+
self.cmm_use_inl = cmm_use_inl
|
| 86 |
+
self.cmm_prototype = cmm_prototype
|
| 87 |
+
self.cmm_prototype_temp = cmm_prototype_temp
|
| 88 |
+
self.cmm_classifier_temp = cmm_classifier_temp
|
| 89 |
+
self.cmm_use_sigmoid_ce = cmm_use_sigmoid_ce
|
| 90 |
+
|
| 91 |
+
if self.cmm_separated_branch:
|
| 92 |
+
self.box_head_cmm = deepcopy(self.box_head)
|
| 93 |
+
self.box_predictor_cmm = deepcopy(self.box_predictor)
|
| 94 |
+
if self.cmm_classifier_temp is not None:
|
| 95 |
+
for k in range(self.num_cascade_stages):
|
| 96 |
+
self.box_predictor_cmm[k].cls_score.norm_temperature = self.cmm_classifier_temp
|
| 97 |
+
if not self.cmm_use_sigmoid_ce:
|
| 98 |
+
for k in range(self.num_cascade_stages):
|
| 99 |
+
self.box_predictor_cmm[k].use_sigmoid_ce = self.cmm_use_sigmoid_ce # using bce or ce
|
| 100 |
+
|
| 101 |
+
@classmethod
|
| 102 |
+
def from_config(cls, cfg, input_shape):
|
| 103 |
+
ret = super().from_config(cfg, input_shape)
|
| 104 |
+
ret.update({
|
| 105 |
+
'mult_proposal_score': cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE,
|
| 106 |
+
'with_image_labels': cfg.WITH_IMAGE_LABELS,
|
| 107 |
+
'add_image_box': cfg.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX,
|
| 108 |
+
'image_box_size': cfg.MODEL.ROI_BOX_HEAD.IMAGE_BOX_SIZE,
|
| 109 |
+
'ws_num_props': cfg.MODEL.ROI_BOX_HEAD.WS_NUM_PROPS,
|
| 110 |
+
'add_feature_to_prop': cfg.MODEL.ROI_BOX_HEAD.ADD_FEATURE_TO_PROP,
|
| 111 |
+
'mask_weight': cfg.MODEL.ROI_HEADS.MASK_WEIGHT,
|
| 112 |
+
'one_class_per_proposal': cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL,
|
| 113 |
+
"use_regional_embedding": cfg.MODEL.ROI_BOX_HEAD.USE_REGIONAL_EMBEDDING,
|
| 114 |
+
"base_cat_mask": cfg.MODEL.ROI_HEADS.BASE_CAT_MASK,
|
| 115 |
+
"cmm_stage": cfg.MODEL.ROI_HEADS.CMM.MIXUP_STAGE,
|
| 116 |
+
"cmm_stage_test": cfg.MODEL.ROI_HEADS.CMM.MIXUP_STAGE_TEST,
|
| 117 |
+
"cmm_beta": cfg.MODEL.ROI_HEADS.CMM.MIXUP_BETA,
|
| 118 |
+
"cmm_loss": cfg.MODEL.ROI_HEADS.CMM.LOSS,
|
| 119 |
+
"cmm_loss_weight": cfg.MODEL.ROI_HEADS.CMM.LOSS_WEIGHT,
|
| 120 |
+
"cmm_separated_branch": cfg.MODEL.ROI_HEADS.CMM.SEPARATED_BRANCH,
|
| 121 |
+
"cmm_base_alpha": cfg.MODEL.ROI_HEADS.CMM.BASE_ALPHA,
|
| 122 |
+
"cmm_novel_beta": cfg.MODEL.ROI_HEADS.CMM.NOVEL_BETA,
|
| 123 |
+
"cmm_use_inl": cfg.MODEL.ROI_HEADS.CMM.USE_INL,
|
| 124 |
+
"cmm_prototype": cfg.MODEL.ROI_HEADS.CMM.PROTOTYPE,
|
| 125 |
+
"cmm_prototype_temp": cfg.MODEL.ROI_HEADS.CMM.PROTOTYPE_TEMP,
|
| 126 |
+
"cmm_classifier_temp": cfg.MODEL.ROI_HEADS.CMM.CLASSIFIER_TEMP,
|
| 127 |
+
"cmm_use_sigmoid_ce": cfg.MODEL.ROI_HEADS.CMM.USE_SIGMOID_CE,
|
| 128 |
+
})
|
| 129 |
+
return ret
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@classmethod
|
| 133 |
+
def _init_box_head(self, cfg, input_shape):
|
| 134 |
+
ret = super()._init_box_head(cfg, input_shape)
|
| 135 |
+
del ret['box_predictors']
|
| 136 |
+
cascade_bbox_reg_weights = cfg.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS
|
| 137 |
+
box_predictors = []
|
| 138 |
+
for box_head, bbox_reg_weights in zip(ret['box_heads'], \
|
| 139 |
+
cascade_bbox_reg_weights):
|
| 140 |
+
box_predictors.append(
|
| 141 |
+
ProxydetFastRCNNOutputLayers(
|
| 142 |
+
cfg, box_head.output_shape,
|
| 143 |
+
box2box_transform=Box2BoxTransform(weights=bbox_reg_weights)
|
| 144 |
+
))
|
| 145 |
+
ret['box_predictors'] = box_predictors
|
| 146 |
+
return ret
|
| 147 |
+
|
| 148 |
+
def _embed_interp(self, re_i, re_j, te_i, te_j, lam):
|
| 149 |
+
# mix image, text embedding
|
| 150 |
+
_mixed_re = lam * re_i + (1 - lam) * re_j
|
| 151 |
+
_mixed_te = lam * te_i + (1 - lam) * te_j
|
| 152 |
+
return _mixed_re, _mixed_te
|
| 153 |
+
|
| 154 |
+
def _get_head_outputs(self, features, proposals, image_sizes, _run_stage, box_predictor, targets=None, ann_type='box', classifier_info=(None,None,None)):
|
| 155 |
+
head_outputs = [] # (predictor, predictions, proposals)
|
| 156 |
+
for k in range(self.num_cascade_stages):
|
| 157 |
+
if k > 0:
|
| 158 |
+
proposals = self._create_proposals_from_boxes(
|
| 159 |
+
prev_pred_boxes, image_sizes,
|
| 160 |
+
logits=[p.objectness_logits for p in proposals])
|
| 161 |
+
if self.training and ann_type in ['box']:
|
| 162 |
+
proposals = self._match_and_label_boxes(
|
| 163 |
+
proposals, k, targets)
|
| 164 |
+
predictions = _run_stage(features, proposals, k,
|
| 165 |
+
classifier_info=classifier_info)
|
| 166 |
+
prev_pred_boxes = box_predictor[k].predict_boxes(
|
| 167 |
+
(predictions[0], predictions[1]), proposals)
|
| 168 |
+
head_outputs.append((box_predictor[k], predictions, proposals))
|
| 169 |
+
return head_outputs, proposals
|
| 170 |
+
|
| 171 |
+
def loss_mixup(self, mixed_re, mixed_te, text_embeddings):
|
| 172 |
+
# loss
|
| 173 |
+
if self.cmm_loss in ["l1", "l2"]:
|
| 174 |
+
if self.cmm_loss == "l1":
|
| 175 |
+
loss_type = F.l1_loss
|
| 176 |
+
elif self.cmm_loss == "l2":
|
| 177 |
+
loss_type = F.mse_loss
|
| 178 |
+
cmm_loss = loss_type(mixed_re, mixed_te)
|
| 179 |
+
else:
|
| 180 |
+
raise ValueError("No such loss is supported : ", self.cmm_loss)
|
| 181 |
+
return cmm_loss
|
| 182 |
+
|
| 183 |
+
def mixup(self, stage, regional_embeddings, text_embeddings, gt_classes, proto_weights=None):
|
| 184 |
+
# class-wise multi-modal mixup
|
| 185 |
+
try:
|
| 186 |
+
neg_class = text_embeddings.shape[0] - 1
|
| 187 |
+
|
| 188 |
+
# select positive text embeddings
|
| 189 |
+
all_classes = torch.unique(gt_classes)
|
| 190 |
+
pos_classes = all_classes[all_classes != neg_class]
|
| 191 |
+
|
| 192 |
+
# select class-wise regional embeddings & text embeddings
|
| 193 |
+
clswise_re = []
|
| 194 |
+
clswise_te = []
|
| 195 |
+
for p_c in pos_classes:
|
| 196 |
+
mask = (gt_classes == p_c)
|
| 197 |
+
if self.cmm_prototype == "center":
|
| 198 |
+
_clswise_re = torch.mean(regional_embeddings[mask], axis=0, keepdim=True)
|
| 199 |
+
elif self.cmm_prototype in ["obj_score", "iou"]:
|
| 200 |
+
soft_proto_weights = F.softmax(proto_weights[mask] / self.cmm_prototype_temp, dim=0)
|
| 201 |
+
_clswise_re = torch.sum(regional_embeddings[mask] * soft_proto_weights.unsqueeze(-1), 0, keepdim=True)
|
| 202 |
+
_clswise_te = text_embeddings[int(p_c.item())].unsqueeze(0)
|
| 203 |
+
clswise_re.append(_clswise_re)
|
| 204 |
+
clswise_te.append(_clswise_te)
|
| 205 |
+
|
| 206 |
+
if len(clswise_re) == 0:
|
| 207 |
+
raise ValueError("no positive base classes found for mixup.")
|
| 208 |
+
|
| 209 |
+
clswise_re = torch.cat(clswise_re, dim=0)
|
| 210 |
+
clswise_re = F.normalize(clswise_re, p=2, dim=1) # re-normalize
|
| 211 |
+
clswise_te = torch.cat(clswise_te, dim=0)
|
| 212 |
+
|
| 213 |
+
if self.cmm_beta == 0:
|
| 214 |
+
lam = float(np.random.randint(2))
|
| 215 |
+
else:
|
| 216 |
+
lam = np.random.beta(self.cmm_beta, self.cmm_beta)
|
| 217 |
+
|
| 218 |
+
# random shuffle for mixup pair
|
| 219 |
+
rand_index = torch.randperm(clswise_re.size()[0]).to(clswise_re.device)
|
| 220 |
+
|
| 221 |
+
# mixup
|
| 222 |
+
sf_clswise_re = clswise_re[rand_index]
|
| 223 |
+
sf_clswise_te = clswise_te[rand_index]
|
| 224 |
+
mixed_re, mixed_te = self._embed_interp(clswise_re, sf_clswise_re, clswise_te, sf_clswise_te, lam)
|
| 225 |
+
mixed_re = F.normalize(mixed_re, p=2, dim=1)
|
| 226 |
+
mixed_te = F.normalize(mixed_te, p=2, dim=1)
|
| 227 |
+
cmm_loss = self.loss_mixup(mixed_re, mixed_te, text_embeddings)
|
| 228 |
+
|
| 229 |
+
except Exception as e:
|
| 230 |
+
print("Caught this error in mixup: " + repr(e), "Thus skipping current batch w/o mixup...")
|
| 231 |
+
cmm_loss = text_embeddings[0].new_zeros([1])[0]
|
| 232 |
+
|
| 233 |
+
return cmm_loss
|
| 234 |
+
|
| 235 |
+
def _forward_box(self, features, proposals, targets=None,
|
| 236 |
+
ann_type='box', classifier_info=(None,None,None)):
|
| 237 |
+
"""
|
| 238 |
+
Add mult proposal scores at testing
|
| 239 |
+
Add ann_type
|
| 240 |
+
"""
|
| 241 |
+
if (not self.training) and self.mult_proposal_score:
|
| 242 |
+
if len(proposals) > 0 and proposals[0].has('scores'):
|
| 243 |
+
proposal_scores = [p.get('scores') for p in proposals]
|
| 244 |
+
else:
|
| 245 |
+
proposal_scores = [p.get('objectness_logits') for p in proposals]
|
| 246 |
+
|
| 247 |
+
features = [features[f] for f in self.box_in_features]
|
| 248 |
+
# head_outputs = [] # (predictor, predictions, proposals)
|
| 249 |
+
prev_pred_boxes = None
|
| 250 |
+
image_sizes = [x.image_size for x in proposals]
|
| 251 |
+
|
| 252 |
+
head_outputs, proposals = self._get_head_outputs(features, proposals, image_sizes, self._run_stage, self.box_predictor, targets, ann_type, classifier_info)
|
| 253 |
+
if self.cmm_separated_branch:
|
| 254 |
+
# separated forward
|
| 255 |
+
head_outputs_cmm, proposals_cmm = self._get_head_outputs(features, proposals, image_sizes, self._run_stage_cmm, self.box_predictor_cmm, targets, ann_type, classifier_info)
|
| 256 |
+
|
| 257 |
+
if self.training:
|
| 258 |
+
losses = {}
|
| 259 |
+
storage = get_event_storage()
|
| 260 |
+
for stage, (predictor, predictions, proposals) in enumerate(head_outputs):
|
| 261 |
+
with storage.name_scope("stage{}".format(stage)):
|
| 262 |
+
if ann_type != 'box':
|
| 263 |
+
stage_losses = {}
|
| 264 |
+
if ann_type in ['image', 'caption', 'captiontag']:
|
| 265 |
+
image_labels = [x._pos_category_ids for x in targets]
|
| 266 |
+
weak_losses = predictor.image_label_losses(
|
| 267 |
+
predictions, proposals, image_labels,
|
| 268 |
+
classifier_info=classifier_info,
|
| 269 |
+
ann_type=ann_type)
|
| 270 |
+
stage_losses.update(weak_losses)
|
| 271 |
+
|
| 272 |
+
if self.cmm_use_inl and len(self.cmm_stage) > 0 and stage in self.cmm_stage:
|
| 273 |
+
if self.cmm_separated_branch:
|
| 274 |
+
# get regional embeddings (l2 normalized) from separated branch
|
| 275 |
+
regional_embeddings = head_outputs_cmm[stage][1][2]
|
| 276 |
+
else:
|
| 277 |
+
# get regional embeddings (l2 normalized)
|
| 278 |
+
regional_embeddings = predictions[2]
|
| 279 |
+
|
| 280 |
+
# get text embeddings (L2 normalized), [C (1203 + 1), embedding dim]
|
| 281 |
+
text_embeddings = predictor.cls_score.zs_weight.t()
|
| 282 |
+
|
| 283 |
+
# get max-size proposal's regional embedding, per image
|
| 284 |
+
num_inst_per_image = [len(p) for p in proposals_cmm]
|
| 285 |
+
re_per_image = regional_embeddings.split(num_inst_per_image, dim=0)
|
| 286 |
+
|
| 287 |
+
maxsize_re_per_image = []
|
| 288 |
+
for p, re in zip(proposals_cmm, re_per_image):
|
| 289 |
+
sizes = p.proposal_boxes.area()
|
| 290 |
+
ind = sizes[:-1].argmax().item() if len(sizes) > 1 else 0
|
| 291 |
+
maxsize_re_per_image.append(re[ind].unsqueeze(0))
|
| 292 |
+
maxsize_re_per_image = torch.cat(maxsize_re_per_image, dim=0)
|
| 293 |
+
maxsize_re_per_image = maxsize_re_per_image.to(regional_embeddings.device)
|
| 294 |
+
|
| 295 |
+
# get gt classes per max-size proposal
|
| 296 |
+
# TODO: add best-label per image by cls loss from weak_losses (image-label loss)
|
| 297 |
+
# NOTE: image_labels are not multi-labels.
|
| 298 |
+
gt_classes = (
|
| 299 |
+
regional_embeddings.new_tensor(
|
| 300 |
+
[np.random.choice(labels, 1, replace=False)[0] for labels in image_labels],
|
| 301 |
+
dtype=torch.long
|
| 302 |
+
) if len(proposals_cmm)
|
| 303 |
+
else torch.empty(0)
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
proto_weights = None
|
| 307 |
+
|
| 308 |
+
# get text embeddings (L2 normalized), [C (1203 + 1), embedding dim]
|
| 309 |
+
text_embeddings = predictor.cls_score.zs_weight.t()
|
| 310 |
+
cmm_loss = self.mixup(stage, maxsize_re_per_image, text_embeddings, gt_classes, proto_weights)
|
| 311 |
+
stage_losses["cmm_image_loss"] = (
|
| 312 |
+
cmm_loss * self.cmm_loss_weight
|
| 313 |
+
)
|
| 314 |
+
stage_losses["cmm_loss"] = \
|
| 315 |
+
predictions[0].new_zeros([1])[0]
|
| 316 |
+
|
| 317 |
+
else: # supervised
|
| 318 |
+
stage_losses = predictor.losses(
|
| 319 |
+
(predictions[0], predictions[1]), proposals,
|
| 320 |
+
classifier_info=classifier_info)
|
| 321 |
+
if self.with_image_labels:
|
| 322 |
+
stage_losses['image_loss'] = \
|
| 323 |
+
predictions[0].new_zeros([1])[0]
|
| 324 |
+
|
| 325 |
+
if len(self.cmm_stage) > 0 and stage in self.cmm_stage:
|
| 326 |
+
assert self.use_regional_embedding
|
| 327 |
+
|
| 328 |
+
# get gt classes per proposal
|
| 329 |
+
# e.g. dtype: torch.int64, value: tensor([ 142, 111, 142, ..., 1203, 1203, 1203], device='cuda:6')
|
| 330 |
+
gt_classes = (
|
| 331 |
+
cat([p.gt_classes for p in proposals_cmm], dim=0)
|
| 332 |
+
if len(proposals_cmm)
|
| 333 |
+
else torch.empty(0)
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
if self.cmm_prototype in ["obj_score"]:
|
| 337 |
+
proto_weights = (
|
| 338 |
+
cat([p.objectness_logits for p in proposals_cmm], dim=0)
|
| 339 |
+
if len(proposals_cmm)
|
| 340 |
+
else torch.empty(0)
|
| 341 |
+
)
|
| 342 |
+
elif self.cmm_prototype in ["iou"]:
|
| 343 |
+
gt_boxes = (
|
| 344 |
+
cat([p.gt_boxes.tensor for p in proposals_cmm], dim=0)
|
| 345 |
+
if len(proposals_cmm)
|
| 346 |
+
else torch.empty(0)
|
| 347 |
+
)
|
| 348 |
+
proposal_boxes = (
|
| 349 |
+
cat([p.proposal_boxes.tensor for p in proposals_cmm], dim=0)
|
| 350 |
+
if len(proposals_cmm)
|
| 351 |
+
else torch.empty(0)
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
proto_weights = 1 - giou_loss(proposal_boxes, gt_boxes, reduction="none") # GIoU. (-1 < x < 1)
|
| 355 |
+
else:
|
| 356 |
+
proto_weights = None
|
| 357 |
+
|
| 358 |
+
if self.cmm_separated_branch:
|
| 359 |
+
# get regional embeddings (l2 normalized) from separated branch
|
| 360 |
+
regional_embeddings = head_outputs_cmm[stage][1][2]
|
| 361 |
+
else:
|
| 362 |
+
# get regional embeddings (l2 normalized)
|
| 363 |
+
regional_embeddings = predictions[2]
|
| 364 |
+
|
| 365 |
+
# get text embeddings (L2 normalized), [C (1203 + 1), embedding dim]
|
| 366 |
+
text_embeddings = predictor.cls_score.zs_weight.t()
|
| 367 |
+
|
| 368 |
+
cmm_loss = self.mixup(stage, regional_embeddings, text_embeddings, gt_classes, proto_weights)
|
| 369 |
+
stage_losses["cmm_loss"] = (
|
| 370 |
+
cmm_loss * self.cmm_loss_weight
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
if self.cmm_use_inl:
|
| 374 |
+
stage_losses["cmm_image_loss"] = \
|
| 375 |
+
predictions[0].new_zeros([1])[0]
|
| 376 |
+
|
| 377 |
+
losses.update({k + "_stage{}".format(stage): v \
|
| 378 |
+
for k, v in stage_losses.items()})
|
| 379 |
+
return losses
|
| 380 |
+
else:
|
| 381 |
+
# Each is a list[Tensor] of length #image. Each tensor is Ri x (K+1)
|
| 382 |
+
scores_per_stage = [h[0].predict_probs(h[1], h[2]) for h in head_outputs]
|
| 383 |
+
scores = [
|
| 384 |
+
sum(list(scores_per_image)) * (1.0 / self.num_cascade_stages)
|
| 385 |
+
for scores_per_image in zip(*scores_per_stage)
|
| 386 |
+
]
|
| 387 |
+
|
| 388 |
+
if self.cmm_separated_branch:
|
| 389 |
+
# scores from separated branch
|
| 390 |
+
if self.cmm_stage_test is None:
|
| 391 |
+
# average all stage's classification scores
|
| 392 |
+
scores_per_stage_cmm = [h[0].predict_probs(h[1], h[2]) for h in head_outputs_cmm]
|
| 393 |
+
scores_cmm = [
|
| 394 |
+
sum(list(scores_per_image)) * (1.0 / self.num_cascade_stages)
|
| 395 |
+
for scores_per_image in zip(*scores_per_stage_cmm)
|
| 396 |
+
]
|
| 397 |
+
else:
|
| 398 |
+
# only using specific stages
|
| 399 |
+
scores_per_stage_cmm = [h[0].predict_probs(h[1], h[2]) for k, h in enumerate(head_outputs_cmm) if k in self.cmm_stage_test]
|
| 400 |
+
scores_cmm = [
|
| 401 |
+
sum(list(scores_per_image)) * (1.0 / len(self.cmm_stage_test))
|
| 402 |
+
for scores_per_image in zip(*scores_per_stage_cmm)
|
| 403 |
+
]
|
| 404 |
+
|
| 405 |
+
base_cat_mask = self.base_cat_mask
|
| 406 |
+
assert len(scores) == 1
|
| 407 |
+
bg_score = scores[0][:, -1].clone()
|
| 408 |
+
scores[0][:, base_cat_mask] = scores[0][:, base_cat_mask].pow(
|
| 409 |
+
1.0 - self.cmm_base_alpha
|
| 410 |
+
) * scores_cmm[0][:, base_cat_mask].pow(self.cmm_base_alpha)
|
| 411 |
+
scores[0][:, ~base_cat_mask] = scores[0][:, ~base_cat_mask].pow(
|
| 412 |
+
1.0 - self.cmm_novel_beta
|
| 413 |
+
) * scores_cmm[0][:, ~base_cat_mask].pow(self.cmm_novel_beta)
|
| 414 |
+
scores[0][:, -1] = bg_score
|
| 415 |
+
|
| 416 |
+
if self.mult_proposal_score:
|
| 417 |
+
scores = [(s * ps[:, None]) ** 0.5 \
|
| 418 |
+
for s, ps in zip(scores, proposal_scores)]
|
| 419 |
+
if self.one_class_per_proposal:
|
| 420 |
+
scores = [s * (s == s[:, :-1].max(dim=1)[0][:, None]).float() for s in scores]
|
| 421 |
+
predictor, predictions, proposals = head_outputs[-1]
|
| 422 |
+
boxes = predictor.predict_boxes(
|
| 423 |
+
(predictions[0], predictions[1]), proposals)
|
| 424 |
+
pred_instances, _ = fast_rcnn_inference(
|
| 425 |
+
boxes,
|
| 426 |
+
scores,
|
| 427 |
+
image_sizes,
|
| 428 |
+
predictor.test_score_thresh,
|
| 429 |
+
predictor.test_nms_thresh,
|
| 430 |
+
predictor.test_topk_per_image,
|
| 431 |
+
)
|
| 432 |
+
return pred_instances
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def forward(self, images, features, proposals, targets=None,
|
| 436 |
+
ann_type='box', classifier_info=(None,None,None)):
|
| 437 |
+
'''
|
| 438 |
+
enable debug and image labels
|
| 439 |
+
classifier_info is shared across the batch
|
| 440 |
+
'''
|
| 441 |
+
if self.training:
|
| 442 |
+
if ann_type in ['box', 'prop', 'proptag']:
|
| 443 |
+
proposals = self.label_and_sample_proposals(
|
| 444 |
+
proposals, targets)
|
| 445 |
+
else:
|
| 446 |
+
proposals = self.get_top_proposals(proposals)
|
| 447 |
+
|
| 448 |
+
losses = self._forward_box(features, proposals, targets, \
|
| 449 |
+
ann_type=ann_type, classifier_info=classifier_info)
|
| 450 |
+
if ann_type == 'box' and targets[0].has('gt_masks'):
|
| 451 |
+
mask_losses = self._forward_mask(features, proposals)
|
| 452 |
+
losses.update({k: v * self.mask_weight \
|
| 453 |
+
for k, v in mask_losses.items()})
|
| 454 |
+
losses.update(self._forward_keypoint(features, proposals))
|
| 455 |
+
else:
|
| 456 |
+
losses.update(self._get_empty_mask_loss(
|
| 457 |
+
features, proposals,
|
| 458 |
+
device=proposals[0].objectness_logits.device))
|
| 459 |
+
return proposals, losses
|
| 460 |
+
else:
|
| 461 |
+
pred_instances = self._forward_box(
|
| 462 |
+
features, proposals, classifier_info=classifier_info)
|
| 463 |
+
pred_instances = self.forward_with_given_boxes(features, pred_instances)
|
| 464 |
+
return pred_instances, {}
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def get_top_proposals(self, proposals):
|
| 468 |
+
for i in range(len(proposals)):
|
| 469 |
+
proposals[i].proposal_boxes.clip(proposals[i].image_size)
|
| 470 |
+
proposals = [p[:self.ws_num_props] for p in proposals]
|
| 471 |
+
for i, p in enumerate(proposals):
|
| 472 |
+
p.proposal_boxes.tensor = p.proposal_boxes.tensor.detach()
|
| 473 |
+
if self.add_image_box:
|
| 474 |
+
proposals[i] = self._add_image_box(p)
|
| 475 |
+
return proposals
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def _add_image_box(self, p):
|
| 479 |
+
image_box = Instances(p.image_size)
|
| 480 |
+
n = 1
|
| 481 |
+
h, w = p.image_size
|
| 482 |
+
f = self.image_box_size
|
| 483 |
+
image_box.proposal_boxes = Boxes(
|
| 484 |
+
p.proposal_boxes.tensor.new_tensor(
|
| 485 |
+
[w * (1. - f) / 2.,
|
| 486 |
+
h * (1. - f) / 2.,
|
| 487 |
+
w * (1. - (1. - f) / 2.),
|
| 488 |
+
h * (1. - (1. - f) / 2.)]
|
| 489 |
+
).view(n, 4))
|
| 490 |
+
image_box.objectness_logits = p.objectness_logits.new_ones(n)
|
| 491 |
+
return Instances.cat([p, image_box])
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def _get_empty_mask_loss(self, features, proposals, device):
|
| 495 |
+
if self.mask_on:
|
| 496 |
+
return {'loss_mask': torch.zeros(
|
| 497 |
+
(1, ), device=device, dtype=torch.float32)[0]}
|
| 498 |
+
else:
|
| 499 |
+
return {}
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def _create_proposals_from_boxes(self, boxes, image_sizes, logits):
|
| 503 |
+
"""
|
| 504 |
+
Add objectness_logits
|
| 505 |
+
"""
|
| 506 |
+
boxes = [Boxes(b.detach()) for b in boxes]
|
| 507 |
+
proposals = []
|
| 508 |
+
for boxes_per_image, image_size, logit in zip(
|
| 509 |
+
boxes, image_sizes, logits):
|
| 510 |
+
boxes_per_image.clip(image_size)
|
| 511 |
+
if self.training:
|
| 512 |
+
inds = boxes_per_image.nonempty()
|
| 513 |
+
boxes_per_image = boxes_per_image[inds]
|
| 514 |
+
logit = logit[inds]
|
| 515 |
+
prop = Instances(image_size)
|
| 516 |
+
prop.proposal_boxes = boxes_per_image
|
| 517 |
+
prop.objectness_logits = logit
|
| 518 |
+
proposals.append(prop)
|
| 519 |
+
return proposals
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def _run_stage(self, features, proposals, stage, \
|
| 523 |
+
classifier_info=(None,None,None)):
|
| 524 |
+
"""
|
| 525 |
+
Support classifier_info and add_feature_to_prop
|
| 526 |
+
"""
|
| 527 |
+
pool_boxes = [x.proposal_boxes for x in proposals]
|
| 528 |
+
box_features = self.box_pooler(features, pool_boxes)
|
| 529 |
+
box_features = _ScaleGradient.apply(box_features, 1.0 / self.num_cascade_stages)
|
| 530 |
+
box_features = self.box_head[stage](box_features)
|
| 531 |
+
if self.add_feature_to_prop:
|
| 532 |
+
feats_per_image = box_features.split(
|
| 533 |
+
[len(p) for p in proposals], dim=0)
|
| 534 |
+
for feat, p in zip(feats_per_image, proposals):
|
| 535 |
+
p.feat = feat
|
| 536 |
+
return self.box_predictor[stage](
|
| 537 |
+
box_features,
|
| 538 |
+
classifier_info=classifier_info)
|
| 539 |
+
|
| 540 |
+
def _run_stage_cmm(self, features, proposals, stage, \
|
| 541 |
+
classifier_info=(None,None,None)):
|
| 542 |
+
"""
|
| 543 |
+
Support classifier_info and add_feature_to_prop
|
| 544 |
+
"""
|
| 545 |
+
pool_boxes = [x.proposal_boxes for x in proposals]
|
| 546 |
+
box_features = self.box_pooler(features, pool_boxes)
|
| 547 |
+
box_features = _ScaleGradient.apply(box_features, 1.0 / self.num_cascade_stages)
|
| 548 |
+
box_features = self.box_head_cmm[stage](box_features)
|
| 549 |
+
if self.add_feature_to_prop:
|
| 550 |
+
feats_per_image = box_features.split(
|
| 551 |
+
[len(p) for p in proposals], dim=0)
|
| 552 |
+
for feat, p in zip(feats_per_image, proposals):
|
| 553 |
+
p.feat = feat
|
| 554 |
+
return self.box_predictor_cmm[stage](
|
| 555 |
+
box_features,
|
| 556 |
+
classifier_info=classifier_info)
|
proxydet/modeling/roi_heads/zero_shot_classifier.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
'''
|
| 3 |
+
Modifications Copyright (c) 2024-present NAVER Corp, Apache License v2.0
|
| 4 |
+
original source: https://github.com/facebookresearch/Detic/blob/main/detic/modeling/roi_heads/zero_shot_classifier.py
|
| 5 |
+
'''
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
from detectron2.config import configurable
|
| 11 |
+
from detectron2.layers import Linear, ShapeSpec
|
| 12 |
+
|
| 13 |
+
class ZeroShotClassifier(nn.Module):
|
| 14 |
+
@configurable
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
input_shape: ShapeSpec,
|
| 18 |
+
*,
|
| 19 |
+
num_classes: int,
|
| 20 |
+
zs_weight_path: str,
|
| 21 |
+
zs_weight_dim: int = 512,
|
| 22 |
+
use_bias: float = 0.0,
|
| 23 |
+
norm_weight: bool = True,
|
| 24 |
+
norm_temperature: float = 50.0,
|
| 25 |
+
use_regional_embedding: bool = False
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
if isinstance(input_shape, int): # some backward compatibility
|
| 29 |
+
input_shape = ShapeSpec(channels=input_shape)
|
| 30 |
+
input_size = input_shape.channels * (input_shape.width or 1) * (input_shape.height or 1)
|
| 31 |
+
self.norm_weight = norm_weight
|
| 32 |
+
self.norm_temperature = norm_temperature
|
| 33 |
+
|
| 34 |
+
self.use_bias = use_bias < 0
|
| 35 |
+
if self.use_bias:
|
| 36 |
+
self.cls_bias = nn.Parameter(torch.ones(1) * use_bias)
|
| 37 |
+
|
| 38 |
+
self.linear = nn.Linear(input_size, zs_weight_dim)
|
| 39 |
+
|
| 40 |
+
if zs_weight_path == 'rand':
|
| 41 |
+
zs_weight = torch.randn((zs_weight_dim, num_classes))
|
| 42 |
+
nn.init.normal_(zs_weight, std=0.01)
|
| 43 |
+
else:
|
| 44 |
+
zs_weight = torch.tensor(
|
| 45 |
+
np.load(zs_weight_path),
|
| 46 |
+
dtype=torch.float32).permute(1, 0).contiguous() # D x C
|
| 47 |
+
zs_weight = torch.cat(
|
| 48 |
+
[zs_weight, zs_weight.new_zeros((zs_weight_dim, 1))],
|
| 49 |
+
dim=1) # D x (C + 1)
|
| 50 |
+
|
| 51 |
+
if self.norm_weight:
|
| 52 |
+
zs_weight = F.normalize(zs_weight, p=2, dim=0)
|
| 53 |
+
|
| 54 |
+
if zs_weight_path == 'rand':
|
| 55 |
+
self.zs_weight = nn.Parameter(zs_weight)
|
| 56 |
+
else:
|
| 57 |
+
self.register_buffer('zs_weight', zs_weight)
|
| 58 |
+
|
| 59 |
+
assert self.zs_weight.shape[1] == num_classes + 1, self.zs_weight.shape
|
| 60 |
+
|
| 61 |
+
self.use_regional_embedding = use_regional_embedding
|
| 62 |
+
if self.use_regional_embedding:
|
| 63 |
+
assert (
|
| 64 |
+
self.norm_weight
|
| 65 |
+
), "norm_weight should be True for using regional embedding."
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@classmethod
|
| 69 |
+
def from_config(cls, cfg, input_shape):
|
| 70 |
+
return {
|
| 71 |
+
'input_shape': input_shape,
|
| 72 |
+
'num_classes': cfg.MODEL.ROI_HEADS.NUM_CLASSES,
|
| 73 |
+
'zs_weight_path': cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH,
|
| 74 |
+
'zs_weight_dim': cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_DIM,
|
| 75 |
+
'use_bias': cfg.MODEL.ROI_BOX_HEAD.USE_BIAS,
|
| 76 |
+
'norm_weight': cfg.MODEL.ROI_BOX_HEAD.NORM_WEIGHT,
|
| 77 |
+
'norm_temperature': cfg.MODEL.ROI_BOX_HEAD.NORM_TEMP,
|
| 78 |
+
"use_regional_embedding": cfg.MODEL.ROI_BOX_HEAD.USE_REGIONAL_EMBEDDING,
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
def forward(self, x, classifier=None):
|
| 82 |
+
'''
|
| 83 |
+
Inputs:
|
| 84 |
+
x: B x D'
|
| 85 |
+
classifier_info: (C', C' x D)
|
| 86 |
+
'''
|
| 87 |
+
x = self.linear(x)
|
| 88 |
+
if classifier is not None:
|
| 89 |
+
zs_weight = classifier.permute(1, 0).contiguous() # D x C'
|
| 90 |
+
zs_weight = F.normalize(zs_weight, p=2, dim=0) \
|
| 91 |
+
if self.norm_weight else zs_weight
|
| 92 |
+
else:
|
| 93 |
+
zs_weight = self.zs_weight
|
| 94 |
+
if self.norm_weight:
|
| 95 |
+
# NOTE: x shape is [Batch size * # proposals for each image, 512 (embedding dim)]
|
| 96 |
+
x = F.normalize(x, p=2, dim=1)
|
| 97 |
+
if self.use_regional_embedding:
|
| 98 |
+
# NOTE: gradient of cloned tensor will be propagated to the original tensor (x): https://discuss.pytorch.org/t/how-does-clone-interact-with-backpropagation/8247/6?u=111368
|
| 99 |
+
regional_embedding = x.clone()
|
| 100 |
+
# NOTE: apply normalizing temperature
|
| 101 |
+
x *= self.norm_temperature
|
| 102 |
+
|
| 103 |
+
x = torch.mm(x, zs_weight)
|
| 104 |
+
|
| 105 |
+
if self.use_bias:
|
| 106 |
+
x = x + self.cls_bias
|
| 107 |
+
|
| 108 |
+
if not self.use_regional_embedding:
|
| 109 |
+
return x # class logits
|
| 110 |
+
else:
|
| 111 |
+
return x, regional_embedding # class logits & regional embeddings
|
proxydet/modeling/text/text_encoder.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This code is modified from https://github.com/openai/CLIP/blob/main/clip/clip.py
|
| 2 |
+
# Modified by Xingyi Zhou
|
| 3 |
+
# The original code is under MIT license
|
| 4 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 5 |
+
from typing import Union, List
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
|
| 12 |
+
|
| 13 |
+
__all__ = ["tokenize"]
|
| 14 |
+
|
| 15 |
+
count = 0
|
| 16 |
+
|
| 17 |
+
class LayerNorm(nn.LayerNorm):
|
| 18 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
| 19 |
+
|
| 20 |
+
def forward(self, x: torch.Tensor):
|
| 21 |
+
orig_type = x.dtype
|
| 22 |
+
ret = super().forward(x.type(torch.float32))
|
| 23 |
+
return ret.type(orig_type)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class QuickGELU(nn.Module):
|
| 27 |
+
def forward(self, x: torch.Tensor):
|
| 28 |
+
return x * torch.sigmoid(1.702 * x)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ResidualAttentionBlock(nn.Module):
|
| 32 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
| 33 |
+
super().__init__()
|
| 34 |
+
|
| 35 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 36 |
+
self.ln_1 = LayerNorm(d_model)
|
| 37 |
+
self.mlp = nn.Sequential(OrderedDict([
|
| 38 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
| 39 |
+
("gelu", QuickGELU()),
|
| 40 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
| 41 |
+
]))
|
| 42 |
+
self.ln_2 = LayerNorm(d_model)
|
| 43 |
+
self.attn_mask = attn_mask
|
| 44 |
+
|
| 45 |
+
def attention(self, x: torch.Tensor):
|
| 46 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
| 47 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
| 48 |
+
|
| 49 |
+
def forward(self, x: torch.Tensor):
|
| 50 |
+
x = x + self.attention(self.ln_1(x))
|
| 51 |
+
x = x + self.mlp(self.ln_2(x))
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Transformer(nn.Module):
|
| 56 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.width = width
|
| 59 |
+
self.layers = layers
|
| 60 |
+
self.resblocks = nn.Sequential(
|
| 61 |
+
*[ResidualAttentionBlock(width, heads, attn_mask) \
|
| 62 |
+
for _ in range(layers)])
|
| 63 |
+
|
| 64 |
+
def forward(self, x: torch.Tensor):
|
| 65 |
+
return self.resblocks(x)
|
| 66 |
+
|
| 67 |
+
class CLIPTEXT(nn.Module):
|
| 68 |
+
def __init__(self,
|
| 69 |
+
embed_dim=512,
|
| 70 |
+
# text
|
| 71 |
+
context_length=77,
|
| 72 |
+
vocab_size=49408,
|
| 73 |
+
transformer_width=512,
|
| 74 |
+
transformer_heads=8,
|
| 75 |
+
transformer_layers=12
|
| 76 |
+
):
|
| 77 |
+
super().__init__()
|
| 78 |
+
|
| 79 |
+
self._tokenizer = _Tokenizer()
|
| 80 |
+
self.context_length = context_length
|
| 81 |
+
|
| 82 |
+
self.transformer = Transformer(
|
| 83 |
+
width=transformer_width,
|
| 84 |
+
layers=transformer_layers,
|
| 85 |
+
heads=transformer_heads,
|
| 86 |
+
attn_mask=self.build_attention_mask()
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
self.vocab_size = vocab_size
|
| 90 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
| 91 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
| 92 |
+
self.ln_final = LayerNorm(transformer_width)
|
| 93 |
+
|
| 94 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
| 95 |
+
# self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 96 |
+
|
| 97 |
+
self.initialize_parameters()
|
| 98 |
+
|
| 99 |
+
def initialize_parameters(self):
|
| 100 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
| 101 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
| 102 |
+
|
| 103 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
| 104 |
+
attn_std = self.transformer.width ** -0.5
|
| 105 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
| 106 |
+
for block in self.transformer.resblocks:
|
| 107 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 108 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 109 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 110 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 111 |
+
|
| 112 |
+
if self.text_projection is not None:
|
| 113 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
| 114 |
+
|
| 115 |
+
def build_attention_mask(self):
|
| 116 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
| 117 |
+
# pytorch uses additive attention mask; fill with -inf
|
| 118 |
+
mask = torch.empty(self.context_length, self.context_length)
|
| 119 |
+
mask.fill_(float("-inf"))
|
| 120 |
+
mask.triu_(1) # zero out the lower diagonal
|
| 121 |
+
return mask
|
| 122 |
+
|
| 123 |
+
@property
|
| 124 |
+
def device(self):
|
| 125 |
+
return self.text_projection.device
|
| 126 |
+
|
| 127 |
+
@property
|
| 128 |
+
def dtype(self):
|
| 129 |
+
return self.text_projection.dtype
|
| 130 |
+
|
| 131 |
+
def tokenize(self,
|
| 132 |
+
texts: Union[str, List[str]], \
|
| 133 |
+
context_length: int = 77) -> torch.LongTensor:
|
| 134 |
+
"""
|
| 135 |
+
"""
|
| 136 |
+
if isinstance(texts, str):
|
| 137 |
+
texts = [texts]
|
| 138 |
+
|
| 139 |
+
sot_token = self._tokenizer.encoder["<|startoftext|>"]
|
| 140 |
+
eot_token = self._tokenizer.encoder["<|endoftext|>"]
|
| 141 |
+
all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts]
|
| 142 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 143 |
+
|
| 144 |
+
for i, tokens in enumerate(all_tokens):
|
| 145 |
+
if len(tokens) > context_length:
|
| 146 |
+
st = torch.randint(
|
| 147 |
+
len(tokens) - context_length + 1, (1,))[0].item()
|
| 148 |
+
tokens = tokens[st: st + context_length]
|
| 149 |
+
# raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
| 150 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
| 151 |
+
|
| 152 |
+
return result
|
| 153 |
+
|
| 154 |
+
def encode_text(self, text):
|
| 155 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
| 156 |
+
x = x + self.positional_embedding.type(self.dtype)
|
| 157 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 158 |
+
x = self.transformer(x)
|
| 159 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 160 |
+
x = self.ln_final(x).type(self.dtype)
|
| 161 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 162 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
| 163 |
+
return x
|
| 164 |
+
|
| 165 |
+
def forward(self, captions):
|
| 166 |
+
'''
|
| 167 |
+
captions: list of strings
|
| 168 |
+
'''
|
| 169 |
+
text = self.tokenize(captions).to(self.device) # B x L x D
|
| 170 |
+
features = self.encode_text(text) # B x D
|
| 171 |
+
return features
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def build_text_encoder(pretrain=True):
|
| 175 |
+
text_encoder = CLIPTEXT()
|
| 176 |
+
if pretrain:
|
| 177 |
+
import clip
|
| 178 |
+
pretrained_model, _ = clip.load("ViT-B/32", device='cpu')
|
| 179 |
+
state_dict = pretrained_model.state_dict()
|
| 180 |
+
to_delete_keys = ["logit_scale", "input_resolution", \
|
| 181 |
+
"context_length", "vocab_size"] + \
|
| 182 |
+
[k for k in state_dict.keys() if k.startswith('visual.')]
|
| 183 |
+
for k in to_delete_keys:
|
| 184 |
+
if k in state_dict:
|
| 185 |
+
del state_dict[k]
|
| 186 |
+
print('Loading pretrained CLIP')
|
| 187 |
+
text_encoder.load_state_dict(state_dict)
|
| 188 |
+
# import pdb; pdb.set_trace()
|
| 189 |
+
return text_encoder
|
proxydet/modeling/utils.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import torch
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
def load_class_freq(
|
| 8 |
+
path='datasets/metadata/lvis_v1_train_cat_info.json', freq_weight=1.0):
|
| 9 |
+
cat_info = json.load(open(path, 'r'))
|
| 10 |
+
cat_info = torch.tensor(
|
| 11 |
+
[c['image_count'] for c in sorted(cat_info, key=lambda x: x['id'])])
|
| 12 |
+
freq_weight = cat_info.float() ** freq_weight
|
| 13 |
+
return freq_weight
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_fed_loss_inds(gt_classes, num_sample_cats, C, weight=None):
|
| 17 |
+
appeared = torch.unique(gt_classes) # C'
|
| 18 |
+
prob = appeared.new_ones(C + 1).float()
|
| 19 |
+
prob[-1] = 0
|
| 20 |
+
if len(appeared) < num_sample_cats:
|
| 21 |
+
if weight is not None:
|
| 22 |
+
prob[:C] = weight.float().clone()
|
| 23 |
+
prob[appeared] = 0
|
| 24 |
+
more_appeared = torch.multinomial(
|
| 25 |
+
prob, num_sample_cats - len(appeared),
|
| 26 |
+
replacement=False)
|
| 27 |
+
appeared = torch.cat([appeared, more_appeared])
|
| 28 |
+
return appeared
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def reset_cls_test(model, cls_path, num_classes):
|
| 33 |
+
model.roi_heads.num_classes = num_classes
|
| 34 |
+
if type(cls_path) == str:
|
| 35 |
+
print('Resetting zs_weight', cls_path)
|
| 36 |
+
zs_weight = torch.tensor(
|
| 37 |
+
np.load(cls_path),
|
| 38 |
+
dtype=torch.float32).permute(1, 0).contiguous() # D x C
|
| 39 |
+
else:
|
| 40 |
+
zs_weight = cls_path
|
| 41 |
+
zs_weight = torch.cat(
|
| 42 |
+
[zs_weight, zs_weight.new_zeros((zs_weight.shape[0], 1))],
|
| 43 |
+
dim=1) # D x (C + 1)
|
| 44 |
+
if model.roi_heads.box_predictor[0].cls_score.norm_weight:
|
| 45 |
+
zs_weight = F.normalize(zs_weight, p=2, dim=0)
|
| 46 |
+
zs_weight = zs_weight.to(model.device)
|
| 47 |
+
for k in range(len(model.roi_heads.box_predictor)):
|
| 48 |
+
del model.roi_heads.box_predictor[k].cls_score.zs_weight
|
| 49 |
+
model.roi_heads.box_predictor[k].cls_score.zs_weight = zs_weight
|
| 50 |
+
|
| 51 |
+
if hasattr(model.roi_heads, "box_predictor_cmm"):
|
| 52 |
+
for k in range(len(model.roi_heads.box_predictor_cmm)):
|
| 53 |
+
del model.roi_heads.box_predictor_cmm[k].cls_score.zs_weight
|
| 54 |
+
model.roi_heads.box_predictor_cmm[k].cls_score.zs_weight = zs_weight
|
proxydet/predictor.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
'''
|
| 3 |
+
Modifications Copyright (c) 2024-present NAVER Corp, Apache License v2.0
|
| 4 |
+
original source: https://github.com/facebookresearch/Detic/blob/main/detic/predictor.py
|
| 5 |
+
'''
|
| 6 |
+
import atexit
|
| 7 |
+
import bisect
|
| 8 |
+
import numpy as np
|
| 9 |
+
import multiprocessing as mp
|
| 10 |
+
from collections import deque
|
| 11 |
+
import cv2
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
from detectron2.data import MetadataCatalog
|
| 16 |
+
from detectron2.engine.defaults import DefaultPredictor
|
| 17 |
+
from detectron2.utils.video_visualizer import VideoVisualizer
|
| 18 |
+
from detectron2.utils.visualizer import ColorMode, Visualizer
|
| 19 |
+
|
| 20 |
+
from .modeling.utils import reset_cls_test
|
| 21 |
+
|
| 22 |
+
def get_text_encoder():
|
| 23 |
+
from proxydet.modeling.text.text_encoder import build_text_encoder
|
| 24 |
+
text_encoder = build_text_encoder(pretrain=True)
|
| 25 |
+
text_encoder.eval()
|
| 26 |
+
return text_encoder
|
| 27 |
+
def get_clip_embeddings(vocabulary, text_encoder, prompt='a '):
|
| 28 |
+
texts = [prompt + x for x in vocabulary]
|
| 29 |
+
emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu()
|
| 30 |
+
return emb
|
| 31 |
+
|
| 32 |
+
BUILDIN_CLASSIFIER = {
|
| 33 |
+
'lvis': 'datasets/metadata/lvis_v1_clip_a+cname.npy',
|
| 34 |
+
'objects365': 'datasets/metadata/o365_clip_a+cnamefix.npy',
|
| 35 |
+
'openimages': 'datasets/metadata/oid_clip_a+cname.npy',
|
| 36 |
+
'coco': 'datasets/metadata/coco_clip_a+cname.npy',
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
BUILDIN_METADATA_PATH = {
|
| 40 |
+
'lvis': 'lvis_v1_val',
|
| 41 |
+
'objects365': 'objects365_v2_val',
|
| 42 |
+
'openimages': 'oid_val_expanded',
|
| 43 |
+
'coco': 'coco_2017_val',
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
class VisualizationDemo(object):
|
| 47 |
+
def __init__(self, cfg, args,
|
| 48 |
+
instance_mode=ColorMode.IMAGE, parallel=False):
|
| 49 |
+
"""
|
| 50 |
+
Args:
|
| 51 |
+
cfg (CfgNode):
|
| 52 |
+
instance_mode (ColorMode):
|
| 53 |
+
parallel (bool): whether to run the model in different processes from visualization.
|
| 54 |
+
Useful since the visualization logic can be slow.
|
| 55 |
+
"""
|
| 56 |
+
if args.vocabulary == 'custom':
|
| 57 |
+
self.text_encoder = get_text_encoder()
|
| 58 |
+
self.metadata = MetadataCatalog.get("__unused")
|
| 59 |
+
self.metadata.thing_classes = args.custom_vocabulary.split(',')
|
| 60 |
+
classifier = get_clip_embeddings(self.metadata.thing_classes, self.text_encoder)
|
| 61 |
+
else:
|
| 62 |
+
self.metadata = MetadataCatalog.get(
|
| 63 |
+
BUILDIN_METADATA_PATH[args.vocabulary])
|
| 64 |
+
classifier = BUILDIN_CLASSIFIER[args.vocabulary]
|
| 65 |
+
|
| 66 |
+
num_classes = len(self.metadata.thing_classes)
|
| 67 |
+
self.cpu_device = torch.device("cpu")
|
| 68 |
+
self.instance_mode = instance_mode
|
| 69 |
+
|
| 70 |
+
self.parallel = parallel
|
| 71 |
+
if parallel:
|
| 72 |
+
num_gpu = torch.cuda.device_count()
|
| 73 |
+
self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu)
|
| 74 |
+
else:
|
| 75 |
+
self.predictor = DefaultPredictor(cfg)
|
| 76 |
+
reset_cls_test(self.predictor.model, classifier, num_classes)
|
| 77 |
+
|
| 78 |
+
# reset base category mask, based on the similarity
|
| 79 |
+
self.zeroshot_weight_path = args.zeroshot_weight_path
|
| 80 |
+
self.base_cat_threshold = args.base_cat_threshold
|
| 81 |
+
if self.zeroshot_weight_path is not None:
|
| 82 |
+
self.trained_zs_classifier = torch.tensor(
|
| 83 |
+
np.load(self.zeroshot_weight_path),
|
| 84 |
+
dtype=torch.float32
|
| 85 |
+
).permute(1, 0).contiguous().to(cfg.MODEL.DEVICE) # D x C
|
| 86 |
+
self.trained_zs_classifier = F.normalize(self.trained_zs_classifier, p=2, dim=0)
|
| 87 |
+
|
| 88 |
+
self.base_cat_indices = self.predictor.model.roi_heads.base_cat_mask.nonzero().squeeze(-1)
|
| 89 |
+
self.reset_base_cat_mask()
|
| 90 |
+
|
| 91 |
+
def reset_base_cat_mask(self, new_base_cat_mask=None):
|
| 92 |
+
if new_base_cat_mask is None:
|
| 93 |
+
# L2 normalized
|
| 94 |
+
if hasattr(self.predictor.model.roi_heads.box_predictor, "__getitem__"):
|
| 95 |
+
custom_classifier = self.predictor.model.roi_heads.box_predictor[0].cls_score.zs_weight
|
| 96 |
+
else:
|
| 97 |
+
custom_classifier = self.predictor.model.roi_heads.box_predictor.cls_score.zs_weight
|
| 98 |
+
|
| 99 |
+
base_cat_sim = custom_classifier.T[:-1, :] @ self.trained_zs_classifier[:, self.base_cat_indices]
|
| 100 |
+
|
| 101 |
+
# reset base cat mask
|
| 102 |
+
new_base_cat_mask = torch.cat([base_cat_sim.max(dim=1)[0] > self.base_cat_threshold, base_cat_sim.new_zeros((1))], dim=0) # for bg
|
| 103 |
+
new_base_cat_mask = new_base_cat_mask.bool()
|
| 104 |
+
self.predictor.model.roi_heads.base_cat_mask = new_base_cat_mask
|
| 105 |
+
return new_base_cat_mask
|
| 106 |
+
|
| 107 |
+
def reset_classifier(self, vocabs):
|
| 108 |
+
thing_classes = vocabs.split(',')
|
| 109 |
+
classifier = get_clip_embeddings(thing_classes, self.text_encoder)
|
| 110 |
+
reset_cls_test(self.predictor.model, classifier, len(thing_classes))
|
| 111 |
+
|
| 112 |
+
def run_on_image(self, image):
|
| 113 |
+
"""
|
| 114 |
+
Args:
|
| 115 |
+
image (np.ndarray): an image of shape (H, W, C) (in BGR order).
|
| 116 |
+
This is the format used by OpenCV.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
predictions (dict): the output of the model.
|
| 120 |
+
vis_output (VisImage): the visualized image output.
|
| 121 |
+
"""
|
| 122 |
+
vis_output = None
|
| 123 |
+
predictions = self.predictor(image)
|
| 124 |
+
# Convert image from OpenCV BGR format to Matplotlib RGB format.
|
| 125 |
+
image = image[:, :, ::-1]
|
| 126 |
+
visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode)
|
| 127 |
+
if "panoptic_seg" in predictions:
|
| 128 |
+
panoptic_seg, segments_info = predictions["panoptic_seg"]
|
| 129 |
+
vis_output = visualizer.draw_panoptic_seg_predictions(
|
| 130 |
+
panoptic_seg.to(self.cpu_device), segments_info
|
| 131 |
+
)
|
| 132 |
+
else:
|
| 133 |
+
if "sem_seg" in predictions:
|
| 134 |
+
vis_output = visualizer.draw_sem_seg(
|
| 135 |
+
predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
|
| 136 |
+
)
|
| 137 |
+
if "instances" in predictions:
|
| 138 |
+
instances = predictions["instances"].to(self.cpu_device)
|
| 139 |
+
vis_output = visualizer.draw_instance_predictions(predictions=instances)
|
| 140 |
+
|
| 141 |
+
return predictions, vis_output
|
| 142 |
+
|
| 143 |
+
def _frame_from_video(self, video):
|
| 144 |
+
while video.isOpened():
|
| 145 |
+
success, frame = video.read()
|
| 146 |
+
if success:
|
| 147 |
+
yield frame
|
| 148 |
+
else:
|
| 149 |
+
break
|
| 150 |
+
|
| 151 |
+
def run_on_video(self, video):
|
| 152 |
+
"""
|
| 153 |
+
Visualizes predictions on frames of the input video.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be
|
| 157 |
+
either a webcam or a video file.
|
| 158 |
+
|
| 159 |
+
Yields:
|
| 160 |
+
ndarray: BGR visualizations of each video frame.
|
| 161 |
+
"""
|
| 162 |
+
video_visualizer = VideoVisualizer(self.metadata, self.instance_mode)
|
| 163 |
+
|
| 164 |
+
def process_predictions(frame, predictions):
|
| 165 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 166 |
+
if "panoptic_seg" in predictions:
|
| 167 |
+
panoptic_seg, segments_info = predictions["panoptic_seg"]
|
| 168 |
+
vis_frame = video_visualizer.draw_panoptic_seg_predictions(
|
| 169 |
+
frame, panoptic_seg.to(self.cpu_device), segments_info
|
| 170 |
+
)
|
| 171 |
+
elif "instances" in predictions:
|
| 172 |
+
predictions = predictions["instances"].to(self.cpu_device)
|
| 173 |
+
vis_frame = video_visualizer.draw_instance_predictions(frame, predictions)
|
| 174 |
+
elif "sem_seg" in predictions:
|
| 175 |
+
vis_frame = video_visualizer.draw_sem_seg(
|
| 176 |
+
frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# Converts Matplotlib RGB format to OpenCV BGR format
|
| 180 |
+
vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR)
|
| 181 |
+
return vis_frame
|
| 182 |
+
|
| 183 |
+
frame_gen = self._frame_from_video(video)
|
| 184 |
+
if self.parallel:
|
| 185 |
+
buffer_size = self.predictor.default_buffer_size
|
| 186 |
+
|
| 187 |
+
frame_data = deque()
|
| 188 |
+
|
| 189 |
+
for cnt, frame in enumerate(frame_gen):
|
| 190 |
+
frame_data.append(frame)
|
| 191 |
+
self.predictor.put(frame)
|
| 192 |
+
|
| 193 |
+
if cnt >= buffer_size:
|
| 194 |
+
frame = frame_data.popleft()
|
| 195 |
+
predictions = self.predictor.get()
|
| 196 |
+
yield process_predictions(frame, predictions)
|
| 197 |
+
|
| 198 |
+
while len(frame_data):
|
| 199 |
+
frame = frame_data.popleft()
|
| 200 |
+
predictions = self.predictor.get()
|
| 201 |
+
yield process_predictions(frame, predictions)
|
| 202 |
+
else:
|
| 203 |
+
for frame in frame_gen:
|
| 204 |
+
yield process_predictions(frame, self.predictor(frame))
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class AsyncPredictor:
|
| 208 |
+
"""
|
| 209 |
+
A predictor that runs the model asynchronously, possibly on >1 GPUs.
|
| 210 |
+
Because rendering the visualization takes considerably amount of time,
|
| 211 |
+
this helps improve throughput a little bit when rendering videos.
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
class _StopToken:
|
| 215 |
+
pass
|
| 216 |
+
|
| 217 |
+
class _PredictWorker(mp.Process):
|
| 218 |
+
def __init__(self, cfg, task_queue, result_queue):
|
| 219 |
+
self.cfg = cfg
|
| 220 |
+
self.task_queue = task_queue
|
| 221 |
+
self.result_queue = result_queue
|
| 222 |
+
super().__init__()
|
| 223 |
+
|
| 224 |
+
def run(self):
|
| 225 |
+
predictor = DefaultPredictor(self.cfg)
|
| 226 |
+
|
| 227 |
+
while True:
|
| 228 |
+
task = self.task_queue.get()
|
| 229 |
+
if isinstance(task, AsyncPredictor._StopToken):
|
| 230 |
+
break
|
| 231 |
+
idx, data = task
|
| 232 |
+
result = predictor(data)
|
| 233 |
+
self.result_queue.put((idx, result))
|
| 234 |
+
|
| 235 |
+
def __init__(self, cfg, num_gpus: int = 1):
|
| 236 |
+
"""
|
| 237 |
+
Args:
|
| 238 |
+
cfg (CfgNode):
|
| 239 |
+
num_gpus (int): if 0, will run on CPU
|
| 240 |
+
"""
|
| 241 |
+
num_workers = max(num_gpus, 1)
|
| 242 |
+
self.task_queue = mp.Queue(maxsize=num_workers * 3)
|
| 243 |
+
self.result_queue = mp.Queue(maxsize=num_workers * 3)
|
| 244 |
+
self.procs = []
|
| 245 |
+
for gpuid in range(max(num_gpus, 1)):
|
| 246 |
+
cfg = cfg.clone()
|
| 247 |
+
cfg.defrost()
|
| 248 |
+
cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
|
| 249 |
+
self.procs.append(
|
| 250 |
+
AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
self.put_idx = 0
|
| 254 |
+
self.get_idx = 0
|
| 255 |
+
self.result_rank = []
|
| 256 |
+
self.result_data = []
|
| 257 |
+
|
| 258 |
+
for p in self.procs:
|
| 259 |
+
p.start()
|
| 260 |
+
atexit.register(self.shutdown)
|
| 261 |
+
|
| 262 |
+
def put(self, image):
|
| 263 |
+
self.put_idx += 1
|
| 264 |
+
self.task_queue.put((self.put_idx, image))
|
| 265 |
+
|
| 266 |
+
def get(self):
|
| 267 |
+
self.get_idx += 1 # the index needed for this request
|
| 268 |
+
if len(self.result_rank) and self.result_rank[0] == self.get_idx:
|
| 269 |
+
res = self.result_data[0]
|
| 270 |
+
del self.result_data[0], self.result_rank[0]
|
| 271 |
+
return res
|
| 272 |
+
|
| 273 |
+
while True:
|
| 274 |
+
# make sure the results are returned in the correct order
|
| 275 |
+
idx, res = self.result_queue.get()
|
| 276 |
+
if idx == self.get_idx:
|
| 277 |
+
return res
|
| 278 |
+
insert = bisect.bisect(self.result_rank, idx)
|
| 279 |
+
self.result_rank.insert(insert, idx)
|
| 280 |
+
self.result_data.insert(insert, res)
|
| 281 |
+
|
| 282 |
+
def __len__(self):
|
| 283 |
+
return self.put_idx - self.get_idx
|
| 284 |
+
|
| 285 |
+
def __call__(self, image):
|
| 286 |
+
self.put(image)
|
| 287 |
+
return self.get()
|
| 288 |
+
|
| 289 |
+
def shutdown(self):
|
| 290 |
+
for _ in self.procs:
|
| 291 |
+
self.task_queue.put(AsyncPredictor._StopToken())
|
| 292 |
+
|
| 293 |
+
@property
|
| 294 |
+
def default_buffer_size(self):
|
| 295 |
+
return len(self.procs) * 5
|
requirements.txt
CHANGED
|
@@ -7,10 +7,4 @@ mss<=6.1.0
|
|
| 7 |
timm<=0.5.4
|
| 8 |
lvis
|
| 9 |
nltk<=3.7
|
| 10 |
-
#git+https://github.com/facebookresearch/detectron2.git
|
| 11 |
-
|
| 12 |
numpy>=1.18.5
|
| 13 |
-
#torch>=1.7.0
|
| 14 |
-
#torchvision>=0.8.1
|
| 15 |
-
git+https://github.com/huggingface/transformers.git
|
| 16 |
-
opencv-python
|
|
|
|
| 7 |
timm<=0.5.4
|
| 8 |
lvis
|
| 9 |
nltk<=3.7
|
|
|
|
|
|
|
| 10 |
numpy>=1.18.5
|
|
|
|
|
|
|
|
|
|
|
|
third_party/CenterNet2/.github/CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code of Conduct
|
| 2 |
+
|
| 3 |
+
Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
|
| 4 |
+
Please read the [full text](https://code.fb.com/codeofconduct/)
|
| 5 |
+
so that you can understand what actions will and will not be tolerated.
|
third_party/CenterNet2/.github/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to detectron2
|
| 2 |
+
|
| 3 |
+
## Issues
|
| 4 |
+
We use GitHub issues to track public bugs and questions.
|
| 5 |
+
Please make sure to follow one of the
|
| 6 |
+
[issue templates](https://github.com/facebookresearch/detectron2/issues/new/choose)
|
| 7 |
+
when reporting any issues.
|
| 8 |
+
|
| 9 |
+
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
| 10 |
+
disclosure of security bugs. In those cases, please go through the process
|
| 11 |
+
outlined on that page and do not file a public issue.
|
| 12 |
+
|
| 13 |
+
## Pull Requests
|
| 14 |
+
We actively welcome pull requests.
|
| 15 |
+
|
| 16 |
+
However, if you're adding any significant features (e.g. > 50 lines), please
|
| 17 |
+
make sure to discuss with maintainers about your motivation and proposals in an issue
|
| 18 |
+
before sending a PR. This is to save your time so you don't spend time on a PR that we'll not accept.
|
| 19 |
+
|
| 20 |
+
We do not always accept new features, and we take the following
|
| 21 |
+
factors into consideration:
|
| 22 |
+
|
| 23 |
+
1. Whether the same feature can be achieved without modifying detectron2.
|
| 24 |
+
Detectron2 is designed so that you can implement many extensions from the outside, e.g.
|
| 25 |
+
those in [projects](https://github.com/facebookresearch/detectron2/tree/master/projects).
|
| 26 |
+
* If some part of detectron2 is not extensible enough, you can also bring up a more general issue to
|
| 27 |
+
improve it. Such feature request may be useful to more users.
|
| 28 |
+
2. Whether the feature is potentially useful to a large audience (e.g. an impactful detection paper, a popular dataset,
|
| 29 |
+
a significant speedup, a widely useful utility),
|
| 30 |
+
or only to a small portion of users (e.g., a less-known paper, an improvement not in the object
|
| 31 |
+
detection field, a trick that's not very popular in the community, code to handle a non-standard type of data)
|
| 32 |
+
* Adoption of additional models, datasets, new task are by default not added to detectron2 before they
|
| 33 |
+
receive significant popularity in the community.
|
| 34 |
+
We sometimes accept such features in `projects/`, or as a link in `projects/README.md`.
|
| 35 |
+
3. Whether the proposed solution has a good design / interface. This can be discussed in the issue prior to PRs, or
|
| 36 |
+
in the form of a draft PR.
|
| 37 |
+
4. Whether the proposed solution adds extra mental/practical overhead to users who don't
|
| 38 |
+
need such feature.
|
| 39 |
+
5. Whether the proposed solution breaks existing APIs.
|
| 40 |
+
|
| 41 |
+
To add a feature to an existing function/class `Func`, there are always two approaches:
|
| 42 |
+
(1) add new arguments to `Func`; (2) write a new `Func_with_new_feature`.
|
| 43 |
+
To meet the above criteria, we often prefer approach (2), because:
|
| 44 |
+
|
| 45 |
+
1. It does not involve modifying or potentially breaking existing code.
|
| 46 |
+
2. It does not add overhead to users who do not need the new feature.
|
| 47 |
+
3. Adding new arguments to a function/class is not scalable w.r.t. all the possible new research ideas in the future.
|
| 48 |
+
|
| 49 |
+
When sending a PR, please do:
|
| 50 |
+
|
| 51 |
+
1. If a PR contains multiple orthogonal changes, split it to several PRs.
|
| 52 |
+
2. If you've added code that should be tested, add tests.
|
| 53 |
+
3. For PRs that need experiments (e.g. adding a new model or new methods),
|
| 54 |
+
you don't need to update model zoo, but do provide experiment results in the description of the PR.
|
| 55 |
+
4. If APIs are changed, update the documentation.
|
| 56 |
+
5. We use the [Google style docstrings](https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html) in python.
|
| 57 |
+
6. Make sure your code lints with `./dev/linter.sh`.
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
## Contributor License Agreement ("CLA")
|
| 61 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
| 62 |
+
to do this once to work on any of Facebook's open source projects.
|
| 63 |
+
|
| 64 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
| 65 |
+
|
| 66 |
+
## License
|
| 67 |
+
By contributing to detectron2, you agree that your contributions will be licensed
|
| 68 |
+
under the LICENSE file in the root directory of this source tree.
|
third_party/CenterNet2/.github/Detectron2-Logo-Horz.svg
ADDED
|
|
third_party/CenterNet2/.github/ISSUE_TEMPLATE.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Please select an issue template from
|
| 3 |
+
https://github.com/facebookresearch/detectron2/issues/new/choose .
|
| 4 |
+
|
| 5 |
+
Otherwise your issue will be closed.
|
third_party/CenterNet2/.github/ISSUE_TEMPLATE/bugs.md
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: "π Bugs"
|
| 3 |
+
about: Report bugs in detectron2
|
| 4 |
+
title: Please read & provide the following
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## Instructions To Reproduce the π Bug:
|
| 9 |
+
1. Full runnable code or full changes you made:
|
| 10 |
+
```
|
| 11 |
+
If making changes to the project itself, please use output of the following command:
|
| 12 |
+
git rev-parse HEAD; git diff
|
| 13 |
+
|
| 14 |
+
<put code or diff here>
|
| 15 |
+
```
|
| 16 |
+
2. What exact command you run:
|
| 17 |
+
3. __Full logs__ or other relevant observations:
|
| 18 |
+
```
|
| 19 |
+
<put logs here>
|
| 20 |
+
```
|
| 21 |
+
4. please simplify the steps as much as possible so they do not require additional resources to
|
| 22 |
+
run, such as a private dataset.
|
| 23 |
+
|
| 24 |
+
## Expected behavior:
|
| 25 |
+
|
| 26 |
+
If there are no obvious error in "full logs" provided above,
|
| 27 |
+
please tell us the expected behavior.
|
| 28 |
+
|
| 29 |
+
## Environment:
|
| 30 |
+
|
| 31 |
+
Provide your environment information using the following command:
|
| 32 |
+
```
|
| 33 |
+
wget -nc -q https://github.com/facebookresearch/detectron2/raw/main/detectron2/utils/collect_env.py && python collect_env.py
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
If your issue looks like an installation issue / environment issue,
|
| 37 |
+
please first try to solve it yourself with the instructions in
|
| 38 |
+
https://detectron2.readthedocs.io/tutorials/install.html#common-installation-issues
|
third_party/CenterNet2/.github/ISSUE_TEMPLATE/config.yml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# require an issue template to be chosen
|
| 2 |
+
blank_issues_enabled: false
|
| 3 |
+
|
| 4 |
+
contact_links:
|
| 5 |
+
- name: How-To / All Other Questions
|
| 6 |
+
url: https://github.com/facebookresearch/detectron2/discussions
|
| 7 |
+
about: Use "github discussions" for community support on general questions that don't belong to the above issue categories
|
| 8 |
+
- name: Detectron2 Documentation
|
| 9 |
+
url: https://detectron2.readthedocs.io/index.html
|
| 10 |
+
about: Check if your question is answered in tutorials or API docs
|
| 11 |
+
|
| 12 |
+
# Unexpected behaviors & bugs are split to two templates.
|
| 13 |
+
# When they are one template, users think "it's not a bug" and don't choose the template.
|
| 14 |
+
#
|
| 15 |
+
# But the file name is still "unexpected-problems-bugs.md" so that old references
|
| 16 |
+
# to this issue template still works.
|
| 17 |
+
# It's ok since this template should be a superset of "bugs.md" (unexpected behaviors is a superset of bugs)
|
third_party/CenterNet2/.github/ISSUE_TEMPLATE/documentation.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: "\U0001F4DA Documentation Issue"
|
| 3 |
+
about: Report a problem about existing documentation, comments, website or tutorials.
|
| 4 |
+
labels: documentation
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## π Documentation Issue
|
| 9 |
+
|
| 10 |
+
This issue category is for problems about existing documentation, not for asking how-to questions.
|
| 11 |
+
|
| 12 |
+
* Provide a link to an existing documentation/comment/tutorial:
|
| 13 |
+
|
| 14 |
+
* How should the above documentation/comment/tutorial improve:
|