D-FINE / src /data /dataset /coco_dataset.py
developer0hye's picture
Upload 76 files
e85fecb verified
"""
Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
import faster_coco_eval
import faster_coco_eval.core.mask as coco_mask
import torch
import torch.utils.data
import torchvision
import os
from PIL import Image
from ...core import register
from .._misc import convert_to_tv_tensor
from ._dataset import DetDataset
torchvision.disable_beta_transforms_warning()
faster_coco_eval.init_as_pycocotools()
Image.MAX_IMAGE_PIXELS = None
__all__ = ["CocoDetection"]
@register()
class CocoDetection(torchvision.datasets.CocoDetection, DetDataset):
__inject__ = [
"transforms",
]
__share__ = ["remap_mscoco_category"]
def __init__(
self, img_folder, ann_file, transforms, return_masks=False, remap_mscoco_category=False
):
super(CocoDetection, self).__init__(img_folder, ann_file)
self._transforms = transforms
self.prepare = ConvertCocoPolysToMask(return_masks)
self.img_folder = img_folder
self.ann_file = ann_file
self.return_masks = return_masks
self.remap_mscoco_category = remap_mscoco_category
def __getitem__(self, idx):
img, target = self.load_item(idx)
if self._transforms is not None:
img, target, _ = self._transforms(img, target, self)
return img, target
def load_item(self, idx):
image, target = super(CocoDetection, self).__getitem__(idx)
image_id = self.ids[idx]
image_path = os.path.join(self.img_folder, self.coco.loadImgs(image_id)[0]["file_name"])
target = {"image_id": image_id, "image_path": image_path, "annotations": target}
if self.remap_mscoco_category:
image, target = self.prepare(image, target, category2label=mscoco_category2label)
else:
image, target = self.prepare(image, target)
target["idx"] = torch.tensor([idx])
if "boxes" in target:
target["boxes"] = convert_to_tv_tensor(
target["boxes"], key="boxes", spatial_size=image.size[::-1]
)
if "masks" in target:
target["masks"] = convert_to_tv_tensor(target["masks"], key="masks")
return image, target
def extra_repr(self) -> str:
s = f" img_folder: {self.img_folder}\n ann_file: {self.ann_file}\n"
s += f" return_masks: {self.return_masks}\n"
if hasattr(self, "_transforms") and self._transforms is not None:
s += f" transforms:\n {repr(self._transforms)}"
if hasattr(self, "_preset") and self._preset is not None:
s += f" preset:\n {repr(self._preset)}"
return s
@property
def categories(
self,
):
return self.coco.dataset["categories"]
@property
def category2name(
self,
):
return {cat["id"]: cat["name"] for cat in self.categories}
@property
def category2label(
self,
):
return {cat["id"]: i for i, cat in enumerate(self.categories)}
@property
def label2category(
self,
):
return {i: cat["id"] for i, cat in enumerate(self.categories)}
def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
if len(mask.shape) < 3:
mask = mask[..., None]
mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = mask.any(dim=2)
masks.append(mask)
if masks:
masks = torch.stack(masks, dim=0)
else:
masks = torch.zeros((0, height, width), dtype=torch.uint8)
return masks
class ConvertCocoPolysToMask(object):
def __init__(self, return_masks=False):
self.return_masks = return_masks
def __call__(self, image: Image.Image, target, **kwargs):
w, h = image.size
image_id = target["image_id"]
image_id = torch.tensor([image_id])
image_path = target["image_path"]
anno = target["annotations"]
anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0]
boxes = [obj["bbox"] for obj in anno]
# guard against no boxes via resizing
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
boxes[:, 2:] += boxes[:, :2]
boxes[:, 0::2].clamp_(min=0, max=w)
boxes[:, 1::2].clamp_(min=0, max=h)
category2label = kwargs.get("category2label", None)
if category2label is not None:
labels = [category2label[obj["category_id"]] for obj in anno]
else:
labels = [obj["category_id"] for obj in anno]
labels = torch.tensor(labels, dtype=torch.int64)
if self.return_masks:
segmentations = [obj["segmentation"] for obj in anno]
masks = convert_coco_poly_to_mask(segmentations, h, w)
keypoints = None
if anno and "keypoints" in anno[0]:
keypoints = [obj["keypoints"] for obj in anno]
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
num_keypoints = keypoints.shape[0]
if num_keypoints:
keypoints = keypoints.view(num_keypoints, -1, 3)
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
boxes = boxes[keep]
labels = labels[keep]
if self.return_masks:
masks = masks[keep]
if keypoints is not None:
keypoints = keypoints[keep]
target = {}
target["boxes"] = boxes
target["labels"] = labels
if self.return_masks:
target["masks"] = masks
target["image_id"] = image_id
target["image_path"] = image_path
if keypoints is not None:
target["keypoints"] = keypoints
# for conversion to coco api
area = torch.tensor([obj["area"] for obj in anno])
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
target["area"] = area[keep]
target["iscrowd"] = iscrowd[keep]
target["orig_size"] = torch.as_tensor([int(w), int(h)])
# target["size"] = torch.as_tensor([int(w), int(h)])
return image, target
mscoco_category2name = {
1: "person",
2: "bicycle",
3: "car",
4: "motorcycle",
5: "airplane",
6: "bus",
7: "train",
8: "truck",
9: "boat",
10: "traffic light",
11: "fire hydrant",
13: "stop sign",
14: "parking meter",
15: "bench",
16: "bird",
17: "cat",
18: "dog",
19: "horse",
20: "sheep",
21: "cow",
22: "elephant",
23: "bear",
24: "zebra",
25: "giraffe",
27: "backpack",
28: "umbrella",
31: "handbag",
32: "tie",
33: "suitcase",
34: "frisbee",
35: "skis",
36: "snowboard",
37: "sports ball",
38: "kite",
39: "baseball bat",
40: "baseball glove",
41: "skateboard",
42: "surfboard",
43: "tennis racket",
44: "bottle",
46: "wine glass",
47: "cup",
48: "fork",
49: "knife",
50: "spoon",
51: "bowl",
52: "banana",
53: "apple",
54: "sandwich",
55: "orange",
56: "broccoli",
57: "carrot",
58: "hot dog",
59: "pizza",
60: "donut",
61: "cake",
62: "chair",
63: "couch",
64: "potted plant",
65: "bed",
67: "dining table",
70: "toilet",
72: "tv",
73: "laptop",
74: "mouse",
75: "remote",
76: "keyboard",
77: "cell phone",
78: "microwave",
79: "oven",
80: "toaster",
81: "sink",
82: "refrigerator",
84: "book",
85: "clock",
86: "vase",
87: "scissors",
88: "teddy bear",
89: "hair drier",
90: "toothbrush",
}
mscoco_category2label = {k: i for i, k in enumerate(mscoco_category2name.keys())}
mscoco_label2category = {v: k for k, v in mscoco_category2label.items()}