mkthoma commited on
Commit
1edae1c
·
1 Parent(s): b079c9c

Update resnet.py

Browse files
Files changed (1) hide show
  1. resnet.py +4 -0
resnet.py CHANGED
@@ -23,6 +23,10 @@ import torchvision.datasets as datasets
23
  import pytorch_lightning as pl
24
  import matplotlib.pyplot as plt
25
 
 
 
 
 
26
  # Model
27
  class custom_ResNet(pl.LightningModule):
28
  def __init__(self, data_dir=PATH_DATASETS, learning_rate=2e-4):
 
23
  import pytorch_lightning as pl
24
  import matplotlib.pyplot as plt
25
 
26
+
27
+ PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
28
+ BATCH_SIZE = 256 if AVAIL_GPUS else 64
29
+
30
  # Model
31
  class custom_ResNet(pl.LightningModule):
32
  def __init__(self, data_dir=PATH_DATASETS, learning_rate=2e-4):