Spaces:
Running
Running
Add support for MNIST
Browse files- dataset_tool.py +44 -11
dataset_tool.py
CHANGED
|
@@ -13,6 +13,7 @@ import os
|
|
| 13 |
import pickle
|
| 14 |
import sys
|
| 15 |
import tarfile
|
|
|
|
| 16 |
import zipfile
|
| 17 |
from pathlib import Path
|
| 18 |
from typing import Callable, Optional, Tuple, Union
|
|
@@ -165,6 +166,36 @@ def open_cifar10(tarball: str, *, max_images: Optional[int]):
|
|
| 165 |
|
| 166 |
#----------------------------------------------------------------------------
|
| 167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
def make_transform(
|
| 169 |
transform: Optional[str],
|
| 170 |
output_width: Optional[int],
|
|
@@ -225,10 +256,11 @@ def open_dataset(source, *, max_images: Optional[int]):
|
|
| 225 |
else:
|
| 226 |
return open_image_folder(source, max_images=max_images)
|
| 227 |
elif os.path.isfile(source):
|
| 228 |
-
if
|
| 229 |
return open_cifar10(source, max_images=max_images)
|
| 230 |
-
|
| 231 |
-
|
|
|
|
| 232 |
return open_image_zip(source, max_images=max_images)
|
| 233 |
else:
|
| 234 |
assert False, 'unknown archive type'
|
|
@@ -293,17 +325,18 @@ def convert_dataset(
|
|
| 293 |
The input dataset format is guessed from the --source argument:
|
| 294 |
|
| 295 |
\b
|
| 296 |
-
--source *_lmdb/
|
| 297 |
-
--source cifar-10-python.tar.gz
|
| 298 |
-
--source
|
| 299 |
-
--source
|
|
|
|
| 300 |
|
| 301 |
-
The output dataset format can be either an image folder or a zip archive.
|
| 302 |
-
the output format and path:
|
| 303 |
|
| 304 |
\b
|
| 305 |
-
--dest /path/to/dir
|
| 306 |
-
--dest /path/to/dataset.zip
|
| 307 |
|
| 308 |
Images within the dataset archive will be stored as uncompressed PNG.
|
| 309 |
|
|
|
|
| 13 |
import pickle
|
| 14 |
import sys
|
| 15 |
import tarfile
|
| 16 |
+
import gzip
|
| 17 |
import zipfile
|
| 18 |
from pathlib import Path
|
| 19 |
from typing import Callable, Optional, Tuple, Union
|
|
|
|
| 166 |
|
| 167 |
#----------------------------------------------------------------------------
|
| 168 |
|
| 169 |
+
def open_mnist(images_gz: str, *, max_images: Optional[int]):
|
| 170 |
+
labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
|
| 171 |
+
assert labels_gz != images_gz
|
| 172 |
+
images = []
|
| 173 |
+
labels = []
|
| 174 |
+
|
| 175 |
+
with gzip.open(images_gz, 'rb') as f:
|
| 176 |
+
images = np.frombuffer(f.read(), np.uint8, offset=16)
|
| 177 |
+
with gzip.open(labels_gz, 'rb') as f:
|
| 178 |
+
labels = np.frombuffer(f.read(), np.uint8, offset=8)
|
| 179 |
+
|
| 180 |
+
images = images.reshape(-1, 28, 28)
|
| 181 |
+
images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
|
| 182 |
+
assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
|
| 183 |
+
assert labels.shape == (60000,) and labels.dtype == np.uint8
|
| 184 |
+
assert np.min(images) == 0 and np.max(images) == 255
|
| 185 |
+
assert np.min(labels) == 0 and np.max(labels) == 9
|
| 186 |
+
|
| 187 |
+
max_idx = maybe_min(len(images), max_images)
|
| 188 |
+
|
| 189 |
+
def iterate_images():
|
| 190 |
+
for idx, img in enumerate(images):
|
| 191 |
+
yield dict(img=img, label=int(labels[idx]))
|
| 192 |
+
if idx >= max_idx-1:
|
| 193 |
+
break
|
| 194 |
+
|
| 195 |
+
return max_idx, iterate_images()
|
| 196 |
+
|
| 197 |
+
#----------------------------------------------------------------------------
|
| 198 |
+
|
| 199 |
def make_transform(
|
| 200 |
transform: Optional[str],
|
| 201 |
output_width: Optional[int],
|
|
|
|
| 256 |
else:
|
| 257 |
return open_image_folder(source, max_images=max_images)
|
| 258 |
elif os.path.isfile(source):
|
| 259 |
+
if os.path.basename(source) == 'cifar-10-python.tar.gz':
|
| 260 |
return open_cifar10(source, max_images=max_images)
|
| 261 |
+
elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
|
| 262 |
+
return open_mnist(source, max_images=max_images)
|
| 263 |
+
elif file_ext(source) == 'zip':
|
| 264 |
return open_image_zip(source, max_images=max_images)
|
| 265 |
else:
|
| 266 |
assert False, 'unknown archive type'
|
|
|
|
| 325 |
The input dataset format is guessed from the --source argument:
|
| 326 |
|
| 327 |
\b
|
| 328 |
+
--source *_lmdb/ Load LSUN dataset
|
| 329 |
+
--source cifar-10-python.tar.gz Load CIFAR-10 dataset
|
| 330 |
+
--source train-images-idx3-ubyte.gz Load MNIST dataset
|
| 331 |
+
--source path/ Recursively load all images from path/
|
| 332 |
+
--source dataset.zip Recursively load all images from dataset.zip
|
| 333 |
|
| 334 |
+
The output dataset format can be either an image folder or a zip archive.
|
| 335 |
+
Specifying the output format and path:
|
| 336 |
|
| 337 |
\b
|
| 338 |
+
--dest /path/to/dir Save output files under /path/to/dir
|
| 339 |
+
--dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
|
| 340 |
|
| 341 |
Images within the dataset archive will be stored as uncompressed PNG.
|
| 342 |
|