|
import copy |
|
import os |
|
|
|
import torch |
|
import torch.utils.data |
|
import torchvision |
|
from PIL import Image |
|
from pycocotools import mask as coco_mask |
|
from transforms import Compose |
|
|
|
|
|
class FilterAndRemapCocoCategories: |
|
def __init__(self, categories, remap=True): |
|
self.categories = categories |
|
self.remap = remap |
|
|
|
def __call__(self, image, anno): |
|
anno = [obj for obj in anno if obj["category_id"] in self.categories] |
|
if not self.remap: |
|
return image, anno |
|
anno = copy.deepcopy(anno) |
|
for obj in anno: |
|
obj["category_id"] = self.categories.index(obj["category_id"]) |
|
return image, anno |
|
|
|
|
|
def convert_coco_poly_to_mask(segmentations, height, width): |
|
masks = [] |
|
for polygons in segmentations: |
|
rles = coco_mask.frPyObjects(polygons, height, width) |
|
mask = coco_mask.decode(rles) |
|
if len(mask.shape) < 3: |
|
mask = mask[..., None] |
|
mask = torch.as_tensor(mask, dtype=torch.uint8) |
|
mask = mask.any(dim=2) |
|
masks.append(mask) |
|
if masks: |
|
masks = torch.stack(masks, dim=0) |
|
else: |
|
masks = torch.zeros((0, height, width), dtype=torch.uint8) |
|
return masks |
|
|
|
|
|
class ConvertCocoPolysToMask: |
|
def __call__(self, image, anno): |
|
w, h = image.size |
|
segmentations = [obj["segmentation"] for obj in anno] |
|
cats = [obj["category_id"] for obj in anno] |
|
if segmentations: |
|
masks = convert_coco_poly_to_mask(segmentations, h, w) |
|
cats = torch.as_tensor(cats, dtype=masks.dtype) |
|
|
|
|
|
target, _ = (masks * cats[:, None, None]).max(dim=0) |
|
|
|
target[masks.sum(0) > 1] = 255 |
|
else: |
|
target = torch.zeros((h, w), dtype=torch.uint8) |
|
target = Image.fromarray(target.numpy()) |
|
return image, target |
|
|
|
|
|
def _coco_remove_images_without_annotations(dataset, cat_list=None): |
|
def _has_valid_annotation(anno): |
|
|
|
if len(anno) == 0: |
|
return False |
|
|
|
return sum(obj["area"] for obj in anno) > 1000 |
|
|
|
if not isinstance(dataset, torchvision.datasets.CocoDetection): |
|
raise TypeError( |
|
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}" |
|
) |
|
|
|
ids = [] |
|
for ds_idx, img_id in enumerate(dataset.ids): |
|
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) |
|
anno = dataset.coco.loadAnns(ann_ids) |
|
if cat_list: |
|
anno = [obj for obj in anno if obj["category_id"] in cat_list] |
|
if _has_valid_annotation(anno): |
|
ids.append(ds_idx) |
|
|
|
dataset = torch.utils.data.Subset(dataset, ids) |
|
return dataset |
|
|
|
|
|
def get_coco(root, image_set, transforms): |
|
PATHS = { |
|
"train": ("train2017", os.path.join("annotations", "instances_train2017.json")), |
|
"val": ("val2017", os.path.join("annotations", "instances_val2017.json")), |
|
|
|
} |
|
CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72] |
|
|
|
transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms]) |
|
|
|
img_folder, ann_file = PATHS[image_set] |
|
img_folder = os.path.join(root, img_folder) |
|
ann_file = os.path.join(root, ann_file) |
|
|
|
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) |
|
|
|
if image_set == "train": |
|
dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST) |
|
|
|
return dataset |
|
|