willis commited on
Commit
712050d
·
1 Parent(s): 0b2696b
Files changed (4) hide show
  1. model.py +8 -9
  2. processing/pipeline_torch.py +41 -6
  3. train.py +93 -132
  4. utils/base.py +65 -24
model.py CHANGED
@@ -3,7 +3,7 @@ from collections import defaultdict
3
 
4
  import torch
5
  import torch.optim
6
- from torchvision.models import resnet18
7
  from torchvision.utils import make_grid, save_image
8
  import torch.nn.functional as F
9
 
@@ -12,10 +12,13 @@ import pytorch_lightning as pl
12
  import mlflow.pytorch
13
 
14
 
15
- def resnet_model(model=resnet18, pretrained=True, in_channels=3, fc_out_features=2):
16
- resnet = model(pretrained=pretrained)
17
- # if not pretrained: # TODO: add case for in_channels=4
18
- # resnet.conv1 = torch.nn.Conv2d(channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
 
 
 
19
  resnet.fc = torch.nn.Linear(in_features=512, out_features=fc_out_features, bias=True)
20
  return resnet
21
 
@@ -81,10 +84,6 @@ class LitModel(pl.LightningModule):
81
 
82
  def update_step(self, batch, step_name):
83
  x, y = batch
84
- # debug(self.processor)
85
- # debug(self.processor.parameters())
86
- # debug.pause()
87
- # print('type', type(self.processor).__name__)
88
 
89
  logits = self(x)
90
 
 
3
 
4
  import torch
5
  import torch.optim
6
+ from torchvision.models import resnet18, resnet34, resnet50
7
  from torchvision.utils import make_grid, save_image
8
  import torch.nn.functional as F
9
 
 
12
  import mlflow.pytorch
13
 
14
 
15
+ def resnet_model(model='resnet18', pretrained=True, in_channels=3, fc_out_features=2):
16
+ if model.lower() == 'resnet18':
17
+ resnet = resnet18(pretrained=pretrained)
18
+ if model.lower() == 'resnet34':
19
+ resnet = resnet34(pretrained=pretrained)
20
+ if model.lower() == 'resnet50':
21
+ resnet = resnet50(pretrained=pretrained)
22
  resnet.fc = torch.nn.Linear(in_features=512, out_features=fc_out_features, bias=True)
23
  return resnet
24
 
 
84
 
85
  def update_step(self, batch, step_name):
86
  x, y = batch
 
 
 
 
87
 
88
  logits = self(x)
89
 
processing/pipeline_torch.py CHANGED
@@ -43,6 +43,18 @@ DEFAULT_CAMERA_PARAMS = (
43
 
44
 
45
  class RawToRGB(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
46
  def __init__(self, reduce_size=True, out_channels=3, track_stages=False, normalize_mosaic=None):
47
  super().__init__()
48
  self.stages = None
@@ -71,6 +83,14 @@ class RawToRGB(nn.Module):
71
 
72
 
73
  class NNProcessing(nn.Module):
 
 
 
 
 
 
 
 
74
  def __init__(self, track_stages=False, normalize_mosaic=None, batch_norm_output=True):
75
  super().__init__()
76
  self.stages = None
@@ -89,7 +109,7 @@ class NNProcessing(nn.Module):
89
  def forward(self, raw):
90
  self.stages = {}
91
  self.buffer = {}
92
- # self.stages['raw'] = raw
93
  rgb = raw2rgb(raw)
94
  if self.normalize_mosaic:
95
  rgb = self.normalize_mosaic(rgb)
@@ -108,18 +128,29 @@ class NNProcessing(nn.Module):
108
  return rgb
109
 
110
 
111
- def add_additive_layer(processor):
112
  processor.additive_layer = nn.Parameter(torch.zeros((1, 3, 256, 256)))
113
  # processor.additive_layer = nn.Parameter(0.001 * torch.randn((1, 3, 256, 256)))
114
 
115
 
116
  class ParametrizedProcessing(nn.Module):
117
- def __init__(self, camera_parameters, track_stages=False, batch_norm_output=True):
 
 
 
 
 
 
 
 
118
  super().__init__()
119
  self.stages = None
120
  self.buffer = None
121
  self.track_stages = track_stages
122
 
 
 
 
123
  black_level, white_balance, colour_matrix = camera_parameters
124
 
125
  self.black_level = nn.Parameter(torch.as_tensor(black_level))
@@ -197,8 +228,11 @@ class ParametrizedProcessing(nn.Module):
197
 
198
 
199
  class Debayer(nn.Conv2d):
 
 
 
200
  def __init__(self):
201
- super().__init__(3, 3, kernel_size=3, padding=1, padding_mode='reflect', bias=False) # default_pipeline uses 'replicate'
202
  self.weight.data.fill_(0)
203
  self.weight.data[0, 0] = K_RB.clone()
204
  self.weight.data[1, 1] = K_G.clone()
@@ -206,15 +240,16 @@ class Debayer(nn.Conv2d):
206
 
207
 
208
  def raw2rgb(raw, black_level=None, reduce_size=True, out_channels=3):
209
- """transform raw image with 1 channel to rgb with 3 channels
 
210
  Args:
211
  raw (Tensor): raw Tensor of shape (B, H, W)
212
  black_level (iterable, optional): RGGB black level values to subtract
213
  reduce_size (bool, optional): if False, the output image will have the same height and width
214
  as the raw input, i.e. (B, C, H, W), empty values are filled with zeros.
215
  if True, the output dimensions are reduced by half (B, C, H//2, W//2),
216
- the two green channels are averaged.
217
  out_channels (int, optional): number of output channels. One of {3, 4}.
 
218
  """
219
  assert out_channels in [3, 4]
220
  if black_level is None:
 
43
 
44
 
45
  class RawToRGB(nn.Module):
46
+ """transforms a raw image with 1 channel to rgb with 3 channels
47
+
48
+ Args:
49
+ reduce_size (bool, optional): if False, the output image will have the same height and width
50
+ as the raw input, i.e. (B, C, H, W), empty values are filled with zeros.
51
+ if True, the output dimensions are reduced by half (B, C, H//2, W//2),
52
+ out_channels (int, optional): number of output channels. One of {3, 4}.
53
+ for 3 channels, the two green channels are averaged.
54
+ track_stages (bool, optional): whether or not to retain intermediary steps in processing
55
+ normalize_mosaic (function, optional): applies normalization transformation to rgb image
56
+ """
57
+
58
  def __init__(self, reduce_size=True, out_channels=3, track_stages=False, normalize_mosaic=None):
59
  super().__init__()
60
  self.stages = None
 
83
 
84
 
85
  class NNProcessing(nn.Module):
86
+ """Transforms raw images to processed rgb via a segmentation Unet
87
+
88
+ Args:
89
+ track_stages (bool, optional): whether or not to retain intermediary steps in processing
90
+ normalize_mosaic (function, optional): applies normalization transformation to rgb image
91
+ batch_norm_output (bool, optional): adds a BatchNorm layer to the end of the processing
92
+ """
93
+
94
  def __init__(self, track_stages=False, normalize_mosaic=None, batch_norm_output=True):
95
  super().__init__()
96
  self.stages = None
 
109
  def forward(self, raw):
110
  self.stages = {}
111
  self.buffer = {}
112
+
113
  rgb = raw2rgb(raw)
114
  if self.normalize_mosaic:
115
  rgb = self.normalize_mosaic(rgb)
 
128
  return rgb
129
 
130
 
131
+ def append_additive_layer(processor):
132
  processor.additive_layer = nn.Parameter(torch.zeros((1, 3, 256, 256)))
133
  # processor.additive_layer = nn.Parameter(0.001 * torch.randn((1, 3, 256, 256)))
134
 
135
 
136
  class ParametrizedProcessing(nn.Module):
137
+ """Differentiable processing pipeline via torch transformations
138
+
139
+ Args:
140
+ camera_parameters (tuple(list), optional): applies given camera parameters in processing
141
+ track_stages (bool, optional): whether or not to retain intermediary steps in processing
142
+ batch_norm_output (bool, optional): adds a BatchNorm layer to the end of the processing
143
+ """
144
+
145
+ def __init__(self, camera_parameters=None, track_stages=False, batch_norm_output=True):
146
  super().__init__()
147
  self.stages = None
148
  self.buffer = None
149
  self.track_stages = track_stages
150
 
151
+ if camera_parameters is None:
152
+ camera_parameters = DEFAULT_CAMERA_PARAMS
153
+
154
  black_level, white_balance, colour_matrix = camera_parameters
155
 
156
  self.black_level = nn.Parameter(torch.as_tensor(black_level))
 
228
 
229
 
230
  class Debayer(nn.Conv2d):
231
+ """Separates the mosaiced raw image into its channels and interpolates bilinearly. Output is of same size as input.
232
+ """
233
+
234
  def __init__(self):
235
+ super().__init__(3, 3, kernel_size=3, padding=1, padding_mode='reflect', bias=False) # pipeline_numpy uses 'replicate'
236
  self.weight.data.fill_(0)
237
  self.weight.data[0, 0] = K_RB.clone()
238
  self.weight.data[1, 1] = K_G.clone()
 
240
 
241
 
242
  def raw2rgb(raw, black_level=None, reduce_size=True, out_channels=3):
243
+ """Transforms a raw image with 1 channel to rgb with 3 channels
244
+
245
  Args:
246
  raw (Tensor): raw Tensor of shape (B, H, W)
247
  black_level (iterable, optional): RGGB black level values to subtract
248
  reduce_size (bool, optional): if False, the output image will have the same height and width
249
  as the raw input, i.e. (B, C, H, W), empty values are filled with zeros.
250
  if True, the output dimensions are reduced by half (B, C, H//2, W//2),
 
251
  out_channels (int, optional): number of output channels. One of {3, 4}.
252
+ The two green channels are averaged if out_channels == 3.
253
  """
254
  assert out_channels in [3, 4]
255
  if black_level is None:
train.py CHANGED
@@ -15,14 +15,14 @@ from pytorch_lightning.metrics.functional import accuracy
15
  import pytorch_lightning as pl
16
  from pytorch_lightning.callbacks import ModelCheckpoint
17
 
18
- from utils.base import display_mlflow_run_info, str2bool, fetch_from_mlflow, get_name, data_loader_mean_and_std
19
  from utils.debug import debug
20
  from utils.dataset_utils import k_fold
21
  from utils.augmentation import get_augmentation
22
  from dataset import Subset, get_dataset
23
 
24
  from processing.pipeline_numpy import RawProcessingPipeline
25
- from processing.pipeline_torch import add_additive_layer, raw2rgb, RawToRGB, ParametrizedProcessing, NNProcessing
26
 
27
  from model import log_tensor, resnet_model, LitModel, TrackImagesCallback
28
 
@@ -31,89 +31,88 @@ import segmentation_models_pytorch as smp
31
  from utils.ssim import SSIM
32
 
33
  # args to set up task
34
- parser = argparse.ArgumentParser(description="classification_task")
35
- parser.add_argument("--tracking_uri", type=str,
36
- default="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com", help='URI of the mlflow server on AWS')
37
- parser.add_argument("--processor_uri", type=str, default=None,
38
  help='URI of the processing model (e.g. s3://mlflow-artifacts-821771080529/1/5fa754c566e3466690b1d309a476340f/artifacts/processing-model)')
39
- parser.add_argument("--classifier_uri", type=str, default=None,
40
  help='URI of the net (e.g. s3://mlflow-artifacts-821771080529/1/5fa754c566e3466690b1d309a476340f/artifacts/prediction-model)')
41
- parser.add_argument("--state_dict_uri", type=str,
42
  default=None, help='URI of the indices you want to load (e.g. s3://mlflow-artifacts-601883093460/7/4326da05aca54107be8c554de0674a14/artifacts/training')
43
 
44
- parser.add_argument("--experiment_name", type=str,
45
  default='classification learnable pipeline', help='Specify the experiment you are running, e.g. end2end segmentation')
46
- parser.add_argument("--run_name", type=str,
47
  default='test run', help='Specify the name of your run')
48
 
49
- parser.add_argument("--log_model", type=str2bool, default=True, help='Enables model logging')
50
- parser.add_argument("--save_locally", action='store_true',
51
  help='Model will be saved locally if action is taken') # TODO: bypass mlflow
52
 
53
- parser.add_argument("--track_processing", action='store_true',
54
  help='Save images after each trasformation of the pipeline for the test set')
55
- parser.add_argument("--track_processing_gradients", action='store_true',
56
  help='Save images of gradients after each trasformation of the pipeline for the test set')
57
- parser.add_argument("--track_save_tensors", action='store_true',
58
  help='Save the torch tensors after each trasformation of the pipeline for the test set')
59
- parser.add_argument("--track_predictions", action='store_true',
60
  help='Save images after each trasformation of the pipeline for the test set + input gradient')
61
- parser.add_argument("--track_n_images", default=5,
62
  help='Track the n first elements of dataset. Only used for args.track_processing=True')
63
- parser.add_argument("--track_every_epoch", action='store_true', help='Track images every epoch or once after training')
64
 
65
  # args to create dataset
66
- parser.add_argument("--seed", type=int, default=1, help='Global seed')
67
- parser.add_argument("--dataset", type=str, default='Microscopy',
68
- choices=["Drone", "DroneSegmentation", "Microscopy"], help='Select dataset')
69
 
70
- parser.add_argument("--n_splits", type=int, default=1, help='Number of splits used for training')
71
- parser.add_argument("--train_size", type=float, default=0.8, help='Fraction of training points in dataset')
72
 
73
  # args for training
74
- parser.add_argument("--lr", type=float, default=1e-5, help="learning rate used for training")
75
- parser.add_argument("--epochs", type=int, default=3, help="numper of epochs")
76
- parser.add_argument("--batch_size", type=int, default=32, help="Training batch size")
77
- parser.add_argument("--augmentation", type=str, default='none',
78
- choices=["none", "weak", "strong"], help="Applies augmentation to training")
79
- parser.add_argument("--augmentation_on_valid_epoch", action='store_true',
80
- help='Track images every epoch or once after training') # TODO: implement, actually should be disabled by default for 'val' and 'test
81
- parser.add_argument("--check_val_every_n_epoch", type=int, default=1)
82
 
83
  # args to specify the processing
84
- parser.add_argument("--processing_mode", type=str, default="parametrized",
85
- choices=["parametrized", "static", "neural_network", "none"],
86
- help="Which type of raw to rgb processing should be used")
87
 
88
  # args to specify model
89
- parser.add_argument("--classifier_network", type=str, default='ResNet18',
90
- help='Type of pretrained network') # TODO: implement different choices
91
- parser.add_argument("--classifier_pretrained", action='store_true',
92
  help='Whether to use a pre-trained model or not')
93
- parser.add_argument("--smp_encoder", type=str, default='resnet34', help='segmentation model encoder')
94
 
95
- parser.add_argument("--freeze_processor", action='store_true', help="Freeze raw to rgb processing model weights")
96
- parser.add_argument("--freeze_classifier", action='store_true', help="Freeze classification model weights")
97
 
98
  # args to specify static pipeline transformations
99
- parser.add_argument("--sp_debayer", type=str, default='bilinear',
100
- choices=['bilinear', 'malvar2004', 'menon2007'], help="Specify algorithm used as debayer")
101
- parser.add_argument("--sp_sharpening", type=str, default='sharpening_filter',
102
- choices=['sharpening_filter', 'unsharp_masking'], help="Specify algorithm used for sharpening")
103
- parser.add_argument("--sp_denoising", type=str, default='gaussian_denoising',
104
- choices=['gaussian_denoising', 'median_denoising', 'fft_denoising'], help="Specify algorithm used for denoising")
105
 
106
  # args to choose training mode
107
- parser.add_argument("--adv_training", action='store_true', help="Enable adversarial training")
108
- parser.add_argument("--adv_aux_weight", type=float, default=1, help="Weighting of the adversarial auxilliary loss")
109
- parser.add_argument("--adv_aux_loss", type=str, default='ssim', choices=['l2', 'ssim'],
110
- help="Type of adversarial auxilliary regularization loss")
111
- parser.add_argument("--adv_noise_layer", action='store_true', help="Adds an additive layer to Parametrized Processing")
112
- parser.add_argument("--adv_track_differences", action='store_true', help='Save difference to default pipeline')
113
  parser.add_argument('--adv_parameters', choices=['all', 'black_level', 'white_balance',
114
- 'colour_correction', 'gamma_correct', 'sharpening_filter', 'gaussian_blur', 'additive_layer'])
 
115
 
116
- parser.add_argument("--cache_downloaded_models", type=str2bool, default=True)
117
 
118
  parser.add_argument('--test_run', action='store_true')
119
 
@@ -133,10 +132,8 @@ def run_train(args):
133
  # set tracking uri, this is the address of the mlflow server where light experimental data will be stored
134
  mlflow.set_tracking_uri(args.tracking_uri)
135
  mlflow.set_experiment(args.experiment_name)
136
- os.environ["AWS_ACCESS_KEY_ID"] = "#TODO: fill in your aws access key id for mlflow server here"
137
- os.environ["AWS_SECRET_ACCESS_KEY"] = "#TODO: fill in your aws secret access key for mlflow server here"
138
-
139
- # dataset
140
 
141
  dataset = get_dataset(args.dataset)
142
 
@@ -147,52 +144,57 @@ def run_train(args):
147
  pl.seed_everything(args.seed)
148
  idxs_kfold = k_fold(dataset, n_splits=args.n_splits, seed=args.seed, train_size=args.train_size)
149
 
 
150
  with mlflow.start_run(run_name=args.run_name) as parent_run:
151
 
152
- for k_iter, idxs in enumerate(idxs_kfold):
 
153
 
154
- print(f"K_fold subset: {k_iter+1}/{args.n_splits}")
155
 
156
  if args.processing_mode == 'static':
157
- if args.dataset == "Drone" or args.dataset == "DroneSegmentation":
 
158
  mean = torch.tensor([0.35, 0.36, 0.35])
159
  std = torch.tensor([0.12, 0.11, 0.12])
160
- elif args.dataset == "Microscopy":
161
  mean = torch.tensor([0.91, 0.84, 0.94])
162
  std = torch.tensor([0.08, 0.12, 0.05])
163
 
 
164
  dataset.transform = T.Compose([RawProcessingPipeline(
165
  camera_parameters=dataset.camera_parameters,
166
  debayer=args.sp_debayer,
167
  sharpening=args.sp_sharpening,
168
  denoising=args.sp_denoising,
169
- ), T.Normalize(mean, std)])
170
- # XXX: Not clean
 
171
 
172
  processor = nn.Identity()
173
 
 
174
  if args.processor_uri is not None and args.processing_mode != 'none':
175
  print('Fetching processor: ', end='')
176
- model = fetch_from_mlflow(args.processor_uri, use_cache=args.cache_downloaded_models)
177
- processor = model.processor
178
- for param in processor.parameters():
179
- param.requires_grad = True
180
- model.processor = None
181
- del model
182
  else:
183
  print(f'processing_mode: {args.processing_mode}')
184
- normalize_mosaic = None # normalize after raw has been passed to raw2rgb
185
- if args.dataset == "Microscopy":
 
 
 
 
186
  mosaic_mean = [0.5663, 0.1401, 0.0731]
187
  mosaic_std = [0.097, 0.0423, 0.008]
188
  normalize_mosaic = T.Normalize(mosaic_mean, mosaic_std)
189
 
 
190
  track_stages = args.track_processing or args.track_processing_gradients
191
  if args.processing_mode == 'parametrized':
192
  processor = ParametrizedProcessing(
193
- camera_parameters=dataset.camera_parameters, track_stages=track_stages, batch_norm_output=True,
194
- # noise_layer=args.adv_noise_layer, # this has to be added manually afterwards for when a model is loaded that doesn't have one yet
195
- )
196
 
197
  elif args.processing_mode == 'neural_network':
198
  processor = NNProcessing(track_stages=track_stages,
@@ -201,22 +203,19 @@ def run_train(args):
201
  processor = RawToRGB(reduce_size=True, out_channels=3, track_stages=track_stages,
202
  normalize_mosaic=normalize_mosaic)
203
 
204
- if args.classifier_uri: # fetch classifier
205
  print('Fetching classifier: ', end='')
206
- model = fetch_from_mlflow(args.classifier_uri, use_cache=args.cache_downloaded_models)
207
- classifier = model.classifier
208
- model.classifier = None
209
- del model
210
  else:
211
  if dataset.task == 'classification':
212
  classifier = resnet_model(
213
- model=resnet18,
214
  pretrained=args.classifier_pretrained,
215
  in_channels=3,
216
  fc_out_features=len(dataset.classes)
217
  )
218
  else:
219
- # XXX: add other network choices to args.smp_network (FPN) and args.network
220
  classifier = smp.UnetPlusPlus(
221
  encoder_name=args.smp_encoder,
222
  encoder_depth=5,
@@ -240,26 +239,23 @@ def run_train(args):
240
 
241
  loss_aux = None
242
 
243
- if args.adv_training:
244
 
245
  assert args.processing_mode == 'parametrized', f"Processing mode ({args.processing_mode}) should be set to 'parametrized' for adversarial training"
246
- assert args.freeze_classifier, "Classifier should be frozen for adversarial training"
247
- assert not args.freeze_processor, "Processor should not be frozen for adversarial training"
248
 
 
249
  processor_default = copy.deepcopy(processor)
250
  processor_default.track_stages = args.track_processing
251
  processor_default.eval()
252
  processor_default.to(DEVICE)
253
- # debug(processor_default)
254
  for p in processor_default.parameters():
255
  p.requires_grad = False
256
 
257
- if args.adv_noise_layer:
258
- add_additive_layer(processor)
259
-
260
- def l2_regularization(x, y):
261
- return ((x - y) ** 2).sum()
262
- # return (x - y).norm()
263
 
264
  if args.adv_aux_loss == 'l2':
265
  regularization = l2_regularization
@@ -268,34 +264,12 @@ def run_train(args):
268
  else:
269
  NotImplementedError(args.adv_aux_loss)
270
 
271
- class AuxLoss(nn.Module):
272
- def __init__(self, loss_aux, weight=1):
273
- super().__init__()
274
- self.loss_aux = loss_aux
275
- self.weight = weight
276
-
277
- def forward(self, x):
278
- with torch.no_grad():
279
- x_reference = processor_default(x)
280
- x_processed = processor.buffer['processed_rgb']
281
- return self.weight * self.loss_aux(x_reference, x_processed)
282
-
283
- class WeightedLoss(nn.Module):
284
- def __init__(self, loss, weight=1):
285
- super().__init__()
286
- self.loss = loss
287
- self.weight = weight
288
-
289
- def forward(self, x, y):
290
- return self.weight * self.loss(x, y)
291
-
292
- def __repr__(self):
293
- return f'{self.weight} * {get_name(self.loss)}'
294
-
295
  loss = WeightedLoss(loss=loss, weight=-1)
296
- # loss = WeightedLoss(loss=nn.CrossEntropyLoss(), weight=0)
297
  loss_aux = AuxLoss(
298
  loss_aux=regularization,
 
 
299
  weight=args.adv_aux_weight,
300
  )
301
 
@@ -316,15 +290,13 @@ def run_train(args):
316
  freeze_processor=args.freeze_processor,
317
  )
318
 
 
 
319
  # get train_set_dict
320
  if args.state_dict_uri:
321
  state_dict = mlflow.pytorch.load_state_dict(args.state_dict_uri)
322
  train_indices = state_dict['train_indices']
323
  valid_indices = state_dict['valid_indices']
324
- else:
325
- train_indices = idxs[0]
326
- valid_indices = idxs[1]
327
- state_dict = vars(args).copy()
328
 
329
  track_indices = list(range(args.track_n_images))
330
 
@@ -350,8 +322,6 @@ def run_train(args):
350
 
351
  with mlflow.start_run(run_name=f"{args.run_name}_{k_iter}", nested=True) as child_run:
352
 
353
- # mlflow.pytorch.autolog(silent=True)
354
-
355
  if k_iter == 0:
356
  display_mlflow_run_info(child_run)
357
 
@@ -389,14 +359,6 @@ def run_train(args):
389
  track_predictions=args.track_predictions,
390
  save_tensors=args.track_save_tensors)]
391
 
392
- # if True: #args.save_best:
393
- # if dataset.task == 'classification':
394
- #checkpoint_callback = ModelCheckpoint(pathmonitor="val_accuracy", mode='max')
395
- # checkpoint_callback = ModelCheckpoint(dirpath=args.tracking_uri, save_top_k=1, verbose=True, monitor="val_accuracy", mode="max") #dirpath=args.tracking_uri,
396
- # else:
397
- # checkpoint_callback = ModelCheckpoint(monitor="val_iou_score")
398
- #callbacks += [checkpoint_callback]
399
-
400
  trainer = pl.Trainer(
401
  gpus=1 if DEVICE == 'cuda' else 0,
402
  min_epochs=args.epochs,
@@ -404,7 +366,6 @@ def run_train(args):
404
  logger=mlf_logger,
405
  callbacks=callbacks,
406
  check_val_every_n_epoch=args.check_val_every_n_epoch,
407
- # checkpoint_callback=True,
408
  )
409
 
410
  if args.log_model:
 
15
  import pytorch_lightning as pl
16
  from pytorch_lightning.callbacks import ModelCheckpoint
17
 
18
+ from utils.base import AuxLoss, WeightedLoss, display_mlflow_run_info, l2_regularization, str2bool, fetch_from_mlflow, get_name, data_loader_mean_and_std
19
  from utils.debug import debug
20
  from utils.dataset_utils import k_fold
21
  from utils.augmentation import get_augmentation
22
  from dataset import Subset, get_dataset
23
 
24
  from processing.pipeline_numpy import RawProcessingPipeline
25
+ from processing.pipeline_torch import append_additive_layer, raw2rgb, RawToRGB, ParametrizedProcessing, NNProcessing
26
 
27
  from model import log_tensor, resnet_model, LitModel, TrackImagesCallback
28
 
 
31
  from utils.ssim import SSIM
32
 
33
  # args to set up task
34
+ parser = argparse.ArgumentParser(description='classification_task')
35
+ parser.add_argument('--tracking_uri', type=str,
36
+ default='http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com', help='URI of the mlflow server on AWS')
37
+ parser.add_argument('--processor_uri', type=str, default=None,
38
  help='URI of the processing model (e.g. s3://mlflow-artifacts-821771080529/1/5fa754c566e3466690b1d309a476340f/artifacts/processing-model)')
39
+ parser.add_argument('--classifier_uri', type=str, default=None,
40
  help='URI of the net (e.g. s3://mlflow-artifacts-821771080529/1/5fa754c566e3466690b1d309a476340f/artifacts/prediction-model)')
41
+ parser.add_argument('--state_dict_uri', type=str,
42
  default=None, help='URI of the indices you want to load (e.g. s3://mlflow-artifacts-601883093460/7/4326da05aca54107be8c554de0674a14/artifacts/training')
43
 
44
+ parser.add_argument('--experiment_name', type=str,
45
  default='classification learnable pipeline', help='Specify the experiment you are running, e.g. end2end segmentation')
46
+ parser.add_argument('--run_name', type=str,
47
  default='test run', help='Specify the name of your run')
48
 
49
+ parser.add_argument('--log_model', type=str2bool, default=True, help='Enables model logging')
50
+ parser.add_argument('--save_locally', action='store_true',
51
  help='Model will be saved locally if action is taken') # TODO: bypass mlflow
52
 
53
+ parser.add_argument('--track_processing', action='store_true',
54
  help='Save images after each trasformation of the pipeline for the test set')
55
+ parser.add_argument('--track_processing_gradients', action='store_true',
56
  help='Save images of gradients after each trasformation of the pipeline for the test set')
57
+ parser.add_argument('--track_save_tensors', action='store_true',
58
  help='Save the torch tensors after each trasformation of the pipeline for the test set')
59
+ parser.add_argument('--track_predictions', action='store_true',
60
  help='Save images after each trasformation of the pipeline for the test set + input gradient')
61
+ parser.add_argument('--track_n_images', default=5,
62
  help='Track the n first elements of dataset. Only used for args.track_processing=True')
63
+ parser.add_argument('--track_every_epoch', action='store_true', help='Track images every epoch or once after training')
64
 
65
  # args to create dataset
66
+ parser.add_argument('--seed', type=int, default=1, help='Global seed')
67
+ parser.add_argument('--dataset', type=str, default='Microscopy',
68
+ choices=['Drone', 'DroneSegmentation', 'Microscopy'], help='Select dataset')
69
 
70
+ parser.add_argument('--n_splits', type=int, default=1, help='Number of splits used for training')
71
+ parser.add_argument('--train_size', type=float, default=0.8, help='Fraction of training points in dataset')
72
 
73
  # args for training
74
+ parser.add_argument('--lr', type=float, default=1e-5, help='learning rate used for training')
75
+ parser.add_argument('--epochs', type=int, default=3, help='numper of epochs')
76
+ parser.add_argument('--batch_size', type=int, default=32, help='Training batch size')
77
+ parser.add_argument('--augmentation', type=str, default='none',
78
+ choices=['none', 'weak', 'strong'], help='Applies augmentation to training')
79
+ parser.add_argument('--check_val_every_n_epoch', type=int, default=1)
 
 
80
 
81
  # args to specify the processing
82
+ parser.add_argument('--processing_mode', type=str, default='parametrized',
83
+ choices=['parametrized', 'static', 'neural_network', 'none'],
84
+ help='Which type of raw to rgb processing should be used')
85
 
86
  # args to specify model
87
+ parser.add_argument('--classifier_network', type=str, default='ResNet18', choices=['ResNet18', 'ResNet34', 'Resnet50'],
88
+ help='Type of pretrained network')
89
+ parser.add_argument('--classifier_pretrained', action='store_true',
90
  help='Whether to use a pre-trained model or not')
91
+ parser.add_argument('--smp_encoder', type=str, default='resnet34', help='segmentation models pytorch encoder')
92
 
93
+ parser.add_argument('--freeze_processor', action='store_true', help='Freeze raw to rgb processing model weights')
94
+ parser.add_argument('--freeze_classifier', action='store_true', help='Freeze classification model weights')
95
 
96
  # args to specify static pipeline transformations
97
+ parser.add_argument('--sp_debayer', type=str, default='bilinear',
98
+ choices=['bilinear', 'malvar2004', 'menon2007'], help='Specify algorithm used as debayer')
99
+ parser.add_argument('--sp_sharpening', type=str, default='sharpening_filter',
100
+ choices=['sharpening_filter', 'unsharp_masking'], help='Specify algorithm used for sharpening')
101
+ parser.add_argument('--sp_denoising', type=str, default='gaussian_denoising',
102
+ choices=['gaussian_denoising', 'median_denoising', 'fft_denoising'], help='Specify algorithm used for denoising')
103
 
104
  # args to choose training mode
105
+ parser.add_argument('--adv_training', action='store_true', help='Enable adversarial training')
106
+ parser.add_argument('--adv_aux_weight', type=float, default=1, help='Weighting of the adversarial auxilliary loss')
107
+ parser.add_argument('--adv_aux_loss', type=str, default='ssim', choices=['l2', 'ssim'],
108
+ help='Type of adversarial auxilliary regularization loss')
109
+ parser.add_argument('--adv_noise_layer', action='store_true', help='Adds an additive layer to Parametrized Processing')
110
+ parser.add_argument('--adv_track_differences', action='store_true', help='Save difference to default pipeline')
111
  parser.add_argument('--adv_parameters', choices=['all', 'black_level', 'white_balance',
112
+ 'colour_correction', 'gamma_correct', 'sharpening_filter', 'gaussian_blur', 'additive_layer'],
113
+ help='Target individual parameters for adversarial training.')
114
 
115
+ parser.add_argument('--cache_downloaded_models', type=str2bool, default=True)
116
 
117
  parser.add_argument('--test_run', action='store_true')
118
 
 
132
  # set tracking uri, this is the address of the mlflow server where light experimental data will be stored
133
  mlflow.set_tracking_uri(args.tracking_uri)
134
  mlflow.set_experiment(args.experiment_name)
135
+ os.environ['AWS_ACCESS_KEY_ID'] = '#TODO: fill in your aws access key id for mlflow server here'
136
+ os.environ['AWS_SECRET_ACCESS_KEY'] = '#TODO: fill in your aws secret access key for mlflow server here'
 
 
137
 
138
  dataset = get_dataset(args.dataset)
139
 
 
144
  pl.seed_everything(args.seed)
145
  idxs_kfold = k_fold(dataset, n_splits=args.n_splits, seed=args.seed, train_size=args.train_size)
146
 
147
+ # start mlflow parent run for k-fold validation (optional)
148
  with mlflow.start_run(run_name=args.run_name) as parent_run:
149
 
150
+ # start mlflow child run
151
+ for k_iter, (train_indices, valid_indices) in enumerate(idxs_kfold):
152
 
153
+ print(f'K_fold subset: {k_iter+1}/{args.n_splits}')
154
 
155
  if args.processing_mode == 'static':
156
+ # only needed if processor outputs should be normalized (might help for classifier training / testing against torch pipeline)
157
+ if args.dataset == 'Drone' or args.dataset == 'DroneSegmentation':
158
  mean = torch.tensor([0.35, 0.36, 0.35])
159
  std = torch.tensor([0.12, 0.11, 0.12])
160
+ elif args.dataset == 'Microscopy':
161
  mean = torch.tensor([0.91, 0.84, 0.94])
162
  std = torch.tensor([0.08, 0.12, 0.05])
163
 
164
+ # numpy pipeline doesn't use torch batched transformations. Transformations are applied individually to dataloader
165
  dataset.transform = T.Compose([RawProcessingPipeline(
166
  camera_parameters=dataset.camera_parameters,
167
  debayer=args.sp_debayer,
168
  sharpening=args.sp_sharpening,
169
  denoising=args.sp_denoising,
170
+ ),
171
+ T.Normalize(mean, std)
172
+ ])
173
 
174
  processor = nn.Identity()
175
 
176
+ # fetch processor from mlflow
177
  if args.processor_uri is not None and args.processing_mode != 'none':
178
  print('Fetching processor: ', end='')
179
+ processor = fetch_from_mlflow(args.processor_uri, type='processor',
180
+ use_cache=args.cache_downloaded_models)
 
 
 
 
181
  else:
182
  print(f'processing_mode: {args.processing_mode}')
183
+ normalize_mosaic = None # normalize after raw has been transformed to rgb image via raw2rgb
184
+ # not strictly necessary, but for processing_mode=='none' this will ensure normalized outputs for the classifier
185
+ # and for processing_mode=='neural_network', the processing segmentation model receives normalized inputs
186
+ # could be evaded via an additional batchnorm!
187
+ # XXX
188
+ if args.dataset == 'Microscopy':
189
  mosaic_mean = [0.5663, 0.1401, 0.0731]
190
  mosaic_std = [0.097, 0.0423, 0.008]
191
  normalize_mosaic = T.Normalize(mosaic_mean, mosaic_std)
192
 
193
+ # track individual processing steps for visualization
194
  track_stages = args.track_processing or args.track_processing_gradients
195
  if args.processing_mode == 'parametrized':
196
  processor = ParametrizedProcessing(
197
+ camera_parameters=dataset.camera_parameters, track_stages=track_stages, batch_norm_output=True)
 
 
198
 
199
  elif args.processing_mode == 'neural_network':
200
  processor = NNProcessing(track_stages=track_stages,
 
203
  processor = RawToRGB(reduce_size=True, out_channels=3, track_stages=track_stages,
204
  normalize_mosaic=normalize_mosaic)
205
 
206
+ if args.classifier_uri: # fetch classifier from mlflow
207
  print('Fetching classifier: ', end='')
208
+ classifier = fetch_from_mlflow(args.classifier_uri, type='classifier',
209
+ use_cache=args.cache_downloaded_models)
 
 
210
  else:
211
  if dataset.task == 'classification':
212
  classifier = resnet_model(
213
+ model=args.classifier_network,
214
  pretrained=args.classifier_pretrained,
215
  in_channels=3,
216
  fc_out_features=len(dataset.classes)
217
  )
218
  else:
 
219
  classifier = smp.UnetPlusPlus(
220
  encoder_name=args.smp_encoder,
221
  encoder_depth=5,
 
239
 
240
  loss_aux = None
241
 
242
+ if args.adv_training: # setup for failure mode search
243
 
244
  assert args.processing_mode == 'parametrized', f"Processing mode ({args.processing_mode}) should be set to 'parametrized' for adversarial training"
245
+ assert args.freeze_classifier, 'Classifier should be frozen for adversarial training'
246
+ assert not args.freeze_processor, 'Processor should not be frozen for adversarial training'
247
 
248
+ # copy, so that regularization in rgb space between adversarial and original processor can be computed
249
  processor_default = copy.deepcopy(processor)
250
  processor_default.track_stages = args.track_processing
251
  processor_default.eval()
252
  processor_default.to(DEVICE)
253
+
254
  for p in processor_default.parameters():
255
  p.requires_grad = False
256
 
257
+ if args.adv_noise_layer: # optional additional "noise" layer in processor
258
+ append_additive_layer(processor)
 
 
 
 
259
 
260
  if args.adv_aux_loss == 'l2':
261
  regularization = l2_regularization
 
264
  else:
265
  NotImplementedError(args.adv_aux_loss)
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  loss = WeightedLoss(loss=loss, weight=-1)
268
+
269
  loss_aux = AuxLoss(
270
  loss_aux=regularization,
271
+ processor_adv=processor,
272
+ processor_default=processor_default,
273
  weight=args.adv_aux_weight,
274
  )
275
 
 
290
  freeze_processor=args.freeze_processor,
291
  )
292
 
293
+ state_dict = vars(args).copy()
294
+
295
  # get train_set_dict
296
  if args.state_dict_uri:
297
  state_dict = mlflow.pytorch.load_state_dict(args.state_dict_uri)
298
  train_indices = state_dict['train_indices']
299
  valid_indices = state_dict['valid_indices']
 
 
 
 
300
 
301
  track_indices = list(range(args.track_n_images))
302
 
 
322
 
323
  with mlflow.start_run(run_name=f"{args.run_name}_{k_iter}", nested=True) as child_run:
324
 
 
 
325
  if k_iter == 0:
326
  display_mlflow_run_info(child_run)
327
 
 
359
  track_predictions=args.track_predictions,
360
  save_tensors=args.track_save_tensors)]
361
 
 
 
 
 
 
 
 
 
362
  trainer = pl.Trainer(
363
  gpus=1 if DEVICE == 'cuda' else 0,
364
  min_epochs=args.epochs,
 
366
  logger=mlf_logger,
367
  callbacks=callbacks,
368
  check_val_every_n_epoch=args.check_val_every_n_epoch,
 
369
  )
370
 
371
  if args.log_model:
utils/base.py CHANGED
@@ -18,14 +18,7 @@ from b2sdk.v1 import *
18
 
19
  import argparse
20
 
21
-
22
- class SmartFormatter(argparse.HelpFormatter):
23
-
24
- def _split_lines(self, text, width):
25
- if text.startswith('R|'):
26
- return text[2:].splitlines()
27
- # this is the RawTextHelpFormatter._split_lines
28
- return argparse.HelpFormatter._split_lines(self, text, width)
29
 
30
 
31
  def str2bool(string):
@@ -193,6 +186,7 @@ def b2_download_folder(b2_dir, local_dir, force_download=False, mirror_folder=Tr
193
  def get_name(obj):
194
  return obj.__name__ if hasattr(obj, '__name__') else type(obj).__name__
195
 
 
196
  def get_mlflow_model_by_name(experiment_name, run_name,
197
  tracking_uri="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com",
198
  download_model=True):
@@ -234,6 +228,7 @@ def get_mlflow_model_by_name(experiment_name, run_name,
234
 
235
  return state_dict, model
236
 
 
237
  def data_loader_mean_and_std(data_loader, transform=None):
238
  means = []
239
  stds = []
@@ -244,23 +239,35 @@ def data_loader_mean_and_std(data_loader, transform=None):
244
  stds.append(x.std(dim=(0, 2, 3)).unsqueeze(0))
245
  return torch.cat(means).mean(dim=0), torch.cat(stds).mean(dim=0)
246
 
 
247
  def fetch_runs_list_mlflow(experiment):
248
  runs = mlflow.search_runs(experiment.experiment_id)
249
  runs.to_pickle('cache/runs_names.pkl') # where to save it, usually as a .pkl
250
  return runs
251
 
252
- def fetch_from_mlflow(uri, use_cache=True, download_model=True):
 
253
  cache_loc = os.path.join('cache', uri.split('//')[1]) + '.pt'
254
  if use_cache and os.path.exists(cache_loc):
255
  print(f'loading cached model from {cache_loc} ...')
256
- return torch.load(cache_loc)
257
  else:
258
  print(f'fetching model from {uri} ...')
259
  model = mlflow.pytorch.load_model(uri)
260
  os.makedirs(os.path.dirname(cache_loc), exist_ok=True)
261
  if download_model:
262
  torch.save(model, cache_loc, pickle_module=mlflow.pytorch.pickle_module)
263
- return model
 
 
 
 
 
 
 
 
 
 
264
 
265
 
266
  def display_mlflow_run_info(run):
@@ -315,16 +322,50 @@ def get_train_test_indices_drone(df, frac, seed=None):
315
  return train_indices, test_indices
316
 
317
 
318
- def smp_get_loss(loss):
319
- if loss == "Dice":
320
- return smp.losses.DiceLoss(mode='binary', from_logits=True)
321
- if loss == "BCE":
322
- return nn.BCELoss()
323
- elif loss == "BCEWithLogits":
324
- return smp.losses.BCEWithLogitsLoss()
325
- elif loss == "DicyBCE":
326
- from pytorch_toolbelt import losses as ptbl
327
- return ptbl.JointLoss(ptbl.DiceLoss(mode='binary', from_logits=False),
328
- nn.BCELoss(),
329
- first_weight=args.dice_weight,
330
- second_weight=args.bce_weight)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  import argparse
20
 
21
+ from torch import nn
 
 
 
 
 
 
 
22
 
23
 
24
  def str2bool(string):
 
186
  def get_name(obj):
187
  return obj.__name__ if hasattr(obj, '__name__') else type(obj).__name__
188
 
189
+
190
  def get_mlflow_model_by_name(experiment_name, run_name,
191
  tracking_uri="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com",
192
  download_model=True):
 
228
 
229
  return state_dict, model
230
 
231
+
232
  def data_loader_mean_and_std(data_loader, transform=None):
233
  means = []
234
  stds = []
 
239
  stds.append(x.std(dim=(0, 2, 3)).unsqueeze(0))
240
  return torch.cat(means).mean(dim=0), torch.cat(stds).mean(dim=0)
241
 
242
+
243
  def fetch_runs_list_mlflow(experiment):
244
  runs = mlflow.search_runs(experiment.experiment_id)
245
  runs.to_pickle('cache/runs_names.pkl') # where to save it, usually as a .pkl
246
  return runs
247
 
248
+
249
+ def fetch_from_mlflow(uri, type='', use_cache=True, download_model=True):
250
  cache_loc = os.path.join('cache', uri.split('//')[1]) + '.pt'
251
  if use_cache and os.path.exists(cache_loc):
252
  print(f'loading cached model from {cache_loc} ...')
253
+ model = torch.load(cache_loc)
254
  else:
255
  print(f'fetching model from {uri} ...')
256
  model = mlflow.pytorch.load_model(uri)
257
  os.makedirs(os.path.dirname(cache_loc), exist_ok=True)
258
  if download_model:
259
  torch.save(model, cache_loc, pickle_module=mlflow.pytorch.pickle_module)
260
+ if type == 'processor':
261
+ processor = model.processor
262
+ model.processor = None
263
+ del model # free up memory space
264
+ return processor
265
+ if type == 'classifier':
266
+ classifier = model.classifier
267
+ model.classifier = None
268
+ del model # free up memory space
269
+ return classifier
270
+ return model
271
 
272
 
273
  def display_mlflow_run_info(run):
 
322
  return train_indices, test_indices
323
 
324
 
325
+ # def smp_get_loss(loss):
326
+ # if loss == "Dice":
327
+ # return smp.losses.DiceLoss(mode='binary', from_logits=True)
328
+ # if loss == "BCE":
329
+ # return nn.BCELoss()
330
+ # elif loss == "BCEWithLogits":
331
+ # return smp.losses.BCEWithLogitsLoss()
332
+ # elif loss == "DicyBCE":
333
+ # from pytorch_toolbelt import losses as ptbl
334
+ # return ptbl.JointLoss(ptbl.DiceLoss(mode='binary', from_logits=False),
335
+ # nn.BCELoss(),
336
+ # first_weight=args.dice_weight,
337
+ # second_weight=args.bce_weight)
338
+
339
+
340
+ # Adversarial setup
341
+
342
+ def l2_regularization(x, y):
343
+ return ((x - y) ** 2).sum()
344
+
345
+
346
+ class AuxLoss(nn.Module):
347
+ def __init__(self, loss_aux, processor_adv, processor_default, weight=1):
348
+ super().__init__()
349
+ self.loss_aux = loss_aux
350
+ self.weight = weight
351
+ self.processor_adv = processor_adv
352
+ self.processor_default = processor_default
353
+
354
+ def forward(self, x):
355
+ with torch.no_grad():
356
+ x_reference = self.processor_default(x)
357
+ x_processed = self.processor.buffer['processed_rgb']
358
+ return self.weight * self.loss_aux(x_reference, x_processed)
359
+
360
+
361
+ class WeightedLoss(nn.Module):
362
+ def __init__(self, loss, weight=1):
363
+ super().__init__()
364
+ self.loss = loss
365
+ self.weight = weight
366
+
367
+ def forward(self, x, y):
368
+ return self.weight * self.loss(x, y)
369
+
370
+ def __repr__(self):
371
+ return f'{self.weight} * {get_name(self.loss)}'