File size: 1,211 Bytes
2fdce3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import logging
from pathlib import Path
import shutil
import tempfile

from torchvision.datasets import MNIST

TEMPDIR = tempfile.gettempdir()


def setup_cached_mnist():
    done, tentatives = False, 0
    while not done and tentatives < 5:
        # Monkey patch the resource URLs to work around a possible blacklist
        MNIST.mirrors = ["https://github.com/blefaudeux/mnist_dataset/raw/main/"] + MNIST.mirrors

        # This will automatically skip the download if the dataset is already there, and check the checksum
        try:
            _ = MNIST(transform=None, download=True, root=TEMPDIR)
            done = True
        except RuntimeError as e:
            logging.warning(e)
            mnist_root = Path(TEMPDIR + "/MNIST")
            # Corrupted data, erase and restart
            shutil.rmtree(str(mnist_root))

        tentatives += 1

    if done is False:
        logging.error("Could not download MNIST dataset")
        exit(-1)
    else:
        logging.info("Dataset downloaded")