Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from multiprocessing.reduction import ForkingPickler | |
| from typing import Union | |
| import numpy as np | |
| import torch | |
| from mmengine.structures import BaseDataElement | |
| from .utils import LABEL_TYPE, SCORE_TYPE, format_label, format_score | |
| class DataSample(BaseDataElement): | |
| """A general data structure interface. | |
| It's used as the interface between different components. | |
| The following fields are convention names in MMPretrain, and we will set or | |
| get these fields in data transforms, models, and metrics if needed. You can | |
| also set any new fields for your need. | |
| Meta fields: | |
| img_shape (Tuple): The shape of the corresponding input image. | |
| ori_shape (Tuple): The original shape of the corresponding image. | |
| sample_idx (int): The index of the sample in the dataset. | |
| num_classes (int): The number of all categories. | |
| Data fields: | |
| gt_label (tensor): The ground truth label. | |
| gt_score (tensor): The ground truth score. | |
| pred_label (tensor): The predicted label. | |
| pred_score (tensor): The predicted score. | |
| mask (tensor): The mask used in masked image modeling. | |
| Examples: | |
| >>> import torch | |
| >>> from mmpretrain.structures import DataSample | |
| >>> | |
| >>> img_meta = dict(img_shape=(960, 720), num_classes=5) | |
| >>> data_sample = DataSample(metainfo=img_meta) | |
| >>> data_sample.set_gt_label(3) | |
| >>> print(data_sample) | |
| <DataSample( | |
| META INFORMATION | |
| num_classes: 5 | |
| img_shape: (960, 720) | |
| DATA FIELDS | |
| gt_label: tensor([3]) | |
| ) at 0x7ff64c1c1d30> | |
| >>> | |
| >>> # For multi-label data | |
| >>> data_sample = DataSample().set_gt_label([0, 1, 4]) | |
| >>> print(data_sample) | |
| <DataSample( | |
| DATA FIELDS | |
| gt_label: tensor([0, 1, 4]) | |
| ) at 0x7ff5b490e100> | |
| >>> | |
| >>> # Set one-hot format score | |
| >>> data_sample = DataSample().set_pred_score([0.1, 0.1, 0.6, 0.1]) | |
| >>> print(data_sample) | |
| <DataSample( | |
| META INFORMATION | |
| num_classes: 4 | |
| DATA FIELDS | |
| pred_score: tensor([0.1000, 0.1000, 0.6000, 0.1000]) | |
| ) at 0x7ff5b48ef6a0> | |
| >>> | |
| >>> # Set custom field | |
| >>> data_sample = DataSample() | |
| >>> data_sample.my_field = [1, 2, 3] | |
| >>> print(data_sample) | |
| <DataSample( | |
| DATA FIELDS | |
| my_field: [1, 2, 3] | |
| ) at 0x7f8e9603d3a0> | |
| >>> print(data_sample.my_field) | |
| [1, 2, 3] | |
| """ | |
| def set_gt_label(self, value: LABEL_TYPE) -> 'DataSample': | |
| """Set ``gt_label``.""" | |
| self.set_field(format_label(value), 'gt_label', dtype=torch.Tensor) | |
| return self | |
| def set_gt_score(self, value: SCORE_TYPE) -> 'DataSample': | |
| """Set ``gt_score``.""" | |
| score = format_score(value) | |
| self.set_field(score, 'gt_score', dtype=torch.Tensor) | |
| if hasattr(self, 'num_classes'): | |
| assert len(score) == self.num_classes, \ | |
| f'The length of score {len(score)} should be '\ | |
| f'equal to the num_classes {self.num_classes}.' | |
| else: | |
| self.set_field( | |
| name='num_classes', value=len(score), field_type='metainfo') | |
| return self | |
| def set_pred_label(self, value: LABEL_TYPE) -> 'DataSample': | |
| """Set ``pred_label``.""" | |
| self.set_field(format_label(value), 'pred_label', dtype=torch.Tensor) | |
| return self | |
| def set_pred_score(self, value: SCORE_TYPE): | |
| """Set ``pred_label``.""" | |
| score = format_score(value) | |
| self.set_field(score, 'pred_score', dtype=torch.Tensor) | |
| if hasattr(self, 'num_classes'): | |
| assert len(score) == self.num_classes, \ | |
| f'The length of score {len(score)} should be '\ | |
| f'equal to the num_classes {self.num_classes}.' | |
| else: | |
| self.set_field( | |
| name='num_classes', value=len(score), field_type='metainfo') | |
| return self | |
| def set_mask(self, value: Union[torch.Tensor, np.ndarray]): | |
| if isinstance(value, np.ndarray): | |
| value = torch.from_numpy(value) | |
| elif not isinstance(value, torch.Tensor): | |
| raise TypeError(f'Invalid mask type {type(value)}') | |
| self.set_field(value, 'mask', dtype=torch.Tensor) | |
| return self | |
| def __repr__(self) -> str: | |
| """Represent the object.""" | |
| def dump_items(items, prefix=''): | |
| return '\n'.join(f'{prefix}{k}: {v}' for k, v in items) | |
| repr_ = '' | |
| if len(self._metainfo_fields) > 0: | |
| repr_ += '\n\nMETA INFORMATION\n' | |
| repr_ += dump_items(self.metainfo_items(), prefix=' ' * 4) | |
| if len(self._data_fields) > 0: | |
| repr_ += '\n\nDATA FIELDS\n' | |
| repr_ += dump_items(self.items(), prefix=' ' * 4) | |
| repr_ = f'<{self.__class__.__name__}({repr_}\n\n) at {hex(id(self))}>' | |
| return repr_ | |
| def _reduce_datasample(data_sample): | |
| """reduce DataSample.""" | |
| attr_dict = data_sample.__dict__ | |
| convert_keys = [] | |
| for k, v in attr_dict.items(): | |
| if isinstance(v, torch.Tensor): | |
| attr_dict[k] = v.numpy() | |
| convert_keys.append(k) | |
| return _rebuild_datasample, (attr_dict, convert_keys) | |
| def _rebuild_datasample(attr_dict, convert_keys): | |
| """rebuild DataSample.""" | |
| data_sample = DataSample() | |
| for k in convert_keys: | |
| attr_dict[k] = torch.from_numpy(attr_dict[k]) | |
| data_sample.__dict__ = attr_dict | |
| return data_sample | |
| # Due to the multi-processing strategy of PyTorch, DataSample may consume many | |
| # file descriptors because it contains multiple tensors. Here we overwrite the | |
| # reduce function of DataSample in ForkingPickler and convert these tensors to | |
| # np.ndarray during pickling. It may slightly influence the performance of | |
| # dataloader. | |
| ForkingPickler.register(DataSample, _reduce_datasample) | |