Mojo commited on
Commit
cb87fb1
·
1 Parent(s): ff213a4

Added new files

Browse files
Files changed (1) hide show
  1. utilities/dataset.py +52 -21
utilities/dataset.py CHANGED
@@ -1,9 +1,7 @@
 
1
  import pytorch_lightning as pl
2
- from torch.utils.data import DataLoader
3
- from torchvision.datasets import CIFAR10
4
- from torchvision import transforms
5
  import torch
6
- import numpy as np
7
 
8
 
9
  class CIFAR10(torch.utils.data.Dataset):
@@ -29,33 +27,66 @@ class CIFAR10(torch.utils.data.Dataset):
29
 
30
  return (image, label)
31
 
 
32
  class CIFAR10DataModule(pl.LightningDataModule):
33
- def __init__(self, train_set_transforms,test_set_transforms, data_dir: str = "./data",batch_size: int = 64, num_workers: int = 4):
 
 
 
 
 
 
 
 
 
34
  super().__init__()
 
35
  self.data_dir = data_dir
36
  self.batch_size = batch_size
37
  self.num_workers = num_workers
38
- self.train_set_transforms =train_set_transforms
39
- self.test_set_transforms = test_set_transforms
 
 
 
40
 
41
  def prepare_data(self):
42
- # Download the CIFAR10 dataset
43
- CIFAR10(self.data_dir, train=True, download=True)
44
- CIFAR10(self.data_dir, train=False, download=True)
45
-
46
- def setup(self, stage: str = None):
47
- # Load the dataset
48
- if stage == "fit" or stage is None:
49
- self.cifar10_train = CIFAR10(self.data_dir, train=True, transform=self.train_set_transforms)
50
- self.cifar10_val = CIFAR10(self.data_dir, train=False, transform=self.train_set_transforms)
51
- if stage == "test" or stage is None:
52
- self.cifar10_test = CIFAR10(self.data_dir, train=False, transform=self.test_set_transforms)
 
53
 
54
  def train_dataloader(self):
55
- return DataLoader(self.cifar10_train, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
 
 
 
 
 
 
56
 
57
  def val_dataloader(self):
58
- return DataLoader(self.cifar10_val, batch_size=self.batch_size, num_workers=self.num_workers)
 
 
 
 
 
 
59
 
60
  def test_dataloader(self):
61
- return DataLoader(self.cifar10_test, batch_size=self.batch_size, num_workers=self.num_workers)
 
 
 
 
 
 
 
1
+ import numpy as np
2
  import pytorch_lightning as pl
 
 
 
3
  import torch
4
+ from torchvision import datasets
5
 
6
 
7
  class CIFAR10(torch.utils.data.Dataset):
 
27
 
28
  return (image, label)
29
 
30
+
31
  class CIFAR10DataModule(pl.LightningDataModule):
32
+ def __init__(
33
+ self,
34
+ train_transforms,
35
+ val_transforms,
36
+ shuffle=True,
37
+ data_dir="../data",
38
+ batch_size=64,
39
+ num_workers=-1,
40
+ pin_memory=True,
41
+ ):
42
  super().__init__()
43
+ self.shuffle = shuffle
44
  self.data_dir = data_dir
45
  self.batch_size = batch_size
46
  self.num_workers = num_workers
47
+ self.pin_memory = pin_memory
48
+ self.train_transforms = train_transforms
49
+ self.val_transforms = val_transforms
50
+ self.train_data = None
51
+ self.val_data = None
52
 
53
  def prepare_data(self):
54
+ datasets.CIFAR10(self.data_dir, train=True, download=True)
55
+ datasets.CIFAR10(self.data_dir, train=False, download=True)
56
+
57
+ def setup(self, stage):
58
+ self.train_data = CIFAR10(
59
+ datasets.CIFAR10(root=self.data_dir, train=True, download=False),
60
+ transform=self.train_transforms,
61
+ )
62
+ self.val_data = CIFAR10(
63
+ datasets.CIFAR10(root=self.data_dir, train=False, download=False),
64
+ transform=self.val_transforms,
65
+ )
66
 
67
  def train_dataloader(self):
68
+ return torch.utils.data.DataLoader(
69
+ self.train_data,
70
+ batch_size=self.batch_size,
71
+ shuffle=self.shuffle,
72
+ num_workers=self.num_workers,
73
+ pin_memory=self.pin_memory,
74
+ )
75
 
76
  def val_dataloader(self):
77
+ return torch.utils.data.DataLoader(
78
+ self.val_data,
79
+ batch_size=self.batch_size,
80
+ shuffle=False,
81
+ num_workers=self.num_workers,
82
+ pin_memory=self.pin_memory,
83
+ )
84
 
85
  def test_dataloader(self):
86
+ return torch.utils.data.DataLoader(
87
+ self.val_data,
88
+ batch_size=self.batch_size,
89
+ shuffle=False,
90
+ num_workers=self.num_workers,
91
+ pin_memory=self.pin_memory,
92
+ )