Mojo
commited on
Commit
·
cb87fb1
1
Parent(s):
ff213a4
Added new files
Browse files- utilities/dataset.py +52 -21
utilities/dataset.py
CHANGED
@@ -1,9 +1,7 @@
|
|
|
|
1 |
import pytorch_lightning as pl
|
2 |
-
from torch.utils.data import DataLoader
|
3 |
-
from torchvision.datasets import CIFAR10
|
4 |
-
from torchvision import transforms
|
5 |
import torch
|
6 |
-
|
7 |
|
8 |
|
9 |
class CIFAR10(torch.utils.data.Dataset):
|
@@ -29,33 +27,66 @@ class CIFAR10(torch.utils.data.Dataset):
|
|
29 |
|
30 |
return (image, label)
|
31 |
|
|
|
32 |
class CIFAR10DataModule(pl.LightningDataModule):
|
33 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
super().__init__()
|
|
|
35 |
self.data_dir = data_dir
|
36 |
self.batch_size = batch_size
|
37 |
self.num_workers = num_workers
|
38 |
-
self.
|
39 |
-
self.
|
|
|
|
|
|
|
40 |
|
41 |
def prepare_data(self):
|
42 |
-
|
43 |
-
CIFAR10(self.data_dir, train=
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
53 |
|
54 |
def train_dataloader(self):
|
55 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
def val_dataloader(self):
|
58 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
def test_dataloader(self):
|
61 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
import pytorch_lightning as pl
|
|
|
|
|
|
|
3 |
import torch
|
4 |
+
from torchvision import datasets
|
5 |
|
6 |
|
7 |
class CIFAR10(torch.utils.data.Dataset):
|
|
|
27 |
|
28 |
return (image, label)
|
29 |
|
30 |
+
|
31 |
class CIFAR10DataModule(pl.LightningDataModule):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
train_transforms,
|
35 |
+
val_transforms,
|
36 |
+
shuffle=True,
|
37 |
+
data_dir="../data",
|
38 |
+
batch_size=64,
|
39 |
+
num_workers=-1,
|
40 |
+
pin_memory=True,
|
41 |
+
):
|
42 |
super().__init__()
|
43 |
+
self.shuffle = shuffle
|
44 |
self.data_dir = data_dir
|
45 |
self.batch_size = batch_size
|
46 |
self.num_workers = num_workers
|
47 |
+
self.pin_memory = pin_memory
|
48 |
+
self.train_transforms = train_transforms
|
49 |
+
self.val_transforms = val_transforms
|
50 |
+
self.train_data = None
|
51 |
+
self.val_data = None
|
52 |
|
53 |
def prepare_data(self):
|
54 |
+
datasets.CIFAR10(self.data_dir, train=True, download=True)
|
55 |
+
datasets.CIFAR10(self.data_dir, train=False, download=True)
|
56 |
+
|
57 |
+
def setup(self, stage):
|
58 |
+
self.train_data = CIFAR10(
|
59 |
+
datasets.CIFAR10(root=self.data_dir, train=True, download=False),
|
60 |
+
transform=self.train_transforms,
|
61 |
+
)
|
62 |
+
self.val_data = CIFAR10(
|
63 |
+
datasets.CIFAR10(root=self.data_dir, train=False, download=False),
|
64 |
+
transform=self.val_transforms,
|
65 |
+
)
|
66 |
|
67 |
def train_dataloader(self):
|
68 |
+
return torch.utils.data.DataLoader(
|
69 |
+
self.train_data,
|
70 |
+
batch_size=self.batch_size,
|
71 |
+
shuffle=self.shuffle,
|
72 |
+
num_workers=self.num_workers,
|
73 |
+
pin_memory=self.pin_memory,
|
74 |
+
)
|
75 |
|
76 |
def val_dataloader(self):
|
77 |
+
return torch.utils.data.DataLoader(
|
78 |
+
self.val_data,
|
79 |
+
batch_size=self.batch_size,
|
80 |
+
shuffle=False,
|
81 |
+
num_workers=self.num_workers,
|
82 |
+
pin_memory=self.pin_memory,
|
83 |
+
)
|
84 |
|
85 |
def test_dataloader(self):
|
86 |
+
return torch.utils.data.DataLoader(
|
87 |
+
self.val_data,
|
88 |
+
batch_size=self.batch_size,
|
89 |
+
shuffle=False,
|
90 |
+
num_workers=self.num_workers,
|
91 |
+
pin_memory=self.pin_memory,
|
92 |
+
)
|