mkthoma commited on
Commit
54ebd81
·
1 Parent(s): c957a3d

Update resnet.py

Browse files
Files changed (1) hide show
  1. resnet.py +2 -2
resnet.py CHANGED
@@ -184,12 +184,12 @@ class custom_ResNet(pl.LightningModule):
184
 
185
  # Assign train/val datasets for use in dataloaders
186
  if stage == "fit" or stage is None:
187
- cifar_full = CIFAR10(self.data_dir, train=True, transform=self.train_transform)
188
  self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
189
 
190
  # Assign test dataset for use in dataloader(s)
191
  if stage == "test" or stage is None:
192
- self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.test_transform)
193
 
194
  def train_dataloader(self):
195
  return DataLoader(self.cifar_train, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
 
184
 
185
  # Assign train/val datasets for use in dataloaders
186
  if stage == "fit" or stage is None:
187
+ cifar_full = CIFAR10(self.data_dir, train=True, download=True, transform=self.train_transform)
188
  self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
189
 
190
  # Assign test dataset for use in dataloader(s)
191
  if stage == "test" or stage is None:
192
+ self.cifar_test = CIFAR10(self.data_dir, train=False, download=True, transform=self.test_transform)
193
 
194
  def train_dataloader(self):
195
  return DataLoader(self.cifar_train, batch_size=BATCH_SIZE, num_workers=os.cpu_count())