marco.aversa commited on
Commit
4f8704e
·
1 Parent(s): 3fe7825

updated libraries

Browse files
model.py CHANGED
@@ -187,10 +187,6 @@ class TrackImagesCallback(pl.callbacks.base.Callback):
187
  self.callback_track_images(trainer.model, save_loc)
188
 
189
 
190
- from utils.debug import debug
191
-
192
-
193
- # @debug
194
  def log_tensor(batch, path, save_tensors=True, nrow=8):
195
  if save_tensors:
196
  torch.save(batch, path)
 
187
  self.callback_track_images(trainer.model, save_loc)
188
 
189
 
 
 
 
 
190
  def log_tensor(batch, path, save_tensors=True, nrow=8):
191
  if save_tensors:
192
  torch.save(batch, path)
processing/pipeline_torch.py CHANGED
@@ -10,8 +10,6 @@ from utils.base import np2torch, torch2np
10
 
11
  import segmentation_models_pytorch as smp
12
 
13
- from utils.debug import debug
14
-
15
  K_G = torch.Tensor([[0, 1, 0],
16
  [1, 4, 1],
17
  [0, 1, 0]]) / 4
 
10
 
11
  import segmentation_models_pytorch as smp
12
 
 
 
13
  K_G = torch.Tensor([[0, 1, 0],
14
  [1, 4, 1],
15
  [0, 1, 0]]) / 4
train.py CHANGED
@@ -16,7 +16,6 @@ import pytorch_lightning as pl
16
  from pytorch_lightning.callbacks import ModelCheckpoint
17
 
18
  from utils.base import AuxLoss, WeightedLoss, display_mlflow_run_info, l2_regularization, str2bool, fetch_from_mlflow, get_name, data_loader_mean_and_std
19
- from utils.debug import debug
20
  from utils.dataset_utils import k_fold
21
  from utils.augmentation import get_augmentation
22
  from dataset import Subset, get_dataset
 
16
  from pytorch_lightning.callbacks import ModelCheckpoint
17
 
18
  from utils.base import AuxLoss, WeightedLoss, display_mlflow_run_info, l2_regularization, str2bool, fetch_from_mlflow, get_name, data_loader_mean_and_std
 
19
  from utils.dataset_utils import k_fold
20
  from utils.augmentation import get_augmentation
21
  from dataset import Subset, get_dataset
utils/dataset_utils.py CHANGED
@@ -1,11 +1,18 @@
 
 
 
1
 
 
2
  import random
3
  import numpy as np
 
 
4
 
5
  import torch
6
 
7
  from skimage.util.shape import view_as_windows
8
 
 
9
 
10
  def load_image(path):
11
  file_type = path.split('.')[-1].lower()
 
1
+ """
2
+ Dataset Import/Download Tools
3
+ """
4
 
5
+ import os
6
  import random
7
  import numpy as np
8
+ import rawpy
9
+ from PIL import Image
10
 
11
  import torch
12
 
13
  from skimage.util.shape import view_as_windows
14
 
15
+ IMAGE_FILE_TYPES = ['dng', 'png', 'tif', 'tiff']
16
 
17
  def load_image(path):
18
  file_type = path.split('.')[-1].lower()