Spaces:
Build error
Build error
| import os.path as osp | |
| import xml.etree.ElementTree as ET | |
| import mmcv | |
| import numpy as np | |
| from PIL import Image | |
| from .builder import DATASETS | |
| from .custom import CustomDataset | |
| class XMLDataset(CustomDataset): | |
| """XML dataset for detection. | |
| Args: | |
| min_size (int | float, optional): The minimum size of bounding | |
| boxes in the images. If the size of a bounding box is less than | |
| ``min_size``, it would be add to ignored field. | |
| """ | |
| def __init__(self, min_size=None, **kwargs): | |
| assert self.CLASSES or kwargs.get( | |
| 'classes', None), 'CLASSES in `XMLDataset` can not be None.' | |
| super(XMLDataset, self).__init__(**kwargs) | |
| self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)} | |
| self.min_size = min_size | |
| def load_annotations(self, ann_file): | |
| """Load annotation from XML style ann_file. | |
| Args: | |
| ann_file (str): Path of XML file. | |
| Returns: | |
| list[dict]: Annotation info from XML file. | |
| """ | |
| data_infos = [] | |
| img_ids = mmcv.list_from_file(ann_file) | |
| for img_id in img_ids: | |
| filename = f'JPEGImages/{img_id}.jpg' | |
| xml_path = osp.join(self.img_prefix, 'Annotations', | |
| f'{img_id}.xml') | |
| tree = ET.parse(xml_path) | |
| root = tree.getroot() | |
| size = root.find('size') | |
| if size is not None: | |
| width = int(size.find('width').text) | |
| height = int(size.find('height').text) | |
| else: | |
| img_path = osp.join(self.img_prefix, 'JPEGImages', | |
| '{}.jpg'.format(img_id)) | |
| img = Image.open(img_path) | |
| width, height = img.size | |
| data_infos.append( | |
| dict(id=img_id, filename=filename, width=width, height=height)) | |
| return data_infos | |
| def _filter_imgs(self, min_size=32): | |
| """Filter images too small or without annotation.""" | |
| valid_inds = [] | |
| for i, img_info in enumerate(self.data_infos): | |
| if min(img_info['width'], img_info['height']) < min_size: | |
| continue | |
| if self.filter_empty_gt: | |
| img_id = img_info['id'] | |
| xml_path = osp.join(self.img_prefix, 'Annotations', | |
| f'{img_id}.xml') | |
| tree = ET.parse(xml_path) | |
| root = tree.getroot() | |
| for obj in root.findall('object'): | |
| name = obj.find('name').text | |
| if name in self.CLASSES: | |
| valid_inds.append(i) | |
| break | |
| else: | |
| valid_inds.append(i) | |
| return valid_inds | |
| def get_ann_info(self, idx): | |
| """Get annotation from XML file by index. | |
| Args: | |
| idx (int): Index of data. | |
| Returns: | |
| dict: Annotation info of specified index. | |
| """ | |
| img_id = self.data_infos[idx]['id'] | |
| xml_path = osp.join(self.img_prefix, 'Annotations', f'{img_id}.xml') | |
| tree = ET.parse(xml_path) | |
| root = tree.getroot() | |
| bboxes = [] | |
| labels = [] | |
| bboxes_ignore = [] | |
| labels_ignore = [] | |
| for obj in root.findall('object'): | |
| name = obj.find('name').text | |
| if name not in self.CLASSES: | |
| continue | |
| label = self.cat2label[name] | |
| difficult = obj.find('difficult') | |
| difficult = 0 if difficult is None else int(difficult.text) | |
| bnd_box = obj.find('bndbox') | |
| # TODO: check whether it is necessary to use int | |
| # Coordinates may be float type | |
| bbox = [ | |
| int(float(bnd_box.find('xmin').text)), | |
| int(float(bnd_box.find('ymin').text)), | |
| int(float(bnd_box.find('xmax').text)), | |
| int(float(bnd_box.find('ymax').text)) | |
| ] | |
| ignore = False | |
| if self.min_size: | |
| assert not self.test_mode | |
| w = bbox[2] - bbox[0] | |
| h = bbox[3] - bbox[1] | |
| if w < self.min_size or h < self.min_size: | |
| ignore = True | |
| if difficult or ignore: | |
| bboxes_ignore.append(bbox) | |
| labels_ignore.append(label) | |
| else: | |
| bboxes.append(bbox) | |
| labels.append(label) | |
| if not bboxes: | |
| bboxes = np.zeros((0, 4)) | |
| labels = np.zeros((0, )) | |
| else: | |
| bboxes = np.array(bboxes, ndmin=2) - 1 | |
| labels = np.array(labels) | |
| if not bboxes_ignore: | |
| bboxes_ignore = np.zeros((0, 4)) | |
| labels_ignore = np.zeros((0, )) | |
| else: | |
| bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1 | |
| labels_ignore = np.array(labels_ignore) | |
| ann = dict( | |
| bboxes=bboxes.astype(np.float32), | |
| labels=labels.astype(np.int64), | |
| bboxes_ignore=bboxes_ignore.astype(np.float32), | |
| labels_ignore=labels_ignore.astype(np.int64)) | |
| return ann | |
| def get_cat_ids(self, idx): | |
| """Get category ids in XML file by index. | |
| Args: | |
| idx (int): Index of data. | |
| Returns: | |
| list[int]: All categories in the image of specified index. | |
| """ | |
| cat_ids = [] | |
| img_id = self.data_infos[idx]['id'] | |
| xml_path = osp.join(self.img_prefix, 'Annotations', f'{img_id}.xml') | |
| tree = ET.parse(xml_path) | |
| root = tree.getroot() | |
| for obj in root.findall('object'): | |
| name = obj.find('name').text | |
| if name not in self.CLASSES: | |
| continue | |
| label = self.cat2label[name] | |
| cat_ids.append(label) | |
| return cat_ids | |