|
"""
|
|
Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
|
|
|
|
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
|
"""
|
|
|
|
import faster_coco_eval
|
|
import faster_coco_eval.core.mask as coco_mask
|
|
import torch
|
|
import torch.utils.data
|
|
import torchvision
|
|
import os
|
|
from PIL import Image
|
|
|
|
from ...core import register
|
|
from .._misc import convert_to_tv_tensor
|
|
from ._dataset import DetDataset
|
|
|
|
torchvision.disable_beta_transforms_warning()
|
|
faster_coco_eval.init_as_pycocotools()
|
|
Image.MAX_IMAGE_PIXELS = None
|
|
|
|
__all__ = ["CocoDetection"]
|
|
|
|
|
|
@register()
|
|
class CocoDetection(torchvision.datasets.CocoDetection, DetDataset):
|
|
__inject__ = [
|
|
"transforms",
|
|
]
|
|
__share__ = ["remap_mscoco_category"]
|
|
|
|
def __init__(
|
|
self, img_folder, ann_file, transforms, return_masks=False, remap_mscoco_category=False
|
|
):
|
|
super(CocoDetection, self).__init__(img_folder, ann_file)
|
|
self._transforms = transforms
|
|
self.prepare = ConvertCocoPolysToMask(return_masks)
|
|
self.img_folder = img_folder
|
|
self.ann_file = ann_file
|
|
self.return_masks = return_masks
|
|
self.remap_mscoco_category = remap_mscoco_category
|
|
|
|
def __getitem__(self, idx):
|
|
img, target = self.load_item(idx)
|
|
if self._transforms is not None:
|
|
img, target, _ = self._transforms(img, target, self)
|
|
return img, target
|
|
|
|
def load_item(self, idx):
|
|
image, target = super(CocoDetection, self).__getitem__(idx)
|
|
image_id = self.ids[idx]
|
|
image_path = os.path.join(self.img_folder, self.coco.loadImgs(image_id)[0]["file_name"])
|
|
target = {"image_id": image_id, "image_path": image_path, "annotations": target}
|
|
|
|
if self.remap_mscoco_category:
|
|
image, target = self.prepare(image, target, category2label=mscoco_category2label)
|
|
else:
|
|
image, target = self.prepare(image, target)
|
|
|
|
target["idx"] = torch.tensor([idx])
|
|
|
|
if "boxes" in target:
|
|
target["boxes"] = convert_to_tv_tensor(
|
|
target["boxes"], key="boxes", spatial_size=image.size[::-1]
|
|
)
|
|
|
|
if "masks" in target:
|
|
target["masks"] = convert_to_tv_tensor(target["masks"], key="masks")
|
|
|
|
return image, target
|
|
|
|
def extra_repr(self) -> str:
|
|
s = f" img_folder: {self.img_folder}\n ann_file: {self.ann_file}\n"
|
|
s += f" return_masks: {self.return_masks}\n"
|
|
if hasattr(self, "_transforms") and self._transforms is not None:
|
|
s += f" transforms:\n {repr(self._transforms)}"
|
|
if hasattr(self, "_preset") and self._preset is not None:
|
|
s += f" preset:\n {repr(self._preset)}"
|
|
return s
|
|
|
|
@property
|
|
def categories(
|
|
self,
|
|
):
|
|
return self.coco.dataset["categories"]
|
|
|
|
@property
|
|
def category2name(
|
|
self,
|
|
):
|
|
return {cat["id"]: cat["name"] for cat in self.categories}
|
|
|
|
@property
|
|
def category2label(
|
|
self,
|
|
):
|
|
return {cat["id"]: i for i, cat in enumerate(self.categories)}
|
|
|
|
@property
|
|
def label2category(
|
|
self,
|
|
):
|
|
return {i: cat["id"] for i, cat in enumerate(self.categories)}
|
|
|
|
|
|
def convert_coco_poly_to_mask(segmentations, height, width):
|
|
masks = []
|
|
for polygons in segmentations:
|
|
rles = coco_mask.frPyObjects(polygons, height, width)
|
|
mask = coco_mask.decode(rles)
|
|
if len(mask.shape) < 3:
|
|
mask = mask[..., None]
|
|
mask = torch.as_tensor(mask, dtype=torch.uint8)
|
|
mask = mask.any(dim=2)
|
|
masks.append(mask)
|
|
if masks:
|
|
masks = torch.stack(masks, dim=0)
|
|
else:
|
|
masks = torch.zeros((0, height, width), dtype=torch.uint8)
|
|
return masks
|
|
|
|
|
|
class ConvertCocoPolysToMask(object):
|
|
def __init__(self, return_masks=False):
|
|
self.return_masks = return_masks
|
|
|
|
def __call__(self, image: Image.Image, target, **kwargs):
|
|
w, h = image.size
|
|
|
|
image_id = target["image_id"]
|
|
image_id = torch.tensor([image_id])
|
|
|
|
image_path = target["image_path"]
|
|
|
|
anno = target["annotations"]
|
|
|
|
anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0]
|
|
|
|
boxes = [obj["bbox"] for obj in anno]
|
|
|
|
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
|
|
boxes[:, 2:] += boxes[:, :2]
|
|
boxes[:, 0::2].clamp_(min=0, max=w)
|
|
boxes[:, 1::2].clamp_(min=0, max=h)
|
|
|
|
category2label = kwargs.get("category2label", None)
|
|
if category2label is not None:
|
|
labels = [category2label[obj["category_id"]] for obj in anno]
|
|
else:
|
|
labels = [obj["category_id"] for obj in anno]
|
|
|
|
labels = torch.tensor(labels, dtype=torch.int64)
|
|
|
|
if self.return_masks:
|
|
segmentations = [obj["segmentation"] for obj in anno]
|
|
masks = convert_coco_poly_to_mask(segmentations, h, w)
|
|
|
|
keypoints = None
|
|
if anno and "keypoints" in anno[0]:
|
|
keypoints = [obj["keypoints"] for obj in anno]
|
|
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
|
|
num_keypoints = keypoints.shape[0]
|
|
if num_keypoints:
|
|
keypoints = keypoints.view(num_keypoints, -1, 3)
|
|
|
|
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
|
boxes = boxes[keep]
|
|
labels = labels[keep]
|
|
if self.return_masks:
|
|
masks = masks[keep]
|
|
if keypoints is not None:
|
|
keypoints = keypoints[keep]
|
|
|
|
target = {}
|
|
target["boxes"] = boxes
|
|
target["labels"] = labels
|
|
if self.return_masks:
|
|
target["masks"] = masks
|
|
target["image_id"] = image_id
|
|
target["image_path"] = image_path
|
|
if keypoints is not None:
|
|
target["keypoints"] = keypoints
|
|
|
|
|
|
area = torch.tensor([obj["area"] for obj in anno])
|
|
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
|
|
target["area"] = area[keep]
|
|
target["iscrowd"] = iscrowd[keep]
|
|
|
|
target["orig_size"] = torch.as_tensor([int(w), int(h)])
|
|
|
|
|
|
return image, target
|
|
|
|
|
|
mscoco_category2name = {
|
|
1: "person",
|
|
2: "bicycle",
|
|
3: "car",
|
|
4: "motorcycle",
|
|
5: "airplane",
|
|
6: "bus",
|
|
7: "train",
|
|
8: "truck",
|
|
9: "boat",
|
|
10: "traffic light",
|
|
11: "fire hydrant",
|
|
13: "stop sign",
|
|
14: "parking meter",
|
|
15: "bench",
|
|
16: "bird",
|
|
17: "cat",
|
|
18: "dog",
|
|
19: "horse",
|
|
20: "sheep",
|
|
21: "cow",
|
|
22: "elephant",
|
|
23: "bear",
|
|
24: "zebra",
|
|
25: "giraffe",
|
|
27: "backpack",
|
|
28: "umbrella",
|
|
31: "handbag",
|
|
32: "tie",
|
|
33: "suitcase",
|
|
34: "frisbee",
|
|
35: "skis",
|
|
36: "snowboard",
|
|
37: "sports ball",
|
|
38: "kite",
|
|
39: "baseball bat",
|
|
40: "baseball glove",
|
|
41: "skateboard",
|
|
42: "surfboard",
|
|
43: "tennis racket",
|
|
44: "bottle",
|
|
46: "wine glass",
|
|
47: "cup",
|
|
48: "fork",
|
|
49: "knife",
|
|
50: "spoon",
|
|
51: "bowl",
|
|
52: "banana",
|
|
53: "apple",
|
|
54: "sandwich",
|
|
55: "orange",
|
|
56: "broccoli",
|
|
57: "carrot",
|
|
58: "hot dog",
|
|
59: "pizza",
|
|
60: "donut",
|
|
61: "cake",
|
|
62: "chair",
|
|
63: "couch",
|
|
64: "potted plant",
|
|
65: "bed",
|
|
67: "dining table",
|
|
70: "toilet",
|
|
72: "tv",
|
|
73: "laptop",
|
|
74: "mouse",
|
|
75: "remote",
|
|
76: "keyboard",
|
|
77: "cell phone",
|
|
78: "microwave",
|
|
79: "oven",
|
|
80: "toaster",
|
|
81: "sink",
|
|
82: "refrigerator",
|
|
84: "book",
|
|
85: "clock",
|
|
86: "vase",
|
|
87: "scissors",
|
|
88: "teddy bear",
|
|
89: "hair drier",
|
|
90: "toothbrush",
|
|
}
|
|
|
|
mscoco_category2label = {k: i for i, k in enumerate(mscoco_category2name.keys())}
|
|
mscoco_label2category = {v: k for k, v in mscoco_category2label.items()}
|
|
|