Spaces:
Sleeping
Sleeping
Update resnet.py
Browse files
resnet.py
CHANGED
@@ -184,12 +184,12 @@ class custom_ResNet(pl.LightningModule):
|
|
184 |
|
185 |
# Assign train/val datasets for use in dataloaders
|
186 |
if stage == "fit" or stage is None:
|
187 |
-
cifar_full = CIFAR10(self.data_dir, train=True, transform=self.train_transform)
|
188 |
self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
|
189 |
|
190 |
# Assign test dataset for use in dataloader(s)
|
191 |
if stage == "test" or stage is None:
|
192 |
-
self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.test_transform)
|
193 |
|
194 |
def train_dataloader(self):
|
195 |
return DataLoader(self.cifar_train, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
|
|
|
184 |
|
185 |
# Assign train/val datasets for use in dataloaders
|
186 |
if stage == "fit" or stage is None:
|
187 |
+
cifar_full = CIFAR10(self.data_dir, train=True, download=True, transform=self.train_transform)
|
188 |
self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
|
189 |
|
190 |
# Assign test dataset for use in dataloader(s)
|
191 |
if stage == "test" or stage is None:
|
192 |
+
self.cifar_test = CIFAR10(self.data_dir, train=False, download=True, transform=self.test_transform)
|
193 |
|
194 |
def train_dataloader(self):
|
195 |
return DataLoader(self.cifar_train, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
|