Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import codecs | |
| from typing import List, Optional | |
| from urllib.parse import urljoin | |
| import mmengine.dist as dist | |
| import numpy as np | |
| import torch | |
| from mmengine.fileio import LocalBackend, exists, get_file_backend, join_path | |
| from mmengine.logging import MMLogger | |
| from mmpretrain.registry import DATASETS | |
| from .base_dataset import BaseDataset | |
| from .categories import FASHIONMNIST_CATEGORITES, MNIST_CATEGORITES | |
| from .utils import (download_and_extract_archive, open_maybe_compressed_file, | |
| rm_suffix) | |
| class MNIST(BaseDataset): | |
| """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset. | |
| This implementation is modified from | |
| https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py | |
| Args: | |
| data_root (str): The root directory of the MNIST Dataset. | |
| split (str, optional): The dataset split, supports "train" and "test". | |
| Default to "train". | |
| metainfo (dict, optional): Meta information for dataset, such as | |
| categories information. Defaults to None. | |
| download (bool): Whether to download the dataset if not exists. | |
| Defaults to True. | |
| **kwargs: Other keyword arguments in :class:`BaseDataset`. | |
| """ # noqa: E501 | |
| url_prefix = 'http://yann.lecun.com/exdb/mnist/' | |
| # train images and labels | |
| train_list = [ | |
| ['train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'], | |
| ['train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'], | |
| ] | |
| # test images and labels | |
| test_list = [ | |
| ['t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'], | |
| ['t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c'], | |
| ] | |
| METAINFO = {'classes': MNIST_CATEGORITES} | |
| def __init__(self, | |
| data_root: str = '', | |
| split: str = 'train', | |
| metainfo: Optional[dict] = None, | |
| download: bool = True, | |
| data_prefix: str = '', | |
| test_mode: bool = False, | |
| **kwargs): | |
| splits = ['train', 'test'] | |
| assert split in splits, \ | |
| f"The split must be one of {splits}, but get '{split}'" | |
| self.split = split | |
| # To handle the BC-breaking | |
| if split == 'train' and test_mode: | |
| logger = MMLogger.get_current_instance() | |
| logger.warning('split="train" but test_mode=True. ' | |
| 'The training set will be used.') | |
| if not data_root and not data_prefix: | |
| raise RuntimeError('Please set ``data_root`` to' | |
| 'specify the dataset path') | |
| self.download = download | |
| super().__init__( | |
| # The MNIST dataset doesn't need specify annotation file | |
| ann_file='', | |
| metainfo=metainfo, | |
| data_root=data_root, | |
| data_prefix=dict(root=data_prefix), | |
| test_mode=test_mode, | |
| **kwargs) | |
| def load_data_list(self): | |
| """Load images and ground truth labels.""" | |
| root = self.data_prefix['root'] | |
| backend = get_file_backend(root, enable_singleton=True) | |
| if dist.is_main_process() and not self._check_exists(): | |
| if not isinstance(backend, LocalBackend): | |
| raise RuntimeError(f'The dataset on {root} is not integrated, ' | |
| f'please manually handle it.') | |
| if self.download: | |
| self._download() | |
| else: | |
| raise RuntimeError( | |
| f'Cannot find {self.__class__.__name__} dataset in ' | |
| f"{self.data_prefix['root']}, you can specify " | |
| '`download=True` to download automatically.') | |
| dist.barrier() | |
| assert self._check_exists(), \ | |
| 'Download failed or shared storage is unavailable. Please ' \ | |
| f'download the dataset manually through {self.url_prefix}.' | |
| if not self.test_mode: | |
| file_list = self.train_list | |
| else: | |
| file_list = self.test_list | |
| # load data from SN3 files | |
| imgs = read_image_file(join_path(root, rm_suffix(file_list[0][0]))) | |
| gt_labels = read_label_file( | |
| join_path(root, rm_suffix(file_list[1][0]))) | |
| data_infos = [] | |
| for img, gt_label in zip(imgs, gt_labels): | |
| gt_label = np.array(gt_label, dtype=np.int64) | |
| info = {'img': img.numpy(), 'gt_label': gt_label} | |
| data_infos.append(info) | |
| return data_infos | |
| def _check_exists(self): | |
| """Check the exists of data files.""" | |
| root = self.data_prefix['root'] | |
| for filename, _ in (self.train_list + self.test_list): | |
| # get extracted filename of data | |
| extract_filename = rm_suffix(filename) | |
| fpath = join_path(root, extract_filename) | |
| if not exists(fpath): | |
| return False | |
| return True | |
| def _download(self): | |
| """Download and extract data files.""" | |
| root = self.data_prefix['root'] | |
| for filename, md5 in (self.train_list + self.test_list): | |
| url = urljoin(self.url_prefix, filename) | |
| download_and_extract_archive( | |
| url, download_root=root, filename=filename, md5=md5) | |
| def extra_repr(self) -> List[str]: | |
| """The extra repr information of the dataset.""" | |
| body = [f"Prefix of data: \t{self.data_prefix['root']}"] | |
| return body | |
| class FashionMNIST(MNIST): | |
| """`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ | |
| Dataset. | |
| Args: | |
| data_root (str): The root directory of the MNIST Dataset. | |
| split (str, optional): The dataset split, supports "train" and "test". | |
| Default to "train". | |
| metainfo (dict, optional): Meta information for dataset, such as | |
| categories information. Defaults to None. | |
| download (bool): Whether to download the dataset if not exists. | |
| Defaults to True. | |
| **kwargs: Other keyword arguments in :class:`BaseDataset`. | |
| """ | |
| url_prefix = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' | |
| # train images and labels | |
| train_list = [ | |
| ['train-images-idx3-ubyte.gz', '8d4fb7e6c68d591d4c3dfef9ec88bf0d'], | |
| ['train-labels-idx1-ubyte.gz', '25c81989df183df01b3e8a0aad5dffbe'], | |
| ] | |
| # test images and labels | |
| test_list = [ | |
| ['t10k-images-idx3-ubyte.gz', 'bef4ecab320f06d8554ea6380940ec79'], | |
| ['t10k-labels-idx1-ubyte.gz', 'bb300cfdad3c16e7a12a480ee83cd310'], | |
| ] | |
| METAINFO = {'classes': FASHIONMNIST_CATEGORITES} | |
| def get_int(b: bytes) -> int: | |
| """Convert bytes to int.""" | |
| return int(codecs.encode(b, 'hex'), 16) | |
| def read_sn3_pascalvincent_tensor(path: str, | |
| strict: bool = True) -> torch.Tensor: | |
| """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx- | |
| io.lsh'). | |
| Argument may be a filename, compressed filename, or file object. | |
| """ | |
| # typemap | |
| if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'): | |
| read_sn3_pascalvincent_tensor.typemap = { | |
| 8: (torch.uint8, np.uint8, np.uint8), | |
| 9: (torch.int8, np.int8, np.int8), | |
| 11: (torch.int16, np.dtype('>i2'), 'i2'), | |
| 12: (torch.int32, np.dtype('>i4'), 'i4'), | |
| 13: (torch.float32, np.dtype('>f4'), 'f4'), | |
| 14: (torch.float64, np.dtype('>f8'), 'f8') | |
| } | |
| # read | |
| with open_maybe_compressed_file(path) as f: | |
| data = f.read() | |
| # parse | |
| magic = get_int(data[0:4]) | |
| nd = magic % 256 | |
| ty = magic // 256 | |
| assert nd >= 1 and nd <= 3 | |
| assert ty >= 8 and ty <= 14 | |
| m = read_sn3_pascalvincent_tensor.typemap[ty] | |
| s = [get_int(data[4 * (i + 1):4 * (i + 2)]) for i in range(nd)] | |
| parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1))) | |
| assert parsed.shape[0] == np.prod(s) or not strict | |
| return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s) | |
| def read_label_file(path: str) -> torch.Tensor: | |
| """Read labels from SN3 label file.""" | |
| with open(path, 'rb') as f: | |
| x = read_sn3_pascalvincent_tensor(f, strict=False) | |
| assert (x.dtype == torch.uint8) | |
| assert (x.ndimension() == 1) | |
| return x.long() | |
| def read_image_file(path: str) -> torch.Tensor: | |
| """Read images from SN3 image file.""" | |
| with open(path, 'rb') as f: | |
| x = read_sn3_pascalvincent_tensor(f, strict=False) | |
| assert (x.dtype == torch.uint8) | |
| assert (x.ndimension() == 3) | |
| return x | |