willis commited on
Commit
0220054
Β·
1 Parent(s): 839dc8c

reorganize

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
utils/dataset.py β†’ dataset.py RENAMED
@@ -15,7 +15,7 @@ from sklearn.model_selection import StratifiedShuffleSplit
15
  if not os.path.exists('README.md'): # set pwd to root
16
  os.chdir('..')
17
 
18
- from utils.splitting import split_img
19
  from utils.base import np2torch, torch2np, b2_download_folder
20
 
21
  IMAGE_FILE_TYPES = ['dng', 'png', 'tif', 'tiff']
@@ -41,24 +41,6 @@ def get_dataset(name, I_ratio=1.0):
41
  raise ValueError(name)
42
 
43
 
44
- def load_image(path):
45
- file_type = path.split('.')[-1].lower()
46
- if file_type == 'dng':
47
- img = rawpy.imread(path).raw_image_visible
48
- elif file_type == 'tiff' or file_type == 'tif':
49
- img = np.array(tiff.imread(path), dtype=np.float32)
50
- else:
51
- img = np.array(Image.open(path), dtype=np.float32)
52
- return img
53
-
54
-
55
- def list_images_in_dir(path):
56
- image_list = [os.path.join(path, img_name)
57
- for img_name in sorted(os.listdir(path))
58
- if img_name.split('.')[-1].lower() in IMAGE_FILE_TYPES]
59
- return image_list
60
-
61
-
62
  class ImageFolderDataset(Dataset):
63
  """Creates a dataset of images in img_dir and corresponding masks in mask_dir.
64
  Corresponding mask files need to contain the filename of the image.
@@ -166,18 +148,20 @@ class ImageFolderDatasetSegmentation(Dataset):
166
 
167
  return img, mask
168
 
 
169
  class MultiIntensity(Dataset):
170
  """Wrap datasets with different intesities
171
 
172
  Args:
173
  datasets (list): list of datasets to wrap
174
  """
 
175
  def __init__(self, datasets):
176
  self.dataset = datasets[0]
177
 
178
- for d in range(1,len(datasets)):
179
- self.dataset.images = self.dataset.images+datasets[d].images
180
- self.dataset.labels = self.dataset.labels+datasets[d].labels
181
 
182
  def __len__(self):
183
  return len(self.dataset)
@@ -191,6 +175,7 @@ class MultiIntensity(Dataset):
191
  x = self.transform(x)
192
  return x, y
193
 
 
194
  class Subset(Dataset):
195
  """Define a subset of a dataset by only selecting given indices.
196
 
@@ -228,8 +213,8 @@ class DroneDatasetSegmentationFull(ImageFolderDatasetSegmentation):
228
  camera_parameters = black_level, white_balance, colour_matrix
229
 
230
  def __init__(self, I_ratio=1.0, transform=None, force_download=False, bits=16):
231
-
232
- assert I_ratio in [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0]
233
 
234
  img_dir = f'data/drone/images_full/raw_scale{int(I_ratio*100):03d}'
235
  mask_dir = 'data/drone/masks_full'
@@ -411,7 +396,7 @@ def download_microscopy_dataset(force_download):
411
 
412
 
413
  def unzip_microscopy_images():
414
-
415
  if os.path.isfile('data/microscopy/labels/.bzEmpty'):
416
  os.remove('data/microscopy/labels/.bzEmpty')
417
 
@@ -421,6 +406,7 @@ def unzip_microscopy_images():
421
  zip.extractall('data/microscopy/images')
422
  os.remove(os.path.join('data/microscopy/images', file))
423
 
 
424
  def unzip_drone_images():
425
 
426
  if os.path.isfile('data/drone/masks_full/.bzEmpty'):
@@ -585,38 +571,3 @@ def check_image_folder_consistency(images, masks):
585
  f"image file {img_file} file type mismatch. Shoule be: {file_type_images}"
586
  assert mask_file.split('.')[-1].lower() == file_type_masks, \
587
  f"image file {mask_file} file type mismatch. Should be: {file_type_masks}"
588
-
589
-
590
- def k_fold(dataset, n_splits: int, seed: int, train_size: float):
591
- """Split dataset in subsets for cross-validation
592
-
593
- Args:
594
- dataset (class): dataset to split
595
- n_split (int): Number of re-shuffling & splitting iterations.
596
- seed (int): seed for k_fold splitting
597
- train_size (float): should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the train split.
598
- Returns:
599
- idxs (list): indeces for splitting the dataset. The list contain n_split pair of train/test indeces.
600
- """
601
- if hasattr(dataset, 'labels'):
602
- x = dataset.images
603
- y = dataset.labels
604
- elif hasattr(dataset, 'masks'):
605
- x = dataset.images
606
- y = dataset.masks
607
-
608
- idxs = []
609
-
610
- if dataset.task == 'classification':
611
- sss = StratifiedShuffleSplit(n_splits=n_splits, train_size=train_size, random_state=seed)
612
-
613
- for idxs_train, idxs_test in sss.split(x, y):
614
- idxs.append((idxs_train.tolist(), idxs_test.tolist()))
615
-
616
- elif dataset.task == 'segmentation':
617
- for n in range(n_splits):
618
- split_idx = int(len(dataset) * train_size)
619
- indices = np.random.permutation(len(dataset))
620
- idxs.append((indices[:split_idx].tolist(), indices[split_idx:].tolist()))
621
-
622
- return idxs
 
15
  if not os.path.exists('README.md'): # set pwd to root
16
  os.chdir('..')
17
 
18
+ from utils.dataset_utils import split_img, list_images_in_dir, load_image
19
  from utils.base import np2torch, torch2np, b2_download_folder
20
 
21
  IMAGE_FILE_TYPES = ['dng', 'png', 'tif', 'tiff']
 
41
  raise ValueError(name)
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  class ImageFolderDataset(Dataset):
45
  """Creates a dataset of images in img_dir and corresponding masks in mask_dir.
46
  Corresponding mask files need to contain the filename of the image.
 
148
 
149
  return img, mask
150
 
151
+
152
  class MultiIntensity(Dataset):
153
  """Wrap datasets with different intesities
154
 
155
  Args:
156
  datasets (list): list of datasets to wrap
157
  """
158
+
159
  def __init__(self, datasets):
160
  self.dataset = datasets[0]
161
 
162
+ for d in range(1, len(datasets)):
163
+ self.dataset.images = self.dataset.images + datasets[d].images
164
+ self.dataset.labels = self.dataset.labels + datasets[d].labels
165
 
166
  def __len__(self):
167
  return len(self.dataset)
 
175
  x = self.transform(x)
176
  return x, y
177
 
178
+
179
  class Subset(Dataset):
180
  """Define a subset of a dataset by only selecting given indices.
181
 
 
213
  camera_parameters = black_level, white_balance, colour_matrix
214
 
215
  def __init__(self, I_ratio=1.0, transform=None, force_download=False, bits=16):
216
+
217
+ assert I_ratio in [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0]
218
 
219
  img_dir = f'data/drone/images_full/raw_scale{int(I_ratio*100):03d}'
220
  mask_dir = 'data/drone/masks_full'
 
396
 
397
 
398
  def unzip_microscopy_images():
399
+
400
  if os.path.isfile('data/microscopy/labels/.bzEmpty'):
401
  os.remove('data/microscopy/labels/.bzEmpty')
402
 
 
406
  zip.extractall('data/microscopy/images')
407
  os.remove(os.path.join('data/microscopy/images', file))
408
 
409
+
410
  def unzip_drone_images():
411
 
412
  if os.path.isfile('data/drone/masks_full/.bzEmpty'):
 
571
  f"image file {img_file} file type mismatch. Shoule be: {file_type_images}"
572
  assert mask_file.split('.')[-1].lower() == file_type_masks, \
573
  f"image file {mask_file} file type mismatch. Should be: {file_type_masks}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ABtesting.py β†’ figures/ABtesting.py RENAMED
@@ -8,11 +8,11 @@ from torch.utils.data import DataLoader
8
  from torchvision.transforms import Compose, Normalize
9
  import torch.nn.functional as F
10
 
11
- from utils.dataset import get_dataset, Subset
12
  from utils.base import get_mlflow_model_by_name, SmartFormatter
13
- from processingpipeline.pipeline import RawProcessingPipeline
14
 
15
- from utils.Cperturb import Distortions
16
 
17
  import segmentation_models_pytorch as smp
18
 
@@ -20,94 +20,104 @@ import matplotlib.pyplot as plt
20
 
21
  parser = argparse.ArgumentParser(description="AB testing, Show Results", formatter_class=SmartFormatter)
22
 
23
- #Select experiment
24
- parser.add_argument("--mode", type=str, default="ABShowImages", choices=('ABMakeTable', 'ABShowTable', 'ABShowImages', 'ABShowAllImages', 'CMakeTable', 'CShowTable', 'CShowImages', 'CShowAllImages'),
25
  help='R|Choose operation to compute. \n'
26
- 'A) Lens2Logit image generation: \n '
27
- 'ABMakeTable: Compute cross-validation metrics results \n '
28
- 'ABShowTable: Plot cross-validation results on a table \n '
29
- 'ABShowImages: Choose a training and testing image to compare different pipelines \n '
30
- 'ABShowAllImages: Plot all possible pipelines \n'
31
- 'B) Hendrycks Perturbations, C-type dataset: \n '
32
- 'CMakeTable: For each pipeline, it computes cross-validation metrics for different perturbations \n '
33
- 'CShowTable: Plot metrics for different pipelines and perturbations \n '
34
- 'CShowImages: Plot an image with a selected a pipeline and perturbation\n '
35
- 'CShowAllImages: Plot all possible perturbations for a fixed pipeline' )
36
-
37
- parser.add_argument("--dataset_name", type=str, default='Microscopy', choices=['Microscopy', 'Drone', 'DroneSegmentation'], help='Choose dataset')
38
- parser.add_argument("--augmentation", type=str, default='weak', choices=['none','weak','strong'], help='Choose augmentation')
 
 
39
  parser.add_argument("--N_runs", type=int, default=5, help='Number of k-fold splitting used in the training')
40
  parser.add_argument("--download_model", default=False, action='store_true', help='Download Models in cache')
41
 
42
- #Select pipelines
43
- parser.add_argument("--dm_train", type=str, default='bilinear', choices= ('bilinear', 'malvar2004', 'menon2007'), help='Choose demosaicing for training processing model')
44
- parser.add_argument("--s_train", type=str, default='sharpening_filter', choices= ('sharpening_filter', 'unsharp_masking'), help='Choose sharpening for training processing model')
45
- parser.add_argument("--dn_train", type=str, default='gaussian_denoising', choices= ('gaussian_denoising', 'median_denoising'), help='Choose denoising for training processing model')
46
- parser.add_argument("--dm_test", type=str, default='bilinear', choices= ('bilinear', 'malvar2004', 'menon2007'), help='Choose demosaicing for testing processing model')
47
- parser.add_argument("--s_test", type=str, default='sharpening_filter', choices= ('sharpening_filter', 'unsharp_masking'), help='Choose sharpening for testing processing model')
48
- parser.add_argument("--dn_test", type=str, default='gaussian_denoising', choices= ('gaussian_denoising', 'median_denoising'), help='Choose denoising for testing processing model')
49
-
50
- #Select Ctest parameters
51
- parser.add_argument("--transform", type=str, default='identity', choices= ('identity','gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
52
- 'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform'), help='Choose transformation to show for Ctesting')
53
- parser.add_argument("--severity", type=int, default=1, choices= (1,2,3,4,5), help='Choose severity for Ctesting')
 
 
 
 
 
 
54
 
55
  args = parser.parse_args()
56
 
 
57
  class metrics:
58
  def __init__(self, confusion_matrix):
59
  self.cm = confusion_matrix
60
  self.N_classes = len(confusion_matrix)
61
 
62
  def accuracy(self):
63
- Tp = torch.diagonal(self.cm,0).sum()
64
  N_elements = torch.sum(self.cm)
65
- return Tp/N_elements
66
-
67
  def precision(self):
68
  Tp_Fp = torch.sum(self.cm, 1)
69
  Tp_Fp[Tp_Fp == 0] = 1
70
- return torch.diagonal(self.cm,0) / Tp_Fp
71
 
72
  def recall(self):
73
  Tp_Fn = torch.sum(self.cm, 0)
74
  Tp_Fn[Tp_Fn == 0] = 1
75
- return torch.diagonal(self.cm,0) / Tp_Fn
76
 
77
  def f1_score(self):
78
- prod = (self.precision()*self.recall())
79
  sum = (self.precision() + self.recall())
80
  sum[sum == 0.] = 1.
81
- return 2*( prod / sum )
82
 
83
  def over_N_runs(ms, N_runs):
84
- m, m2 = 0, 0
85
 
86
  for i in ms:
87
- m += i
88
- mu = m/N_runs
89
-
90
  for i in ms:
91
- m2 += (i-mu)**2
 
 
92
 
93
- sigma = torch.sqrt( m2 / (N_runs-1) )
94
-
95
  return mu.tolist(), sigma.tolist()
96
 
 
97
  class ABtesting:
98
- def __init__(self,
99
- dataset_name: str,
100
- augmentation: str,
101
- dm_train: str,
102
- s_train: str,
103
- dn_train: str,
104
- dm_test: str,
105
- s_test: str,
106
- dn_test: str,
107
- N_runs: int,
108
- severity=1,
109
- transform='identity',
110
- download_model=False):
111
  self.experiment_name = 'ABtesting'
112
  self.dataset_name = dataset_name
113
  self.augmentation = augmentation
@@ -129,12 +139,12 @@ class ABtesting:
129
  if sharpening == None:
130
  sharpening = self.s_test
131
  if denoising == None:
132
- denoising = self.dn_test
133
  if severity == None:
134
- severity = self.severity
135
  if transform == None:
136
  transform = self.transform
137
-
138
  dataset = get_dataset(self.dataset_name)
139
 
140
  if self.dataset_name == "Drone" or self.dataset_name == "DroneSegmentation":
@@ -146,40 +156,41 @@ class ABtesting:
146
 
147
  if not plot_mode:
148
  dataset.transform = Compose([RawProcessingPipeline(
149
- camera_parameters=dataset.camera_parameters,
150
- debayer=debayer,
151
- sharpening=sharpening,
152
- denoising=denoising,
153
- ), Distortions(severity=severity, transform=transform),
154
- Normalize(mean, std)])
155
  else:
156
- dataset.transform = Compose([RawProcessingPipeline(
157
- camera_parameters=dataset.camera_parameters,
158
- debayer=debayer,
159
- sharpening=sharpening,
160
- denoising=denoising,
161
- ), Distortions(severity=severity, transform=transform)])
162
 
163
  return dataset
164
 
165
  def ABclassification(self):
166
-
167
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
168
 
169
  parent_run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}"
170
 
171
- print(f'\nTraining pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_train}, Sharpening: {self.s_train}, Denoiser: {self.dn_train} \n')
172
- print(f'\nTesting pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_test}, Sharpening: {self.s_test}, Denoiser: {self.dn_test} \n Transform: {self.transform}, Severity: {self.severity}\n')
 
173
 
174
- accuracies, precisions, recalls, f1_scores = [],[],[],[]
175
 
176
  os.system('rm -r /tmp/py*')
177
 
178
  for N_run in range(self.N_runs):
179
 
180
  print(f"Evaluating Run {N_run}")
181
-
182
- run_name = parent_run_name+'_'+str(N_run)
183
 
184
  state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name,
185
  download_model=self.download_model)
@@ -190,18 +201,18 @@ class ABtesting:
190
 
191
  model.eval()
192
 
193
- len_classes = len(dataset.classes)
194
  confusion_matrix = torch.zeros((len_classes, len_classes))
195
 
196
  for img, label in valid_loader:
197
-
198
  prediction = model(img.to(DEVICE)).detach().cpu()
199
- prediction = torch.argmax(prediction, dim=1)
200
- confusion_matrix[label,prediction] += 1 # Real value rows, Declared columns
201
 
202
  m = metrics(confusion_matrix)
203
 
204
- accuracies.append(m.accuracy())
205
  precisions.append(m.precision())
206
  recalls.append(m.recall())
207
  f1_scores.append(m.f1_score())
@@ -213,15 +224,16 @@ class ABtesting:
213
  recall = metrics.over_N_runs(recalls, self.N_runs)
214
  f1_score = metrics.over_N_runs(f1_scores, self.N_runs)
215
  return dataset.classes, accuracy, precision, recall, f1_score
216
-
217
  def ABsegmentation(self):
218
-
219
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
220
 
221
  parent_run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}"
222
 
223
- print(f'\nTraining pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_train}, Sharpening: {self.s_train}, Denoiser: {self.dn_train} \n')
224
- print(f'\nTesting pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_test}, Sharpening: {self.s_test}, Denoiser: {self.dn_test} \n Transform: {self.transform}, Severity: {self.severity}\n')
 
225
 
226
  IoUs = []
227
 
@@ -230,34 +242,34 @@ class ABtesting:
230
  for N_run in range(self.N_runs):
231
 
232
  print(f"Evaluating Run {N_run}")
233
-
234
- run_name = parent_run_name+'_'+str(N_run)
235
 
236
- state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name,
237
- download_model=self.download_model)
 
 
238
 
239
  dataset = self.static_pip_val()
240
-
241
  valid_set = Subset(dataset, indices=state_dict['valid_indices'])
242
  valid_loader = DataLoader(valid_set, batch_size=1, num_workers=16, shuffle=False)
243
 
244
  model.eval()
245
 
246
- IoU=0
247
 
248
  for img, label in valid_loader:
249
-
250
  prediction = model(img.to(DEVICE)).detach().cpu()
251
  prediction = F.logsigmoid(prediction).exp().squeeze()
252
- IoU += smp.utils.metrics.IoU()(prediction,label)
253
 
254
- IoU = IoU/len(valid_loader)
255
  IoUs.append(IoU.item())
256
 
257
  os.system('rm -r /tmp/t*')
258
 
259
  IoU = metrics.over_N_runs(torch.tensor(IoUs), self.N_runs)
260
- return IoU
261
 
262
  def ABShowImages(self):
263
 
@@ -265,164 +277,169 @@ class ABtesting:
265
  if not os.path.exists(path):
266
  os.makedirs(path)
267
 
268
- path = os.path.join(path, f'{self.dataset_name}_{self.augmentation}_{self.dm_train[:2]}{self.s_train[0]}{self.dn_train[:2]}_{self.dm_test[:2]}{self.s_test[0]}{self.dn_test[:2]}')
 
269
 
270
  if not os.path.exists(path):
271
  os.makedirs(path)
272
 
273
- run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}"+'_'+str(0)
 
274
 
275
  state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name, download_model=self.download_model)
276
 
277
  model.augmentation = None
278
-
279
- for t in ([self.dm_train, self.s_train, self.dn_train, 'train_img'],
280
- [self.dm_test, self.s_test, self.dn_test, 'test_img']):
281
-
282
  debayer, sharpening, denoising, img_type = t[0], t[1], t[2], t[3]
283
 
284
  dataset = self.static_pip_val(debayer=debayer, sharpening=sharpening, denoising=denoising, plot_mode=True)
285
  valid_set = Subset(dataset, indices=state_dict['valid_indices'])
286
-
287
  img, _ = next(iter(valid_set))
288
 
289
  plt.figure()
290
- plt.imshow(img.permute(1,2,0))
291
  if img_type == 'train_img':
292
  plt.title('Train Image')
293
  plt.savefig(os.path.join(path, f'img_train.png'))
294
  imgA = img
295
  else:
296
  plt.title('Test Image')
297
- plt.savefig(os.path.join(path,f'img_test.png'))
298
-
299
- for c, color in enumerate(['Red','Green','Blue']):
300
- diff = torch.abs(imgA-img)
301
  plt.figure()
302
  # plt.imshow(diff.permute(1,2,0))
303
- plt.imshow(diff[c,50:200,50:200], cmap=f'{color}s')
304
  plt.title(f'|Train Image - Test Image| - {color}')
305
  plt.colorbar()
306
  plt.savefig(os.path.join(path, f'diff_{color}.png'))
307
  plt.figure()
308
- diff[diff == 0.]= 1e-5
309
  # plt.imshow(torch.log(diff.permute(1,2,0)))
310
  plt.imshow(torch.log(diff)[c])
311
  plt.title(f'log(|Train Image - Test Image|) - color')
312
  plt.colorbar()
313
  plt.savefig(os.path.join(path, f'logdiff_{color}.png'))
314
-
315
  if self.dataset_name == 'DroneSegmentation':
316
  plt.figure()
317
  plt.imshow(model(img[None].cuda()).detach().cpu().squeeze())
318
  if img_type == 'train_img':
319
  plt.savefig(os.path.join(path, f'mask_train.png'))
320
  else:
321
- plt.savefig(os.path.join(path,f'mask_test.png'))
322
 
323
  def ABShowAllImages(self):
324
  if not os.path.exists('results/ABtesting'):
325
  os.makedirs('results/ABtesting')
326
 
327
- demosaicings=['bilinear','malvar2004', 'menon2007']
328
- sharpenings=['sharpening_filter', 'unsharp_masking']
329
- denoisings=['median_denoising', 'gaussian_denoising']
330
 
331
  fig = plt.figure()
332
- columns=4
333
- rows=3
334
 
335
- i=1
336
 
337
  for dm in demosaicings:
338
  for s in sharpenings:
339
  for dn in denoisings:
340
-
341
- dataset = self.static_pip_val(self.dm_test, self.s_test,
342
- self.dn_test, plot_mode=True)
343
 
344
- img,_ = dataset[0]
345
-
 
 
 
346
  fig.add_subplot(rows, columns, i)
347
- plt.imshow(img.permute(1,2,0))
348
  plt.title(f'{dm}\n{s}\n{dn}', fontsize=8)
349
  plt.xticks([])
350
  plt.yticks([])
351
  plt.tight_layout()
352
 
353
- i+=1
354
 
355
  plt.show()
356
  plt.savefig(f'results/ABtesting/ABpipelines.png')
357
 
358
  def CShowImages(self):
359
-
360
  path = 'results/Ctesting/imgs/'
361
  if not os.path.exists(path):
362
  os.makedirs(path)
363
 
364
- run_name = f"{self.dataset_name}_{self.dm_test}_{self.s_test}_{self.dn_test}_{self.augmentation}"+'_'+str(0)
365
 
366
  state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name, download_model=True)
367
 
368
  model.augmentation = None
369
-
370
- dataset = self.static_pip_val(self.dm_test, self.s_test, self.dn_test, self.severity, self.transform, plot_mode=True)
 
371
  valid_set = Subset(dataset, indices=state_dict['valid_indices'])
372
-
373
  img, _ = next(iter(valid_set))
374
 
375
  plt.figure()
376
- plt.imshow(img.permute(1,2,0))
377
- plt.savefig(os.path.join(path, f'{self.dataset_name}_{self.augmentation}_{self.dm_train[:2]}{self.s_train[0]}{self.dn_train[:2]}_{self.transform}_sev{self.severity}'))
378
-
 
379
  def CShowAllImages(self):
380
  if not os.path.exists('results/Cimages'):
381
  os.makedirs('results/Cimages')
382
 
383
- transforms = ['identity','gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
384
- 'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform']
385
 
386
- for i,t in enumerate(transforms):
387
-
388
- fig = plt.figure(figsize=(10,6))
389
  columns = 5
390
  rows = 1
391
 
392
- for sev in range(1,6):
393
 
394
  dataset = self.static_pip_val(severity=sev, transform=t, plot_mode=True)
395
 
396
- img,_ = dataset[0]
397
-
398
  fig.add_subplot(rows, columns, sev)
399
- plt.imshow(img.permute(1,2,0))
400
  plt.title(f'Severity: {sev}')
401
  plt.xticks([])
402
  plt.yticks([])
403
  plt.tight_layout()
404
-
405
  if '_' in t:
406
- t=t.replace('_', ' ')
407
- t=t[0].upper()+t[1:]
408
 
409
  fig.suptitle(f'{t}', x=0.5, y=0.8, fontsize=24)
410
  plt.show()
411
  plt.savefig(f'results/Cimages/{i+1}_{t.lower()}.png')
412
 
413
- def ABMakeTable(dataset_name:str, augmentation: str,
414
- N_runs: int, download_model: bool):
415
 
416
- demosaicings=['bilinear','malvar2004', 'menon2007']
417
- sharpenings=['sharpening_filter', 'unsharp_masking']
418
- denoisings=['median_denoising', 'gaussian_denoising']
 
 
 
419
 
420
- path='results/ABtesting/tables'
421
  if not os.path.exists(path):
422
  os.makedirs(path)
423
 
424
- runs={}
425
- i=0
426
 
427
  for dm_train in demosaicings:
428
  for s_train in sharpenings:
@@ -431,28 +448,28 @@ def ABMakeTable(dataset_name:str, augmentation: str,
431
  for s_test in sharpenings:
432
  for dn_test in denoisings:
433
  train_pip = [dm_train, s_train, dn_train]
434
- test_pip = [dm_test, s_test, dn_test]
435
  runs[f'run{i}'] = {
436
- 'dataset': dataset_name,
437
- 'augmentation': augmentation,
438
- 'train_pip': train_pip,
439
- 'test_pip': test_pip,
440
- 'N_runs': N_runs
441
  }
442
  ABclass = ABtesting(
443
- dataset_name=dataset_name,
444
- augmentation=augmentation,
445
- dm_train = dm_train,
446
- s_train = s_train,
447
- dn_train = dn_train,
448
- dm_test = dm_test,
449
- s_test = s_test,
450
- dn_test = dn_test,
451
- N_runs=N_runs,
452
- download_model=download_model
453
- )
454
-
455
- if dataset_name == 'DroneSegmentation':
456
  IoU = ABclass.ABsegmentation()
457
  runs[f'run{i}']['IoU'] = IoU
458
  else:
@@ -462,15 +479,16 @@ def ABMakeTable(dataset_name:str, augmentation: str,
462
  runs[f'run{i}']['precision'] = precision
463
  runs[f'run{i}']['recall'] = recall
464
  runs[f'run{i}']['f1_score'] = f1_score
465
-
466
- with open(os.path.join(path,f'{dataset_name}_{augmentation}_runs.txt'), 'w') as outfile:
467
  json.dump(runs, outfile)
468
 
469
- i+=1
 
470
 
471
  def ABShowTable(dataset_name: str, augmentation: str):
472
-
473
- path='results/ABtesting/tables'
474
  assert os.path.exists(path), 'No tables to plot'
475
 
476
  json_file = os.path.join(path, f'{dataset_name}_{augmentation}_runs.txt')
@@ -478,14 +496,14 @@ def ABShowTable(dataset_name: str, augmentation: str):
478
  with open(json_file, 'r') as run_file:
479
  runs = json.load(run_file)
480
 
481
- metrics=torch.zeros((2,12,12))
482
- classes=[]
483
 
484
- i,j=0,0
485
 
486
  for r in range(len(runs)):
487
-
488
- run = runs['run'+str(r)]
489
  if dataset_name == 'DroneSegmentation':
490
  acc = run['IoU']
491
  else:
@@ -494,64 +512,66 @@ def ABShowTable(dataset_name: str, augmentation: str):
494
  class_list = run['test_pip']
495
  class_name = f'{class_list[0][:2]},{class_list[1][:1]},{class_list[2][:2]}'
496
  classes.append(class_name)
497
- mu,sigma = round(acc[0],4),round(acc[1],4)
 
 
 
498
 
499
- metrics[0,j,i] = mu
500
- metrics[1,j,i] = sigma
501
-
502
- i+=1
503
 
504
  if i == 12:
505
- i=0
506
- j+=1
507
 
508
  differences = torch.zeros_like(metrics)
509
 
510
- diag_mu = torch.diagonal(metrics[0],0)
511
- diag_sigma = torch.diagonal(metrics[1],0)
512
-
513
  for r in range(len(metrics[0])):
514
- differences[0,r] = diag_mu[r] - metrics[0,r]
515
- differences[1,r] = torch.sqrt(metrics[1,r]**2 + diag_sigma[r]**2)
516
 
517
  # Plot with scatter
518
-
519
- for i,img in enumerate([metrics, differences]):
520
 
521
  x, y = torch.arange(12), torch.arange(12)
522
  x, y = torch.meshgrid(x, y)
523
 
524
  if i == 0:
525
- vmin = max(0.65, round(img[0].min().item(),2))
526
- vmax = round(img[0].max().item(),2)
527
  step = 0.02
528
  elif i == 1:
529
- vmin = round(img[0].min().item(),2)
530
  if augmentation == 'none':
531
- vmax = min(0.15, round(img[0].max().item(),2))
532
  if augmentation == 'weak':
533
- vmax = min(0.08, round(img[0].max().item(),2))
534
  if augmentation == 'strong':
535
- vmax = min(0.05, round(img[0].max().item(),2))
536
  step = 0.01
537
-
538
- vmin = int(vmin/step)*step
539
- vmax = int(vmax/step)*step
540
 
541
- fig = plt.figure(figsize=(10,6.2))
 
 
 
542
  ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
543
- marker_size=350
544
- plt.scatter(x, y, c=torch.rot90(img[1][x,y],-1,[0,1]), vmin = 0., vmax = img[1].max(), cmap='viridis', s=marker_size*2, marker='s')
545
- ticks = torch.arange(0.,img[1].max(),0.03).tolist()
546
- ticks = [round(tick,2) for tick in ticks]
 
547
  cba = plt.colorbar(pad=0.06)
548
  cba.set_ticks(ticks)
549
  cba.ax.set_yticklabels(ticks)
550
  # cmap = plt.cm.get_cmap('tab20c').reversed()
551
  cmap = plt.cm.get_cmap('Reds')
552
- plt.scatter(x,y, c=torch.rot90(img[0][x,y],-1,[0,1]), vmin = vmin, vmax = vmax, cmap=cmap, s=marker_size, marker='s')
 
553
  ticks = torch.arange(vmin, vmax, step).tolist()
554
- ticks = [round(tick,2) for tick in ticks]
555
  if ticks[-1] != vmax:
556
  ticks.append(vmax)
557
  cbb = plt.colorbar(pad=0.06)
@@ -563,26 +583,26 @@ def ABShowTable(dataset_name: str, augmentation: str):
563
  cbb.ax.set_yticklabels(ticks)
564
  for x in range(12):
565
  for y in range(12):
566
- txt = round(torch.rot90(img[0],-1,[0,1])[x,y].item(),2)
567
  if str(txt) == '-0.0':
568
  txt = '0.00'
569
  elif str(txt) == '0.0':
570
  txt = '0.00'
571
  elif len(str(txt)) == 3:
572
- txt = str(txt)+'0'
573
  else:
574
  txt = str(txt)
575
-
576
- plt.text(x-0.25,y-0.1,txt, color='black', fontsize='x-small')
577
 
578
- ax.set_xticks(torch.linspace(0,11,12))
 
 
579
  ax.set_xticklabels(classes)
580
- ax.set_yticks(torch.linspace(0,11,12))
581
  classes.reverse()
582
  ax.set_yticklabels(classes)
583
  classes.reverse()
584
- plt.xticks(rotation = 45)
585
- plt.yticks(rotation = 45)
586
  cba.set_label('Standard Deviation')
587
  plt.xlabel("Test pipelines")
588
  plt.ylabel("Train pipelines")
@@ -590,62 +610,63 @@ def ABShowTable(dataset_name: str, augmentation: str):
590
  if i == 0:
591
  if dataset_name == 'DroneSegmentation':
592
  cbb.set_label('IoU')
593
- plt.savefig(os.path.join(path,f"{dataset_name}_{augmentation}_IoU.png"))
594
  else:
595
  cbb.set_label('Accuracy')
596
- plt.savefig(os.path.join(path,f"{dataset_name}_{augmentation}_accuracies.png"))
597
  elif i == 1:
598
  if dataset_name == 'DroneSegmentation':
599
  cbb.set_label('IoU_d-IoU')
600
  else:
601
  cbb.set_label('Accuracy_d - Accuracy')
602
- plt.savefig(os.path.join(path,f"{dataset_name}_{augmentation}_differences.png"))
 
603
 
604
  def CMakeTable(dataset_name: str, augmentation: str, severity: int, N_runs: int, download_model: bool):
605
-
606
- path='results/Ctesting/tables'
607
  if not os.path.exists(path):
608
  os.makedirs(path)
609
-
610
- demosaicings=['bilinear','malvar2004', 'menon2007']
611
- sharpenings=['sharpening_filter', 'unsharp_masking']
612
- denoisings=['median_denoising', 'gaussian_denoising']
613
 
614
- transformations = ['identity','gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
615
- 'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform']
 
 
 
 
616
 
617
- runs={}
618
- i=0
619
 
620
  for dm in demosaicings:
621
  for s in sharpenings:
622
  for dn in denoisings:
623
  for t in transformations:
624
- pip = [dm,s,dn]
625
  runs[f'run{i}'] = {
626
- 'dataset': dataset_name,
627
- 'augmentation': augmentation,
628
- 'pipeline': pip,
629
- 'N_runs': N_runs,
630
- 'transform': t,
631
- 'severity': severity,
632
  }
633
  ABclass = ABtesting(
634
- dataset_name=dataset_name,
635
- augmentation=augmentation,
636
- dm_train = dm,
637
- s_train = s,
638
- dn_train = dn,
639
- dm_test = dm,
640
- s_test = s,
641
- dn_test = dn,
642
- severity=severity,
643
- transform=t,
644
- N_runs=N_runs,
645
- download_model=download_model
646
- )
647
-
648
- if dataset_name == 'DroneSegmentation':
649
  IoU = ABclass.ABsegmentation()
650
  runs[f'run{i}']['IoU'] = IoU
651
  else:
@@ -656,26 +677,27 @@ def CMakeTable(dataset_name: str, augmentation: str, severity: int, N_runs: int,
656
  runs[f'run{i}']['recall'] = recall
657
  runs[f'run{i}']['f1_score'] = f1_score
658
 
659
- with open(os.path.join(path,f'{dataset_name}_{augmentation}_runs.json'), 'w') as outfile:
660
  json.dump(runs, outfile)
661
 
662
- i+=1
 
663
 
664
  def CShowTable(dataset_name, augmentation):
665
 
666
- path='results/Ctesting/tables'
667
  assert os.path.exists(path), 'No tables to plot'
668
 
669
  json_file = os.path.join(path, f'{dataset_name}_{augmentation}_runs.txt')
670
 
671
- transforms = ['identity','gauss_noise', 'shot', 'impulse', 'speckle',
672
- 'gauss_blur', 'zoom', 'contrast', 'brightness', 'saturate', 'elastic']
673
 
674
  pip = []
675
-
676
- demosaicings=['bilinear','malvar2004', 'menon2007']
677
- sharpenings=['sharpening_filter', 'unsharp_masking']
678
- denoisings=['median_denoising', 'gaussian_denoising']
679
 
680
  for dm in demosaicings:
681
  for s in sharpenings:
@@ -685,52 +707,54 @@ def CShowTable(dataset_name, augmentation):
685
  with open(json_file, 'r') as run_file:
686
  runs = json.load(run_file)
687
 
688
- metrics=torch.zeros((2,len(pip),len(transforms)))
689
 
690
- i,j=0,0
691
 
692
  for r in range(len(runs)):
693
-
694
- run = runs['run'+str(r)]
695
  if dataset_name == 'DroneSegmentation':
696
  acc = run['IoU']
697
  else:
698
  acc = run['accuracy']
699
- mu,sigma = round(acc[0],4),round(acc[1],4)
700
 
701
- metrics[0,j,i] = mu
702
- metrics[1,j,i] = sigma
703
-
704
- i+=1
705
 
706
  if i == len(transforms):
707
- i=0
708
- j+=1
709
 
710
  # Plot with scatter
711
 
712
  img = metrics
713
 
714
- vmin=0.
715
- vmax=1.
716
-
717
  x, y = torch.arange(12), torch.arange(11)
718
  x, y = torch.meshgrid(x, y)
719
 
720
- fig = plt.figure(figsize=(10,6.2))
721
  ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
722
- marker_size=350
723
- plt.scatter(x, y, c=torch.rot90(img[1][x,y],-1,[0,1]), vmin = 0., vmax = img[1].max(), cmap='viridis', s=marker_size*2, marker='s')
724
- ticks = torch.arange(0.,img[1].max(),0.03).tolist()
725
- ticks = [round(tick,2) for tick in ticks]
 
726
  cba = plt.colorbar(pad=0.06)
727
  cba.set_ticks(ticks)
728
  cba.ax.set_yticklabels(ticks)
729
  # cmap = plt.cm.get_cmap('tab20c').reversed()
730
  cmap = plt.cm.get_cmap('Reds')
731
- plt.scatter(x,y, c=torch.rot90(img[0][x,y],-1,[0,1]), vmin=vmin, vmax=vmax, cmap=cmap, s=marker_size, marker='s')
 
732
  ticks = torch.arange(vmin, vmax, step).tolist()
733
- ticks = [round(tick,2) for tick in ticks]
734
  if ticks[-1] != vmax:
735
  ticks.append(vmax)
736
  cbb = plt.colorbar(pad=0.06)
@@ -742,65 +766,66 @@ def CShowTable(dataset_name, augmentation):
742
  cbb.ax.set_yticklabels(ticks)
743
  for x in range(12):
744
  for y in range(12):
745
- txt = round(torch.rot90(img[0],-1,[0,1])[x,y].item(),2)
746
  if str(txt) == '-0.0':
747
  txt = '0.00'
748
  elif str(txt) == '0.0':
749
  txt = '0.00'
750
  elif len(str(txt)) == 3:
751
- txt = str(txt)+'0'
752
  else:
753
  txt = str(txt)
754
-
755
- plt.text(x-0.25,y-0.1,txt, color='black', fontsize='x-small')
756
 
757
- ax.set_xticks(torch.linspace(0,11,12))
 
 
758
  ax.set_xticklabels(transforms)
759
- ax.set_yticks(torch.linspace(0,11,12))
760
  pip.reverse()
761
  ax.set_yticklabels(pip)
762
  pip.reverse()
763
- plt.xticks(rotation = 45)
764
- plt.yticks(rotation = 45)
765
  cba.set_label('Standard Deviation')
766
  plt.xlabel("Pipelines")
767
  plt.ylabel("Distortions")
768
  if dataset_name == 'DroneSegmentation':
769
  cbb.set_label('IoU')
770
- plt.savefig(os.path.join(path,f"{dataset_name}_{augmentation}_IoU.png"))
771
  else:
772
  cbb.set_label('Accuracy')
773
- plt.savefig(os.path.join(path,f"{dataset_name}_{augmentation}_accuracies.png"))
 
774
 
775
  if __name__ == '__main__':
776
-
777
- if args.mode == 'ABMakeTable':
778
  ABMakeTable(args.dataset_name, args.augmentation, args.N_runs, args.download_model)
779
- elif args.mode == 'ABShowTable':
780
  ABShowTable(args.dataset_name, args.augmentation)
781
- elif args.mode == 'ABShowImages':
782
- ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
783
- args.s_train, args.dn_train, args.dm_test, args.s_test,
784
  args.dn_test, args.N_runs, download_model=args.download_model)
785
  ABclass.ABShowImages()
786
- elif args.mode == 'ABShowAllImages':
787
- ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
788
- args.s_train, args.dn_train, args.dm_test, args.s_test,
789
- args.dn_test, args.N_runs, download_model=args.download_model)
790
  ABclass.ABShowAllImages()
791
- elif args.mode == 'CMakeTable':
792
  CMakeTable(args.dataset_name, args.augmentation, args.severity, args.N_runs, args.download_model)
793
- elif args.mode == 'CShowTable': # TODO test it
794
  CShowTable(args.dataset_name, args.augmentation, args.severity)
795
- elif args.mode == 'CShowImages':
796
- ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
797
- args.s_train, args.dn_train, args.dm_test, args.s_test,
798
- args.dn_test, args.N_runs, args.severity, args.transform,
799
- download_model=args.download_model)
800
  ABclass.CShowImages()
801
- elif args.mode == 'CShowAllImages':
802
- ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
803
- args.s_train, args.dn_train, args.dm_test, args.s_test,
804
- args.dn_test, args.N_runs, args.severity, args.transform,
805
- download_model=args.download_model)
806
  ABclass.CShowAllImages()
 
8
  from torchvision.transforms import Compose, Normalize
9
  import torch.nn.functional as F
10
 
11
+ from dataset import get_dataset, Subset
12
  from utils.base import get_mlflow_model_by_name, SmartFormatter
13
+ from processing.pipeline_numpy import RawProcessingPipeline
14
 
15
+ from utils.hendrycks_robustness import Distortions
16
 
17
  import segmentation_models_pytorch as smp
18
 
 
20
 
21
  parser = argparse.ArgumentParser(description="AB testing, Show Results", formatter_class=SmartFormatter)
22
 
23
+ # Select experiment
24
+ parser.add_argument("--mode", type=str, default="ABShowImages", choices=('ABMakeTable', 'ABShowTable', 'ABShowImages', 'ABShowAllImages', 'CMakeTable', 'CShowTable', 'CShowImages', 'CShowAllImages'),
25
  help='R|Choose operation to compute. \n'
26
+ 'A) Lens2Logit image generation: \n '
27
+ 'ABMakeTable: Compute cross-validation metrics results \n '
28
+ 'ABShowTable: Plot cross-validation results on a table \n '
29
+ 'ABShowImages: Choose a training and testing image to compare different pipelines \n '
30
+ 'ABShowAllImages: Plot all possible pipelines \n'
31
+ 'B) Hendrycks Perturbations, C-type dataset: \n '
32
+ 'CMakeTable: For each pipeline, it computes cross-validation metrics for different perturbations \n '
33
+ 'CShowTable: Plot metrics for different pipelines and perturbations \n '
34
+ 'CShowImages: Plot an image with a selected a pipeline and perturbation\n '
35
+ 'CShowAllImages: Plot all possible perturbations for a fixed pipeline')
36
+
37
+ parser.add_argument("--dataset_name", type=str, default='Microscopy',
38
+ choices=['Microscopy', 'Drone', 'DroneSegmentation'], help='Choose dataset')
39
+ parser.add_argument("--augmentation", type=str, default='weak',
40
+ choices=['none', 'weak', 'strong'], help='Choose augmentation')
41
  parser.add_argument("--N_runs", type=int, default=5, help='Number of k-fold splitting used in the training')
42
  parser.add_argument("--download_model", default=False, action='store_true', help='Download Models in cache')
43
 
44
+ # Select pipelines
45
+ parser.add_argument("--dm_train", type=str, default='bilinear', choices=('bilinear', 'malvar2004',
46
+ 'menon2007'), help='Choose demosaicing for training processing model')
47
+ parser.add_argument("--s_train", type=str, default='sharpening_filter', choices=('sharpening_filter',
48
+ 'unsharp_masking'), help='Choose sharpening for training processing model')
49
+ parser.add_argument("--dn_train", type=str, default='gaussian_denoising', choices=('gaussian_denoising',
50
+ 'median_denoising'), help='Choose denoising for training processing model')
51
+ parser.add_argument("--dm_test", type=str, default='bilinear', choices=('bilinear', 'malvar2004',
52
+ 'menon2007'), help='Choose demosaicing for testing processing model')
53
+ parser.add_argument("--s_test", type=str, default='sharpening_filter', choices=('sharpening_filter',
54
+ 'unsharp_masking'), help='Choose sharpening for testing processing model')
55
+ parser.add_argument("--dn_test", type=str, default='gaussian_denoising', choices=('gaussian_denoising',
56
+ 'median_denoising'), help='Choose denoising for testing processing model')
57
+
58
+ # Select Ctest parameters
59
+ parser.add_argument("--transform", type=str, default='identity', choices=('identity', 'gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
60
+ 'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform'), help='Choose transformation to show for Ctesting')
61
+ parser.add_argument("--severity", type=int, default=1, choices=(1, 2, 3, 4, 5), help='Choose severity for Ctesting')
62
 
63
  args = parser.parse_args()
64
 
65
+
66
  class metrics:
67
  def __init__(self, confusion_matrix):
68
  self.cm = confusion_matrix
69
  self.N_classes = len(confusion_matrix)
70
 
71
  def accuracy(self):
72
+ Tp = torch.diagonal(self.cm, 0).sum()
73
  N_elements = torch.sum(self.cm)
74
+ return Tp / N_elements
75
+
76
  def precision(self):
77
  Tp_Fp = torch.sum(self.cm, 1)
78
  Tp_Fp[Tp_Fp == 0] = 1
79
+ return torch.diagonal(self.cm, 0) / Tp_Fp
80
 
81
  def recall(self):
82
  Tp_Fn = torch.sum(self.cm, 0)
83
  Tp_Fn[Tp_Fn == 0] = 1
84
+ return torch.diagonal(self.cm, 0) / Tp_Fn
85
 
86
  def f1_score(self):
87
+ prod = (self.precision() * self.recall())
88
  sum = (self.precision() + self.recall())
89
  sum[sum == 0.] = 1.
90
+ return 2 * (prod / sum)
91
 
92
  def over_N_runs(ms, N_runs):
93
+ m, m2 = 0, 0
94
 
95
  for i in ms:
96
+ m += i
97
+ mu = m / N_runs
98
+
99
  for i in ms:
100
+ m2 += (i - mu)**2
101
+
102
+ sigma = torch.sqrt(m2 / (N_runs - 1))
103
 
 
 
104
  return mu.tolist(), sigma.tolist()
105
 
106
+
107
  class ABtesting:
108
+ def __init__(self,
109
+ dataset_name: str,
110
+ augmentation: str,
111
+ dm_train: str,
112
+ s_train: str,
113
+ dn_train: str,
114
+ dm_test: str,
115
+ s_test: str,
116
+ dn_test: str,
117
+ N_runs: int,
118
+ severity=1,
119
+ transform='identity',
120
+ download_model=False):
121
  self.experiment_name = 'ABtesting'
122
  self.dataset_name = dataset_name
123
  self.augmentation = augmentation
 
139
  if sharpening == None:
140
  sharpening = self.s_test
141
  if denoising == None:
142
+ denoising = self.dn_test
143
  if severity == None:
144
+ severity = self.severity
145
  if transform == None:
146
  transform = self.transform
147
+
148
  dataset = get_dataset(self.dataset_name)
149
 
150
  if self.dataset_name == "Drone" or self.dataset_name == "DroneSegmentation":
 
156
 
157
  if not plot_mode:
158
  dataset.transform = Compose([RawProcessingPipeline(
159
+ camera_parameters=dataset.camera_parameters,
160
+ debayer=debayer,
161
+ sharpening=sharpening,
162
+ denoising=denoising,
163
+ ), Distortions(severity=severity, transform=transform),
164
+ Normalize(mean, std)])
165
  else:
166
+ dataset.transform = Compose([RawProcessingPipeline(
167
+ camera_parameters=dataset.camera_parameters,
168
+ debayer=debayer,
169
+ sharpening=sharpening,
170
+ denoising=denoising,
171
+ ), Distortions(severity=severity, transform=transform)])
172
 
173
  return dataset
174
 
175
  def ABclassification(self):
176
+
177
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
178
 
179
  parent_run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}"
180
 
181
+ print(
182
+ f'\nTraining pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_train}, Sharpening: {self.s_train}, Denoiser: {self.dn_train} \n')
183
+ print(f'\nTesting pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_test}, Sharpening: {self.s_test}, Denoiser: {self.dn_test} \n Transform: {self.transform}, Severity: {self.severity}\n')
184
 
185
+ accuracies, precisions, recalls, f1_scores = [], [], [], []
186
 
187
  os.system('rm -r /tmp/py*')
188
 
189
  for N_run in range(self.N_runs):
190
 
191
  print(f"Evaluating Run {N_run}")
192
+
193
+ run_name = parent_run_name + '_' + str(N_run)
194
 
195
  state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name,
196
  download_model=self.download_model)
 
201
 
202
  model.eval()
203
 
204
+ len_classes = len(dataset.classes)
205
  confusion_matrix = torch.zeros((len_classes, len_classes))
206
 
207
  for img, label in valid_loader:
208
+
209
  prediction = model(img.to(DEVICE)).detach().cpu()
210
+ prediction = torch.argmax(prediction, dim=1)
211
+ confusion_matrix[label, prediction] += 1 # Real value rows, Declared columns
212
 
213
  m = metrics(confusion_matrix)
214
 
215
+ accuracies.append(m.accuracy())
216
  precisions.append(m.precision())
217
  recalls.append(m.recall())
218
  f1_scores.append(m.f1_score())
 
224
  recall = metrics.over_N_runs(recalls, self.N_runs)
225
  f1_score = metrics.over_N_runs(f1_scores, self.N_runs)
226
  return dataset.classes, accuracy, precision, recall, f1_score
227
+
228
  def ABsegmentation(self):
229
+
230
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
231
 
232
  parent_run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}"
233
 
234
+ print(
235
+ f'\nTraining pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_train}, Sharpening: {self.s_train}, Denoiser: {self.dn_train} \n')
236
+ print(f'\nTesting pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_test}, Sharpening: {self.s_test}, Denoiser: {self.dn_test} \n Transform: {self.transform}, Severity: {self.severity}\n')
237
 
238
  IoUs = []
239
 
 
242
  for N_run in range(self.N_runs):
243
 
244
  print(f"Evaluating Run {N_run}")
 
 
245
 
246
+ run_name = parent_run_name + '_' + str(N_run)
247
+
248
+ state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name,
249
+ download_model=self.download_model)
250
 
251
  dataset = self.static_pip_val()
252
+
253
  valid_set = Subset(dataset, indices=state_dict['valid_indices'])
254
  valid_loader = DataLoader(valid_set, batch_size=1, num_workers=16, shuffle=False)
255
 
256
  model.eval()
257
 
258
+ IoU = 0
259
 
260
  for img, label in valid_loader:
261
+
262
  prediction = model(img.to(DEVICE)).detach().cpu()
263
  prediction = F.logsigmoid(prediction).exp().squeeze()
264
+ IoU += smp.utils.metrics.IoU()(prediction, label)
265
 
266
+ IoU = IoU / len(valid_loader)
267
  IoUs.append(IoU.item())
268
 
269
  os.system('rm -r /tmp/t*')
270
 
271
  IoU = metrics.over_N_runs(torch.tensor(IoUs), self.N_runs)
272
+ return IoU
273
 
274
  def ABShowImages(self):
275
 
 
277
  if not os.path.exists(path):
278
  os.makedirs(path)
279
 
280
+ path = os.path.join(
281
+ path, f'{self.dataset_name}_{self.augmentation}_{self.dm_train[:2]}{self.s_train[0]}{self.dn_train[:2]}_{self.dm_test[:2]}{self.s_test[0]}{self.dn_test[:2]}')
282
 
283
  if not os.path.exists(path):
284
  os.makedirs(path)
285
 
286
+ run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}" + \
287
+ '_' + str(0)
288
 
289
  state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name, download_model=self.download_model)
290
 
291
  model.augmentation = None
292
+
293
+ for t in ([self.dm_train, self.s_train, self.dn_train, 'train_img'],
294
+ [self.dm_test, self.s_test, self.dn_test, 'test_img']):
295
+
296
  debayer, sharpening, denoising, img_type = t[0], t[1], t[2], t[3]
297
 
298
  dataset = self.static_pip_val(debayer=debayer, sharpening=sharpening, denoising=denoising, plot_mode=True)
299
  valid_set = Subset(dataset, indices=state_dict['valid_indices'])
300
+
301
  img, _ = next(iter(valid_set))
302
 
303
  plt.figure()
304
+ plt.imshow(img.permute(1, 2, 0))
305
  if img_type == 'train_img':
306
  plt.title('Train Image')
307
  plt.savefig(os.path.join(path, f'img_train.png'))
308
  imgA = img
309
  else:
310
  plt.title('Test Image')
311
+ plt.savefig(os.path.join(path, f'img_test.png'))
312
+
313
+ for c, color in enumerate(['Red', 'Green', 'Blue']):
314
+ diff = torch.abs(imgA - img)
315
  plt.figure()
316
  # plt.imshow(diff.permute(1,2,0))
317
+ plt.imshow(diff[c, 50:200, 50:200], cmap=f'{color}s')
318
  plt.title(f'|Train Image - Test Image| - {color}')
319
  plt.colorbar()
320
  plt.savefig(os.path.join(path, f'diff_{color}.png'))
321
  plt.figure()
322
+ diff[diff == 0.] = 1e-5
323
  # plt.imshow(torch.log(diff.permute(1,2,0)))
324
  plt.imshow(torch.log(diff)[c])
325
  plt.title(f'log(|Train Image - Test Image|) - color')
326
  plt.colorbar()
327
  plt.savefig(os.path.join(path, f'logdiff_{color}.png'))
328
+
329
  if self.dataset_name == 'DroneSegmentation':
330
  plt.figure()
331
  plt.imshow(model(img[None].cuda()).detach().cpu().squeeze())
332
  if img_type == 'train_img':
333
  plt.savefig(os.path.join(path, f'mask_train.png'))
334
  else:
335
+ plt.savefig(os.path.join(path, f'mask_test.png'))
336
 
337
  def ABShowAllImages(self):
338
  if not os.path.exists('results/ABtesting'):
339
  os.makedirs('results/ABtesting')
340
 
341
+ demosaicings = ['bilinear', 'malvar2004', 'menon2007']
342
+ sharpenings = ['sharpening_filter', 'unsharp_masking']
343
+ denoisings = ['median_denoising', 'gaussian_denoising']
344
 
345
  fig = plt.figure()
346
+ columns = 4
347
+ rows = 3
348
 
349
+ i = 1
350
 
351
  for dm in demosaicings:
352
  for s in sharpenings:
353
  for dn in denoisings:
 
 
 
354
 
355
+ dataset = self.static_pip_val(self.dm_test, self.s_test,
356
+ self.dn_test, plot_mode=True)
357
+
358
+ img, _ = dataset[0]
359
+
360
  fig.add_subplot(rows, columns, i)
361
+ plt.imshow(img.permute(1, 2, 0))
362
  plt.title(f'{dm}\n{s}\n{dn}', fontsize=8)
363
  plt.xticks([])
364
  plt.yticks([])
365
  plt.tight_layout()
366
 
367
+ i += 1
368
 
369
  plt.show()
370
  plt.savefig(f'results/ABtesting/ABpipelines.png')
371
 
372
  def CShowImages(self):
373
+
374
  path = 'results/Ctesting/imgs/'
375
  if not os.path.exists(path):
376
  os.makedirs(path)
377
 
378
+ run_name = f"{self.dataset_name}_{self.dm_test}_{self.s_test}_{self.dn_test}_{self.augmentation}" + '_' + str(0)
379
 
380
  state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name, download_model=True)
381
 
382
  model.augmentation = None
383
+
384
+ dataset = self.static_pip_val(self.dm_test, self.s_test, self.dn_test,
385
+ self.severity, self.transform, plot_mode=True)
386
  valid_set = Subset(dataset, indices=state_dict['valid_indices'])
387
+
388
  img, _ = next(iter(valid_set))
389
 
390
  plt.figure()
391
+ plt.imshow(img.permute(1, 2, 0))
392
+ plt.savefig(os.path.join(
393
+ path, f'{self.dataset_name}_{self.augmentation}_{self.dm_train[:2]}{self.s_train[0]}{self.dn_train[:2]}_{self.transform}_sev{self.severity}'))
394
+
395
  def CShowAllImages(self):
396
  if not os.path.exists('results/Cimages'):
397
  os.makedirs('results/Cimages')
398
 
399
+ transforms = ['identity', 'gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
400
+ 'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform']
401
 
402
+ for i, t in enumerate(transforms):
403
+
404
+ fig = plt.figure(figsize=(10, 6))
405
  columns = 5
406
  rows = 1
407
 
408
+ for sev in range(1, 6):
409
 
410
  dataset = self.static_pip_val(severity=sev, transform=t, plot_mode=True)
411
 
412
+ img, _ = dataset[0]
413
+
414
  fig.add_subplot(rows, columns, sev)
415
+ plt.imshow(img.permute(1, 2, 0))
416
  plt.title(f'Severity: {sev}')
417
  plt.xticks([])
418
  plt.yticks([])
419
  plt.tight_layout()
420
+
421
  if '_' in t:
422
+ t = t.replace('_', ' ')
423
+ t = t[0].upper() + t[1:]
424
 
425
  fig.suptitle(f'{t}', x=0.5, y=0.8, fontsize=24)
426
  plt.show()
427
  plt.savefig(f'results/Cimages/{i+1}_{t.lower()}.png')
428
 
 
 
429
 
430
+ def ABMakeTable(dataset_name: str, augmentation: str,
431
+ N_runs: int, download_model: bool):
432
+
433
+ demosaicings = ['bilinear', 'malvar2004', 'menon2007']
434
+ sharpenings = ['sharpening_filter', 'unsharp_masking']
435
+ denoisings = ['median_denoising', 'gaussian_denoising']
436
 
437
+ path = 'results/ABtesting/tables'
438
  if not os.path.exists(path):
439
  os.makedirs(path)
440
 
441
+ runs = {}
442
+ i = 0
443
 
444
  for dm_train in demosaicings:
445
  for s_train in sharpenings:
 
448
  for s_test in sharpenings:
449
  for dn_test in denoisings:
450
  train_pip = [dm_train, s_train, dn_train]
451
+ test_pip = [dm_test, s_test, dn_test]
452
  runs[f'run{i}'] = {
453
+ 'dataset': dataset_name,
454
+ 'augmentation': augmentation,
455
+ 'train_pip': train_pip,
456
+ 'test_pip': test_pip,
457
+ 'N_runs': N_runs
458
  }
459
  ABclass = ABtesting(
460
+ dataset_name=dataset_name,
461
+ augmentation=augmentation,
462
+ dm_train=dm_train,
463
+ s_train=s_train,
464
+ dn_train=dn_train,
465
+ dm_test=dm_test,
466
+ s_test=s_test,
467
+ dn_test=dn_test,
468
+ N_runs=N_runs,
469
+ download_model=download_model
470
+ )
471
+
472
+ if dataset_name == 'DroneSegmentation':
473
  IoU = ABclass.ABsegmentation()
474
  runs[f'run{i}']['IoU'] = IoU
475
  else:
 
479
  runs[f'run{i}']['precision'] = precision
480
  runs[f'run{i}']['recall'] = recall
481
  runs[f'run{i}']['f1_score'] = f1_score
482
+
483
+ with open(os.path.join(path, f'{dataset_name}_{augmentation}_runs.txt'), 'w') as outfile:
484
  json.dump(runs, outfile)
485
 
486
+ i += 1
487
+
488
 
489
  def ABShowTable(dataset_name: str, augmentation: str):
490
+
491
+ path = 'results/ABtesting/tables'
492
  assert os.path.exists(path), 'No tables to plot'
493
 
494
  json_file = os.path.join(path, f'{dataset_name}_{augmentation}_runs.txt')
 
496
  with open(json_file, 'r') as run_file:
497
  runs = json.load(run_file)
498
 
499
+ metrics = torch.zeros((2, 12, 12))
500
+ classes = []
501
 
502
+ i, j = 0, 0
503
 
504
  for r in range(len(runs)):
505
+
506
+ run = runs['run' + str(r)]
507
  if dataset_name == 'DroneSegmentation':
508
  acc = run['IoU']
509
  else:
 
512
  class_list = run['test_pip']
513
  class_name = f'{class_list[0][:2]},{class_list[1][:1]},{class_list[2][:2]}'
514
  classes.append(class_name)
515
+ mu, sigma = round(acc[0], 4), round(acc[1], 4)
516
+
517
+ metrics[0, j, i] = mu
518
+ metrics[1, j, i] = sigma
519
 
520
+ i += 1
 
 
 
521
 
522
  if i == 12:
523
+ i = 0
524
+ j += 1
525
 
526
  differences = torch.zeros_like(metrics)
527
 
528
+ diag_mu = torch.diagonal(metrics[0], 0)
529
+ diag_sigma = torch.diagonal(metrics[1], 0)
530
+
531
  for r in range(len(metrics[0])):
532
+ differences[0, r] = diag_mu[r] - metrics[0, r]
533
+ differences[1, r] = torch.sqrt(metrics[1, r]**2 + diag_sigma[r]**2)
534
 
535
  # Plot with scatter
536
+
537
+ for i, img in enumerate([metrics, differences]):
538
 
539
  x, y = torch.arange(12), torch.arange(12)
540
  x, y = torch.meshgrid(x, y)
541
 
542
  if i == 0:
543
+ vmin = max(0.65, round(img[0].min().item(), 2))
544
+ vmax = round(img[0].max().item(), 2)
545
  step = 0.02
546
  elif i == 1:
547
+ vmin = round(img[0].min().item(), 2)
548
  if augmentation == 'none':
549
+ vmax = min(0.15, round(img[0].max().item(), 2))
550
  if augmentation == 'weak':
551
+ vmax = min(0.08, round(img[0].max().item(), 2))
552
  if augmentation == 'strong':
553
+ vmax = min(0.05, round(img[0].max().item(), 2))
554
  step = 0.01
 
 
 
555
 
556
+ vmin = int(vmin / step) * step
557
+ vmax = int(vmax / step) * step
558
+
559
+ fig = plt.figure(figsize=(10, 6.2))
560
  ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
561
+ marker_size = 350
562
+ plt.scatter(x, y, c=torch.rot90(img[1][x, y], -1, [0, 1]), vmin=0.,
563
+ vmax=img[1].max(), cmap='viridis', s=marker_size * 2, marker='s')
564
+ ticks = torch.arange(0., img[1].max(), 0.03).tolist()
565
+ ticks = [round(tick, 2) for tick in ticks]
566
  cba = plt.colorbar(pad=0.06)
567
  cba.set_ticks(ticks)
568
  cba.ax.set_yticklabels(ticks)
569
  # cmap = plt.cm.get_cmap('tab20c').reversed()
570
  cmap = plt.cm.get_cmap('Reds')
571
+ plt.scatter(x, y, c=torch.rot90(img[0][x, y], -1, [0, 1]), vmin=vmin,
572
+ vmax=vmax, cmap=cmap, s=marker_size, marker='s')
573
  ticks = torch.arange(vmin, vmax, step).tolist()
574
+ ticks = [round(tick, 2) for tick in ticks]
575
  if ticks[-1] != vmax:
576
  ticks.append(vmax)
577
  cbb = plt.colorbar(pad=0.06)
 
583
  cbb.ax.set_yticklabels(ticks)
584
  for x in range(12):
585
  for y in range(12):
586
+ txt = round(torch.rot90(img[0], -1, [0, 1])[x, y].item(), 2)
587
  if str(txt) == '-0.0':
588
  txt = '0.00'
589
  elif str(txt) == '0.0':
590
  txt = '0.00'
591
  elif len(str(txt)) == 3:
592
+ txt = str(txt) + '0'
593
  else:
594
  txt = str(txt)
 
 
595
 
596
+ plt.text(x - 0.25, y - 0.1, txt, color='black', fontsize='x-small')
597
+
598
+ ax.set_xticks(torch.linspace(0, 11, 12))
599
  ax.set_xticklabels(classes)
600
+ ax.set_yticks(torch.linspace(0, 11, 12))
601
  classes.reverse()
602
  ax.set_yticklabels(classes)
603
  classes.reverse()
604
+ plt.xticks(rotation=45)
605
+ plt.yticks(rotation=45)
606
  cba.set_label('Standard Deviation')
607
  plt.xlabel("Test pipelines")
608
  plt.ylabel("Train pipelines")
 
610
  if i == 0:
611
  if dataset_name == 'DroneSegmentation':
612
  cbb.set_label('IoU')
613
+ plt.savefig(os.path.join(path, f"{dataset_name}_{augmentation}_IoU.png"))
614
  else:
615
  cbb.set_label('Accuracy')
616
+ plt.savefig(os.path.join(path, f"{dataset_name}_{augmentation}_accuracies.png"))
617
  elif i == 1:
618
  if dataset_name == 'DroneSegmentation':
619
  cbb.set_label('IoU_d-IoU')
620
  else:
621
  cbb.set_label('Accuracy_d - Accuracy')
622
+ plt.savefig(os.path.join(path, f"{dataset_name}_{augmentation}_differences.png"))
623
+
624
 
625
  def CMakeTable(dataset_name: str, augmentation: str, severity: int, N_runs: int, download_model: bool):
626
+
627
+ path = 'results/Ctesting/tables'
628
  if not os.path.exists(path):
629
  os.makedirs(path)
 
 
 
 
630
 
631
+ demosaicings = ['bilinear', 'malvar2004', 'menon2007']
632
+ sharpenings = ['sharpening_filter', 'unsharp_masking']
633
+ denoisings = ['median_denoising', 'gaussian_denoising']
634
+
635
+ transformations = ['identity', 'gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
636
+ 'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform']
637
 
638
+ runs = {}
639
+ i = 0
640
 
641
  for dm in demosaicings:
642
  for s in sharpenings:
643
  for dn in denoisings:
644
  for t in transformations:
645
+ pip = [dm, s, dn]
646
  runs[f'run{i}'] = {
647
+ 'dataset': dataset_name,
648
+ 'augmentation': augmentation,
649
+ 'pipeline': pip,
650
+ 'N_runs': N_runs,
651
+ 'transform': t,
652
+ 'severity': severity,
653
  }
654
  ABclass = ABtesting(
655
+ dataset_name=dataset_name,
656
+ augmentation=augmentation,
657
+ dm_train=dm,
658
+ s_train=s,
659
+ dn_train=dn,
660
+ dm_test=dm,
661
+ s_test=s,
662
+ dn_test=dn,
663
+ severity=severity,
664
+ transform=t,
665
+ N_runs=N_runs,
666
+ download_model=download_model
667
+ )
668
+
669
+ if dataset_name == 'DroneSegmentation':
670
  IoU = ABclass.ABsegmentation()
671
  runs[f'run{i}']['IoU'] = IoU
672
  else:
 
677
  runs[f'run{i}']['recall'] = recall
678
  runs[f'run{i}']['f1_score'] = f1_score
679
 
680
+ with open(os.path.join(path, f'{dataset_name}_{augmentation}_runs.json'), 'w') as outfile:
681
  json.dump(runs, outfile)
682
 
683
+ i += 1
684
+
685
 
686
  def CShowTable(dataset_name, augmentation):
687
 
688
+ path = 'results/Ctesting/tables'
689
  assert os.path.exists(path), 'No tables to plot'
690
 
691
  json_file = os.path.join(path, f'{dataset_name}_{augmentation}_runs.txt')
692
 
693
+ transforms = ['identity', 'gauss_noise', 'shot', 'impulse', 'speckle',
694
+ 'gauss_blur', 'zoom', 'contrast', 'brightness', 'saturate', 'elastic']
695
 
696
  pip = []
697
+
698
+ demosaicings = ['bilinear', 'malvar2004', 'menon2007']
699
+ sharpenings = ['sharpening_filter', 'unsharp_masking']
700
+ denoisings = ['median_denoising', 'gaussian_denoising']
701
 
702
  for dm in demosaicings:
703
  for s in sharpenings:
 
707
  with open(json_file, 'r') as run_file:
708
  runs = json.load(run_file)
709
 
710
+ metrics = torch.zeros((2, len(pip), len(transforms)))
711
 
712
+ i, j = 0, 0
713
 
714
  for r in range(len(runs)):
715
+
716
+ run = runs['run' + str(r)]
717
  if dataset_name == 'DroneSegmentation':
718
  acc = run['IoU']
719
  else:
720
  acc = run['accuracy']
721
+ mu, sigma = round(acc[0], 4), round(acc[1], 4)
722
 
723
+ metrics[0, j, i] = mu
724
+ metrics[1, j, i] = sigma
725
+
726
+ i += 1
727
 
728
  if i == len(transforms):
729
+ i = 0
730
+ j += 1
731
 
732
  # Plot with scatter
733
 
734
  img = metrics
735
 
736
+ vmin = 0.
737
+ vmax = 1.
738
+
739
  x, y = torch.arange(12), torch.arange(11)
740
  x, y = torch.meshgrid(x, y)
741
 
742
+ fig = plt.figure(figsize=(10, 6.2))
743
  ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
744
+ marker_size = 350
745
+ plt.scatter(x, y, c=torch.rot90(img[1][x, y], -1, [0, 1]), vmin=0.,
746
+ vmax=img[1].max(), cmap='viridis', s=marker_size * 2, marker='s')
747
+ ticks = torch.arange(0., img[1].max(), 0.03).tolist()
748
+ ticks = [round(tick, 2) for tick in ticks]
749
  cba = plt.colorbar(pad=0.06)
750
  cba.set_ticks(ticks)
751
  cba.ax.set_yticklabels(ticks)
752
  # cmap = plt.cm.get_cmap('tab20c').reversed()
753
  cmap = plt.cm.get_cmap('Reds')
754
+ plt.scatter(x, y, c=torch.rot90(img[0][x, y], -1, [0, 1]), vmin=vmin,
755
+ vmax=vmax, cmap=cmap, s=marker_size, marker='s')
756
  ticks = torch.arange(vmin, vmax, step).tolist()
757
+ ticks = [round(tick, 2) for tick in ticks]
758
  if ticks[-1] != vmax:
759
  ticks.append(vmax)
760
  cbb = plt.colorbar(pad=0.06)
 
766
  cbb.ax.set_yticklabels(ticks)
767
  for x in range(12):
768
  for y in range(12):
769
+ txt = round(torch.rot90(img[0], -1, [0, 1])[x, y].item(), 2)
770
  if str(txt) == '-0.0':
771
  txt = '0.00'
772
  elif str(txt) == '0.0':
773
  txt = '0.00'
774
  elif len(str(txt)) == 3:
775
+ txt = str(txt) + '0'
776
  else:
777
  txt = str(txt)
 
 
778
 
779
+ plt.text(x - 0.25, y - 0.1, txt, color='black', fontsize='x-small')
780
+
781
+ ax.set_xticks(torch.linspace(0, 11, 12))
782
  ax.set_xticklabels(transforms)
783
+ ax.set_yticks(torch.linspace(0, 11, 12))
784
  pip.reverse()
785
  ax.set_yticklabels(pip)
786
  pip.reverse()
787
+ plt.xticks(rotation=45)
788
+ plt.yticks(rotation=45)
789
  cba.set_label('Standard Deviation')
790
  plt.xlabel("Pipelines")
791
  plt.ylabel("Distortions")
792
  if dataset_name == 'DroneSegmentation':
793
  cbb.set_label('IoU')
794
+ plt.savefig(os.path.join(path, f"{dataset_name}_{augmentation}_IoU.png"))
795
  else:
796
  cbb.set_label('Accuracy')
797
+ plt.savefig(os.path.join(path, f"{dataset_name}_{augmentation}_accuracies.png"))
798
+
799
 
800
  if __name__ == '__main__':
801
+
802
+ if args.mode == 'ABMakeTable':
803
  ABMakeTable(args.dataset_name, args.augmentation, args.N_runs, args.download_model)
804
+ elif args.mode == 'ABShowTable':
805
  ABShowTable(args.dataset_name, args.augmentation)
806
+ elif args.mode == 'ABShowImages':
807
+ ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
808
+ args.s_train, args.dn_train, args.dm_test, args.s_test,
809
  args.dn_test, args.N_runs, download_model=args.download_model)
810
  ABclass.ABShowImages()
811
+ elif args.mode == 'ABShowAllImages':
812
+ ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
813
+ args.s_train, args.dn_train, args.dm_test, args.s_test,
814
+ args.dn_test, args.N_runs, download_model=args.download_model)
815
  ABclass.ABShowAllImages()
816
+ elif args.mode == 'CMakeTable':
817
  CMakeTable(args.dataset_name, args.augmentation, args.severity, args.N_runs, args.download_model)
818
+ elif args.mode == 'CShowTable': # TODO test it
819
  CShowTable(args.dataset_name, args.augmentation, args.severity)
820
+ elif args.mode == 'CShowImages':
821
+ ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
822
+ args.s_train, args.dn_train, args.dm_test, args.s_test,
823
+ args.dn_test, args.N_runs, args.severity, args.transform,
824
+ download_model=args.download_model)
825
  ABclass.CShowImages()
826
+ elif args.mode == 'CShowAllImages':
827
+ ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
828
+ args.s_train, args.dn_train, args.dm_test, args.s_test,
829
+ args.dn_test, args.N_runs, args.severity, args.transform,
830
+ download_model=args.download_model)
831
  ABclass.CShowAllImages()
figure1.sh β†’ figures/figure1.sh RENAMED
File without changes
figure2.sh β†’ figures/figure2.sh RENAMED
File without changes
figures.py β†’ figures/figures.py RENAMED
File without changes
{processingpipeline β†’ figures}/numpy_static_pipeline_show.ipynb RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:64edef77495ab24143430e7a5d880b6f211568371f37eab03e1b32fb2f5b8015
3
- size 1906586
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fd7f0e985d60a9d24fe74ef0faf0e8e10258416190be64d7b06908306f0e7fc
3
+ size 1906578
sanity_checks_and_statistics.ipynb β†’ figures/sanity_checks_and_statistics.ipynb RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:53f62c6ce9a6656a31c3e0ae1deded2e4f9818cd891381dbe1030dd5edc5f278
3
- size 6103871
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:446fa03594d28d8a23aaa4d721db22e0453724edd4199970418a631e2f4e7b18
3
+ size 6103863
show_classification_results.ipynb β†’ figures/show_classification_results.ipynb RENAMED
File without changes
{utils β†’ figures}/show_dataset.ipynb RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cd09ffb969b9a0a5b414b892614b5b9e48fa32721ea9a14e9e0951160e8f92e4
3
- size 2115545
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11568132f27aee6b772327127e7dac8b2bd81bc219a2b091e74dd8d6514e419d
3
+ size 2115521
show_results.sh β†’ figures/show_results.sh RENAMED
File without changes
train.sh β†’ figures/train.sh RENAMED
File without changes
models/classifier.py β†’ model.py RENAMED
@@ -29,6 +29,7 @@ class LitModel(pl.LightningModule):
29
  weight_decay=0,
30
  loss_aux=None,
31
  adv_training=False,
 
32
  metrics=None,
33
  processor=None,
34
  augmentation=None,
@@ -57,11 +58,19 @@ class LitModel(pl.LightningModule):
57
  self.freeze_classifier = freeze_classifier
58
  self.freeze_processor = freeze_processor
59
 
 
60
  if freeze_classifier:
61
  pl.LightningModule.freeze(self.classifier)
62
  if freeze_processor:
63
  pl.LightningModule.freeze(self.processor)
64
 
 
 
 
 
 
 
 
65
  def forward(self, x):
66
  x = self.processor(x)
67
  apply_augmentation_step = self.training or self.augmentation_on_eval
@@ -97,7 +106,6 @@ class LitModel(pl.LightningModule):
97
  y_hat = F.logsigmoid(logits).exp().squeeze()
98
  else:
99
  y_hat = torch.argmax(logits, dim=1)
100
-
101
 
102
  if self.metrics is not None:
103
  for metric in self.metrics:
@@ -105,11 +113,15 @@ class LitModel(pl.LightningModule):
105
  if metric_name == 'accuracy' or not self.training or self.metrics_on_training:
106
  m = metric(y_hat.cpu().detach(), y.cpu())
107
  self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True,
108
- prog_bar=self.training or metric_name == 'accuracy')
109
  if metric_name == 'iou_score' or not self.training or self.metrics_on_training:
110
  m = metric(y_hat.cpu().detach(), y.cpu())
111
  self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True,
112
  prog_bar=self.training or metric_name == 'iou_score')
 
 
 
 
113
 
114
  return loss
115
 
@@ -124,25 +136,15 @@ class LitModel(pl.LightningModule):
124
 
125
  def train(self, mode=True):
126
  self.training = mode
127
- # self.processor.train(False)
128
- self.processor.train(mode=mode and not self.freeze_processor)
 
129
  self.classifier.train(mode=mode and not self.freeze_classifier)
130
- if self.adv_training and self.processor.batch_norm is not None: # don't update batchnorm in adversarial training
131
- self.processor.batch_norm.track_running_stats = False
132
  return self
133
 
134
  def configure_optimizers(self):
135
  self.optimizer = torch.optim.Adam(self.parameters(), self.lr, weight_decay=self.weight_decay)
136
- # parameters = [self.processor.additive_layer]
137
- # self.optimizer = torch.optim.Adam(parameters, self.lr, weight_decay=self.weight_decay)
138
  return self.optimizer
139
- # self.scheduler = {
140
- # 'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(
141
- # self.optimizer, mode='min', factor=0.2, patience=2, min_lr=1e-6, verbose=True,
142
- # ),
143
- # 'monitor': 'val_loss',
144
- # }
145
- # return [self.optimizer], [self.scheduler]
146
 
147
  def get_progress_bar_dict(self):
148
  items = super().get_progress_bar_dict()
@@ -151,7 +153,7 @@ class LitModel(pl.LightningModule):
151
 
152
 
153
  class TrackImagesCallback(pl.callbacks.base.Callback):
154
- def __init__(self, data_loader, track_every_epoch=False, track_processing=True, track_gradients=True, track_predictions=True, save_tensors=True):
155
  super().__init__()
156
  self.data_loader = data_loader
157
 
@@ -162,9 +164,12 @@ class TrackImagesCallback(pl.callbacks.base.Callback):
162
  self.track_predictions = track_predictions
163
  self.save_tensors = save_tensors
164
 
165
- def callback_track_images(self, trainer, save_loc):
166
- track_images(trainer.model,
 
 
167
  self.data_loader,
 
168
  track_processing=self.track_processing,
169
  track_gradients=self.track_gradients,
170
  track_predictions=self.track_predictions,
@@ -175,12 +180,12 @@ class TrackImagesCallback(pl.callbacks.base.Callback):
175
  def on_fit_end(self, trainer, pl_module):
176
  if not self.track_every_epoch:
177
  save_loc = 'results'
178
- self.callback_track_images(trainer, save_loc)
179
 
180
  def on_train_epoch_end(self, trainer, pl_module, outputs):
181
  if self.track_every_epoch:
182
  save_loc = f'results/epoch_{trainer.current_epoch + 1:04d}'
183
- self.callback_track_images(trainer, save_loc)
184
 
185
 
186
  from utils.debug import debug
@@ -201,7 +206,7 @@ def log_tensor(batch, path, save_tensors=True, nrow=8):
201
  mlflow.log_artifact(img_path, os.path.dirname(path))
202
 
203
 
204
- def track_images(model, data_loader, track_processing=True, track_gradients=True, track_predictions=True, save_tensors=True, save_loc='results'):
205
 
206
  device = model.device
207
  processor = model.processor
@@ -219,6 +224,9 @@ def track_images(model, data_loader, track_processing=True, track_gradients=True
219
  logits_full = []
220
  stages_full = defaultdict(list)
221
  grads_full = defaultdict(list)
 
 
 
222
 
223
  for inputs, labels in data_loader:
224
 
@@ -227,6 +235,10 @@ def track_images(model, data_loader, track_processing=True, track_gradients=True
227
 
228
  processed_rgb = processor(inputs)
229
 
 
 
 
 
230
  if track_gradients or track_predictions:
231
  logits = classifier(processed_rgb)
232
 
@@ -241,6 +253,8 @@ def track_images(model, data_loader, track_processing=True, track_gradients=True
241
 
242
  for stage, batch in processor.stages.items():
243
  stages_full[stage].append(batch.cpu().detach())
 
 
244
  if track_gradients:
245
  grads_full[stage].append(batch.grad.cpu().detach())
246
 
@@ -248,19 +262,29 @@ def track_images(model, data_loader, track_processing=True, track_gradients=True
248
 
249
  stages = stages_full
250
  grads = grads_full
 
251
 
252
  if track_processing:
253
- for stage, batch in stages_full.items():
254
  stages[stage] = torch.cat(batch)
255
 
 
 
 
 
256
  if track_gradients:
257
- for stage, batch in grads_full.items():
258
  grads[stage] = torch.cat(batch)
259
 
260
  for stage_nr, stage_name in enumerate(stages):
261
  if track_processing:
262
  batch = stages[stage_name]
263
  log_tensor(batch, os.path.join(save_loc, f'processing_{stage_nr}_{stage_name}.pt'), save_tensors)
 
 
 
 
 
264
  if track_gradients:
265
  batch_grad = grads[stage_name]
266
  batch_grad = batch_grad.abs()
@@ -270,11 +294,11 @@ def track_images(model, data_loader, track_processing=True, track_gradients=True
270
 
271
  # inputs = torch.cat(inputs_full)
272
 
273
- if track_predictions: #and model.is_segmentation_task:
274
  labels = torch.cat(labels_full)
275
  logits = torch.cat(logits_full)
276
  masks = labels.unsqueeze(1)
277
- predictions = logits #torch.sigmoid(logits).unsqueeze(1)
278
  #mask_vis = torch.cat((masks, predictions, masks * predictions), dim=1)
279
  #log_tensor(mask_vis, os.path.join(save_loc, f'masks.pt'), save_tensors)
280
  log_tensor(masks, os.path.join(save_loc, f'targets.pt'), save_tensors)
 
29
  weight_decay=0,
30
  loss_aux=None,
31
  adv_training=False,
32
+ adv_parameters='all',
33
  metrics=None,
34
  processor=None,
35
  augmentation=None,
 
58
  self.freeze_classifier = freeze_classifier
59
  self.freeze_processor = freeze_processor
60
 
61
+ self.unfreeze()
62
  if freeze_classifier:
63
  pl.LightningModule.freeze(self.classifier)
64
  if freeze_processor:
65
  pl.LightningModule.freeze(self.processor)
66
 
67
+ if adv_training and adv_parameters != 'all':
68
+ if adv_parameters != 'all':
69
+ pl.LightningModule.freeze(self.processor)
70
+ for name, p in self.processor.named_parameters():
71
+ if adv_parameters in name:
72
+ p.requires_grad = True
73
+
74
  def forward(self, x):
75
  x = self.processor(x)
76
  apply_augmentation_step = self.training or self.augmentation_on_eval
 
106
  y_hat = F.logsigmoid(logits).exp().squeeze()
107
  else:
108
  y_hat = torch.argmax(logits, dim=1)
 
109
 
110
  if self.metrics is not None:
111
  for metric in self.metrics:
 
113
  if metric_name == 'accuracy' or not self.training or self.metrics_on_training:
114
  m = metric(y_hat.cpu().detach(), y.cpu())
115
  self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True,
116
+ prog_bar=self.training or metric_name == 'accuracy')
117
  if metric_name == 'iou_score' or not self.training or self.metrics_on_training:
118
  m = metric(y_hat.cpu().detach(), y.cpu())
119
  self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True,
120
  prog_bar=self.training or metric_name == 'iou_score')
121
+ elif metric_name == 'accuracy' or not self.training or self.metrics_on_training:
122
+ m = metric(y_hat.cpu().detach(), y.cpu())
123
+ self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True,
124
+ prog_bar=self.training or metric_name == 'accuracy')
125
 
126
  return loss
127
 
 
136
 
137
  def train(self, mode=True):
138
  self.training = mode
139
+
140
+ # don't update batchnorm in adversarial training
141
+ self.processor.train(mode=mode and not self.freeze_processor and not self.adv_training)
142
  self.classifier.train(mode=mode and not self.freeze_classifier)
 
 
143
  return self
144
 
145
  def configure_optimizers(self):
146
  self.optimizer = torch.optim.Adam(self.parameters(), self.lr, weight_decay=self.weight_decay)
 
 
147
  return self.optimizer
 
 
 
 
 
 
 
148
 
149
  def get_progress_bar_dict(self):
150
  items = super().get_progress_bar_dict()
 
153
 
154
 
155
  class TrackImagesCallback(pl.callbacks.base.Callback):
156
+ def __init__(self, data_loader, reference_processor=None, track_every_epoch=False, track_processing=True, track_gradients=True, track_predictions=True, save_tensors=True):
157
  super().__init__()
158
  self.data_loader = data_loader
159
 
 
164
  self.track_predictions = track_predictions
165
  self.save_tensors = save_tensors
166
 
167
+ self.reference_processor = reference_processor
168
+
169
+ def callback_track_images(self, model, save_loc):
170
+ track_images(model,
171
  self.data_loader,
172
+ reference_processor=self.reference_processor,
173
  track_processing=self.track_processing,
174
  track_gradients=self.track_gradients,
175
  track_predictions=self.track_predictions,
 
180
  def on_fit_end(self, trainer, pl_module):
181
  if not self.track_every_epoch:
182
  save_loc = 'results'
183
+ self.callback_track_images(trainer.model, save_loc)
184
 
185
  def on_train_epoch_end(self, trainer, pl_module, outputs):
186
  if self.track_every_epoch:
187
  save_loc = f'results/epoch_{trainer.current_epoch + 1:04d}'
188
+ self.callback_track_images(trainer.model, save_loc)
189
 
190
 
191
  from utils.debug import debug
 
206
  mlflow.log_artifact(img_path, os.path.dirname(path))
207
 
208
 
209
+ def track_images(model, data_loader, reference_processor=None, track_processing=True, track_gradients=True, track_predictions=True, save_tensors=True, save_loc='results'):
210
 
211
  device = model.device
212
  processor = model.processor
 
224
  logits_full = []
225
  stages_full = defaultdict(list)
226
  grads_full = defaultdict(list)
227
+ diffs_full = defaultdict(list)
228
+
229
+ track_differences = reference_processor is not None
230
 
231
  for inputs, labels in data_loader:
232
 
 
235
 
236
  processed_rgb = processor(inputs)
237
 
238
+ if track_differences:
239
+ # debug(processor)
240
+ processed_rgb_ref = reference_processor(inputs)
241
+
242
  if track_gradients or track_predictions:
243
  logits = classifier(processed_rgb)
244
 
 
253
 
254
  for stage, batch in processor.stages.items():
255
  stages_full[stage].append(batch.cpu().detach())
256
+ if track_differences:
257
+ diffs_full[stage].append((reference_processor.stages[stage] - batch).cpu().detach())
258
  if track_gradients:
259
  grads_full[stage].append(batch.grad.cpu().detach())
260
 
 
262
 
263
  stages = stages_full
264
  grads = grads_full
265
+ diffs = diffs_full
266
 
267
  if track_processing:
268
+ for stage, batch in stages.items():
269
  stages[stage] = torch.cat(batch)
270
 
271
+ if track_differences:
272
+ for stage, batch in diffs.items():
273
+ diffs[stage] = torch.cat(batch)
274
+
275
  if track_gradients:
276
+ for stage, batch in grads.items():
277
  grads[stage] = torch.cat(batch)
278
 
279
  for stage_nr, stage_name in enumerate(stages):
280
  if track_processing:
281
  batch = stages[stage_name]
282
  log_tensor(batch, os.path.join(save_loc, f'processing_{stage_nr}_{stage_name}.pt'), save_tensors)
283
+
284
+ if track_differences:
285
+ batch = diffs[stage_name]
286
+ log_tensor(batch, os.path.join(save_loc, f'diffs_{stage_nr}_{stage_name}.pt'), False)
287
+
288
  if track_gradients:
289
  batch_grad = grads[stage_name]
290
  batch_grad = batch_grad.abs()
 
294
 
295
  # inputs = torch.cat(inputs_full)
296
 
297
+ if track_predictions: # and model.is_segmentation_task:
298
  labels = torch.cat(labels_full)
299
  logits = torch.cat(logits_full)
300
  masks = labels.unsqueeze(1)
301
+ predictions = logits # torch.sigmoid(logits).unsqueeze(1)
302
  #mask_vis = torch.cat((masks, predictions, masks * predictions), dim=1)
303
  #log_tensor(mask_vis, os.path.join(save_loc, f'masks.pt'), save_tensors)
304
  log_tensor(masks, os.path.join(save_loc, f'targets.pt'), save_tensors)
processingpipeline/pipeline.py β†’ processing/pipeline_numpy.py RENAMED
@@ -23,7 +23,7 @@ from colour_demosaicing import (demosaicing_CFA_Bayer_bilinear,
23
  import torch
24
  import numpy as np
25
 
26
- from utils.dataset import Subset
27
  from torch.utils.data import DataLoader
28
 
29
  from colour_demosaicing import (demosaicing_CFA_Bayer_bilinear,
 
23
  import torch
24
  import numpy as np
25
 
26
+ from dataset import Subset
27
  from torch.utils.data import DataLoader
28
 
29
  from colour_demosaicing import (demosaicing_CFA_Bayer_bilinear,
processingpipeline/torch_pipeline.py β†’ processing/pipeline_torch.py RENAMED
@@ -5,7 +5,7 @@ import torch.nn as nn
5
  if not os.path.exists('README.md'):
6
  os.chdir('..')
7
 
8
- from processingpipeline.pipeline import processing as default_processing
9
  from utils.base import np2torch, torch2np
10
 
11
  import segmentation_models_pytorch as smp
@@ -83,7 +83,7 @@ class NNProcessing(nn.Module):
83
  in_channels=3,
84
  classes=3,
85
  )
86
- self.batch_norm = None if not batch_norm_output else nn.BatchNorm2d(3)
87
  self.normalize_mosaic = normalize_mosaic
88
 
89
  def forward(self, raw):
@@ -108,20 +108,23 @@ class NNProcessing(nn.Module):
108
  return rgb
109
 
110
 
 
 
 
 
 
111
  class ParametrizedProcessing(nn.Module):
112
- def __init__(self, camera_parameters, track_stages=False, batch_norm_output=True, noise_layer=False):
113
  super().__init__()
114
  self.stages = None
115
  self.buffer = None
116
  self.track_stages = track_stages
117
 
118
  black_level, white_balance, colour_matrix = camera_parameters
119
- self.register_buffer('black_level', torch.as_tensor(black_level))
120
- self.register_buffer('colour_correction',
121
- torch.as_tensor(white_balance).reshape(1, 3)
122
- * torch.as_tensor(colour_matrix).reshape(3, 3))
123
- self.register_buffer('M_RGB_2_YUV', M_RGB_2_YUV.clone())
124
- self.register_buffer('M_YUV_2_RGB', M_YUV_2_RGB.clone())
125
 
126
  self.gamma_correct = nn.Parameter(torch.Tensor([2.2]))
127
 
@@ -133,14 +136,12 @@ class ParametrizedProcessing(nn.Module):
133
  self.gaussian_blur = nn.Conv2d(1, 1, kernel_size=5, padding=2, padding_mode='reflect', bias=False)
134
  self.gaussian_blur.weight.data[0][0] = K_BLUR.clone()
135
 
136
- self.batch_norm = nn.BatchNorm2d(3) if batch_norm_output else None
137
 
138
- # if noise_layer:
139
- # for param in self.parameters():
140
- # param.requires_grad = False
141
 
142
- self.additive_layer = nn.Parameter(0.001 * torch.randn((1, 3, 256, 256))
143
- ) if noise_layer else None # XXX: can this be 0?
144
 
145
  def forward(self, raw):
146
  assert raw.ndim == 3, f"needs dims (B, H, W), got {raw.shape}"
@@ -157,6 +158,7 @@ class ParametrizedProcessing(nn.Module):
157
  rgb = self.debayer(rgb)
158
  # self.stages['debayer'] = rgb
159
 
 
160
  rgb = torch.einsum('bchw,kc->bkhw', rgb, self.colour_correction).contiguous()
161
  self.stages['color_correct'] = rgb
162
 
@@ -179,7 +181,6 @@ class ParametrizedProcessing(nn.Module):
179
  self.stages['gamma_correct'] = rgb
180
 
181
  if self.additive_layer is not None:
182
- # rgb = rgb + 0 * self.additive_layer
183
  rgb = rgb + self.additive_layer
184
  self.stages['noise'] = rgb
185
 
@@ -259,11 +260,11 @@ if __name__ == "__main__":
259
  os.chdir('..')
260
 
261
  import matplotlib.pyplot as plt
262
- from utils.dataset import get_dataset
263
  from utils.base import np2torch, torch2np
264
 
265
  from utils.debug import debug
266
- from processingpipeline.pipeline import processing as default_processing
267
 
268
  raw_dataset = get_dataset('DS')
269
  loader = torch.utils.data.DataLoader(raw_dataset, batch_size=1)
 
5
  if not os.path.exists('README.md'):
6
  os.chdir('..')
7
 
8
+ from processing.pipeline_numpy import processing as default_processing
9
  from utils.base import np2torch, torch2np
10
 
11
  import segmentation_models_pytorch as smp
 
83
  in_channels=3,
84
  classes=3,
85
  )
86
+ self.batch_norm = None if not batch_norm_output else nn.BatchNorm2d(3, affine=False)
87
  self.normalize_mosaic = normalize_mosaic
88
 
89
  def forward(self, raw):
 
108
  return rgb
109
 
110
 
111
+ def add_additive_layer(processor):
112
+ processor.additive_layer = nn.Parameter(torch.zeros((1, 3, 256, 256)))
113
+ # processor.additive_layer = nn.Parameter(0.001 * torch.randn((1, 3, 256, 256)))
114
+
115
+
116
  class ParametrizedProcessing(nn.Module):
117
+ def __init__(self, camera_parameters, track_stages=False, batch_norm_output=True):
118
  super().__init__()
119
  self.stages = None
120
  self.buffer = None
121
  self.track_stages = track_stages
122
 
123
  black_level, white_balance, colour_matrix = camera_parameters
124
+
125
+ self.black_level = nn.Parameter(torch.as_tensor(black_level))
126
+ self.white_balance = nn.Parameter(torch.as_tensor(white_balance).reshape(1, 3))
127
+ self.colour_correction = nn.Parameter(torch.as_tensor(colour_matrix).reshape(3, 3))
 
 
128
 
129
  self.gamma_correct = nn.Parameter(torch.Tensor([2.2]))
130
 
 
136
  self.gaussian_blur = nn.Conv2d(1, 1, kernel_size=5, padding=2, padding_mode='reflect', bias=False)
137
  self.gaussian_blur.weight.data[0][0] = K_BLUR.clone()
138
 
139
+ self.batch_norm = nn.BatchNorm2d(3, affine=False) if batch_norm_output else None
140
 
141
+ self.register_buffer('M_RGB_2_YUV', M_RGB_2_YUV.clone())
142
+ self.register_buffer('M_YUV_2_RGB', M_YUV_2_RGB.clone())
 
143
 
144
+ self.additive_layer = None # this can be added in later
 
145
 
146
  def forward(self, raw):
147
  assert raw.ndim == 3, f"needs dims (B, H, W), got {raw.shape}"
 
158
  rgb = self.debayer(rgb)
159
  # self.stages['debayer'] = rgb
160
 
161
+ rgb = torch.einsum('bchw,kc->bchw', rgb, self.white_balance).contiguous()
162
  rgb = torch.einsum('bchw,kc->bkhw', rgb, self.colour_correction).contiguous()
163
  self.stages['color_correct'] = rgb
164
 
 
181
  self.stages['gamma_correct'] = rgb
182
 
183
  if self.additive_layer is not None:
 
184
  rgb = rgb + self.additive_layer
185
  self.stages['noise'] = rgb
186
 
 
260
  os.chdir('..')
261
 
262
  import matplotlib.pyplot as plt
263
+ from dataset import get_dataset
264
  from utils.base import np2torch, torch2np
265
 
266
  from utils.debug import debug
267
+ from processing.pipeline_numpy import processing as default_processing
268
 
269
  raw_dataset = get_dataset('DS')
270
  loader = torch.utils.data.DataLoader(raw_dataset, batch_size=1)
readme/Slice 8.png DELETED
Binary file (181 kB)
 
readme/init.md DELETED
@@ -1 +0,0 @@
1
-
 
 
readme/mlflow (1).png DELETED
Binary file (206 kB)
 
train.py CHANGED
@@ -4,6 +4,7 @@ import copy
4
  import argparse
5
 
6
  import torch
 
7
  import torch.nn as nn
8
 
9
  import mlflow.pytorch
@@ -16,17 +17,18 @@ from pytorch_lightning.callbacks import ModelCheckpoint
16
 
17
  from utils.base import display_mlflow_run_info, str2bool, fetch_from_mlflow, get_name, data_loader_mean_and_std
18
  from utils.debug import debug
 
19
  from utils.augmentation import get_augmentation
20
- from utils.dataset import Subset, get_dataset, k_fold
21
 
22
- from processingpipeline.pipeline import RawProcessingPipeline
23
- from processingpipeline.torch_pipeline import raw2rgb, RawToRGB, ParametrizedProcessing, NNProcessing
24
 
25
- from models.classifier import log_tensor, resnet_model, LitModel, TrackImagesCallback
26
 
27
  import segmentation_models_pytorch as smp
28
 
29
- from utils.pytorch_ssim import SSIM
30
 
31
  # args to set up task
32
  parser = argparse.ArgumentParser(description="classification_task")
@@ -106,41 +108,33 @@ parser.add_argument("--adv_training", action='store_true', help="Enable adversar
106
  parser.add_argument("--adv_aux_weight", type=float, default=1, help="Weighting of the adversarial auxilliary loss")
107
  parser.add_argument("--adv_aux_loss", type=str, default='ssim', choices=['l2', 'ssim'],
108
  help="Type of adversarial auxilliary regularization loss")
 
 
 
 
109
 
110
  parser.add_argument("--cache_downloaded_models", type=str2bool, default=True)
111
 
112
  parser.add_argument('--test_run', action='store_true')
113
 
114
- if 'ipykernel_launcher' in sys.argv[0]:
115
- args = parser.parse_args([
116
- '--dataset=Microscopy',
117
- '--epochs=100',
118
- '--augmentation=strong',
119
- '--lr=1e-5',
120
- '--freeze_processor',
121
- # '--track_processing',
122
- # '--test_run',
123
- # '--track_predictions',
124
- # '--track_every_epoch',
125
- # '--adv_training',
126
- # '--adv_aux_weight=100',
127
- # '--adv_aux_loss=l2',
128
- # '--log_model=',
129
- ])
130
- else:
131
- args = parser.parse_args()
132
 
133
 
134
  def run_train(args):
135
 
 
 
136
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
137
  training_mode = 'adversarial' if args.adv_training else 'default'
138
 
139
  # set tracking uri, this is the address of the mlflow server where light experimental data will be stored
140
  mlflow.set_tracking_uri(args.tracking_uri)
141
  mlflow.set_experiment(args.experiment_name)
142
- os.environ["AWS_ACCESS_KEY_ID"] = #TODO: add your AWS access key if you want to write your results to our collaborative lab server
143
- os.environ["AWS_SECRET_ACCESS_KEY"] = #TODO: add your AWS seceret access key if you want to write your results to our collaborative lab server
144
 
145
  # dataset
146
 
@@ -197,8 +191,9 @@ def run_train(args):
197
  if args.processing_mode == 'parametrized':
198
  processor = ParametrizedProcessing(
199
  camera_parameters=dataset.camera_parameters, track_stages=track_stages, batch_norm_output=True,
200
- noise_layer=args.adv_training, # XXX: Remove?
201
  )
 
202
  elif args.processing_mode == 'neural_network':
203
  processor = NNProcessing(track_stages=track_stages,
204
  normalize_mosaic=normalize_mosaic, batch_norm_output=True)
@@ -252,13 +247,19 @@ def run_train(args):
252
  assert not args.freeze_processor, "Processor should not be frozen for adversarial training"
253
 
254
  processor_default = copy.deepcopy(processor)
255
- processor_default.track_stages = False
256
  processor_default.eval()
257
  processor_default.to(DEVICE)
258
  # debug(processor_default)
 
 
 
 
 
259
 
260
  def l2_regularization(x, y):
261
- return (x - y).norm()
 
262
 
263
  if args.adv_aux_loss == 'l2':
264
  regularization = l2_regularization
@@ -274,7 +275,8 @@ def run_train(args):
274
  self.weight = weight
275
 
276
  def forward(self, x):
277
- x_reference = processor_default(x)
 
278
  x_processed = processor.buffer['processed_rgb']
279
  return self.weight * self.loss_aux(x_reference, x_processed)
280
 
@@ -290,7 +292,7 @@ def run_train(args):
290
  def __repr__(self):
291
  return f'{self.weight} * {get_name(self.loss)}'
292
 
293
- loss = WeightedLoss(loss=nn.CrossEntropyLoss(), weight=-1)
294
  # loss = WeightedLoss(loss=nn.CrossEntropyLoss(), weight=0)
295
  loss_aux = AuxLoss(
296
  loss_aux=regularization,
@@ -303,8 +305,10 @@ def run_train(args):
303
  classifier=classifier,
304
  processor=processor,
305
  loss=loss,
 
306
  loss_aux=loss_aux,
307
  adv_training=args.adv_training,
 
308
  metrics=metrics,
309
  augmentation=augmentation,
310
  is_segmentation_task=dataset.task == 'segmentation',
@@ -346,7 +350,7 @@ def run_train(args):
346
 
347
  with mlflow.start_run(run_name=f"{args.run_name}_{k_iter}", nested=True) as child_run:
348
 
349
- #mlflow.pytorch.autolog(silent=True)
350
 
351
  if k_iter == 0:
352
  display_mlflow_run_info(child_run)
@@ -373,19 +377,22 @@ def run_train(args):
373
  tracking_uri=args.tracking_uri,)
374
  mlf_logger._run_id = child_run.info.run_id
375
 
 
 
376
  callbacks = []
377
  if args.track_processing:
378
  callbacks += [TrackImagesCallback(track_loader,
 
379
  track_every_epoch=args.track_every_epoch,
380
  track_processing=args.track_processing,
381
  track_gradients=args.track_processing_gradients,
382
  track_predictions=args.track_predictions,
383
  save_tensors=args.track_save_tensors)]
384
 
385
- #if True: #args.save_best:
386
  # if dataset.task == 'classification':
387
- #checkpoint_callback = ModelCheckpoint(pathmonitor="val_accuracy", mode='max')
388
- # checkpoint_callback = ModelCheckpoint(dirpath=args.tracking_uri, save_top_k=1, verbose=True, monitor="val_accuracy", mode="max") #dirpath=args.tracking_uri,
389
  # else:
390
  # checkpoint_callback = ModelCheckpoint(monitor="val_iou_score")
391
  #callbacks += [checkpoint_callback]
@@ -397,7 +404,7 @@ def run_train(args):
397
  logger=mlf_logger,
398
  callbacks=callbacks,
399
  check_val_every_n_epoch=args.check_val_every_n_epoch,
400
- #checkpoint_callback=True,
401
  )
402
 
403
  if args.log_model:
@@ -410,9 +417,8 @@ def run_train(args):
410
  val_dataloaders=valid_loader,
411
  )
412
 
413
- # if args.adv_training:
414
- # for (name, p1), p2 in zip(processor.named_parameters(), processor_default.cpu().parameters()):
415
- # print(f"param '{name}' diff: {p2 - p1}, l2: {(p2-p1).norm().item()}")
416
  return model
417
 
418
 
 
4
  import argparse
5
 
6
  import torch
7
+ from torch import optim
8
  import torch.nn as nn
9
 
10
  import mlflow.pytorch
 
17
 
18
  from utils.base import display_mlflow_run_info, str2bool, fetch_from_mlflow, get_name, data_loader_mean_and_std
19
  from utils.debug import debug
20
+ from utils.dataset_utils import k_fold
21
  from utils.augmentation import get_augmentation
22
+ from dataset import Subset, get_dataset
23
 
24
+ from processing.pipeline_numpy import RawProcessingPipeline
25
+ from processing.pipeline_torch import add_additive_layer, raw2rgb, RawToRGB, ParametrizedProcessing, NNProcessing
26
 
27
+ from model import log_tensor, resnet_model, LitModel, TrackImagesCallback
28
 
29
  import segmentation_models_pytorch as smp
30
 
31
+ from utils.ssim import SSIM
32
 
33
  # args to set up task
34
  parser = argparse.ArgumentParser(description="classification_task")
 
108
  parser.add_argument("--adv_aux_weight", type=float, default=1, help="Weighting of the adversarial auxilliary loss")
109
  parser.add_argument("--adv_aux_loss", type=str, default='ssim', choices=['l2', 'ssim'],
110
  help="Type of adversarial auxilliary regularization loss")
111
+ parser.add_argument("--adv_noise_layer", action='store_true', help="Adds an additive layer to Parametrized Processing")
112
+ parser.add_argument("--adv_track_differences", action='store_true', help='Save difference to default pipeline')
113
+ parser.add_argument('--adv_parameters', choices=['all', 'black_level', 'white_balance',
114
+ 'colour_correction', 'gamma_correct', 'sharpening_filter', 'gaussian_blur', 'additive_layer'])
115
 
116
  parser.add_argument("--cache_downloaded_models", type=str2bool, default=True)
117
 
118
  parser.add_argument('--test_run', action='store_true')
119
 
120
+
121
+ args = parser.parse_args()
122
+
123
+ os.makedirs('results', exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
 
126
  def run_train(args):
127
 
128
+ print(args)
129
+
130
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
131
  training_mode = 'adversarial' if args.adv_training else 'default'
132
 
133
  # set tracking uri, this is the address of the mlflow server where light experimental data will be stored
134
  mlflow.set_tracking_uri(args.tracking_uri)
135
  mlflow.set_experiment(args.experiment_name)
136
+ os.environ["AWS_ACCESS_KEY_ID"] = "AKIAYYIYHFHKBIJHOJPA"
137
+ os.environ["AWS_SECRET_ACCESS_KEY"] = "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"
138
 
139
  # dataset
140
 
 
191
  if args.processing_mode == 'parametrized':
192
  processor = ParametrizedProcessing(
193
  camera_parameters=dataset.camera_parameters, track_stages=track_stages, batch_norm_output=True,
194
+ # noise_layer=args.adv_noise_layer, # this has to be added manually afterwards for when a model is loaded that doesn't have one yet
195
  )
196
+
197
  elif args.processing_mode == 'neural_network':
198
  processor = NNProcessing(track_stages=track_stages,
199
  normalize_mosaic=normalize_mosaic, batch_norm_output=True)
 
247
  assert not args.freeze_processor, "Processor should not be frozen for adversarial training"
248
 
249
  processor_default = copy.deepcopy(processor)
250
+ processor_default.track_stages = args.track_processing
251
  processor_default.eval()
252
  processor_default.to(DEVICE)
253
  # debug(processor_default)
254
+ for p in processor_default.parameters():
255
+ p.requires_grad = False
256
+
257
+ if args.adv_noise_layer:
258
+ add_additive_layer(processor)
259
 
260
  def l2_regularization(x, y):
261
+ return ((x - y) ** 2).sum()
262
+ # return (x - y).norm()
263
 
264
  if args.adv_aux_loss == 'l2':
265
  regularization = l2_regularization
 
275
  self.weight = weight
276
 
277
  def forward(self, x):
278
+ with torch.no_grad():
279
+ x_reference = processor_default(x)
280
  x_processed = processor.buffer['processed_rgb']
281
  return self.weight * self.loss_aux(x_reference, x_processed)
282
 
 
292
  def __repr__(self):
293
  return f'{self.weight} * {get_name(self.loss)}'
294
 
295
+ loss = WeightedLoss(loss=loss, weight=-1)
296
  # loss = WeightedLoss(loss=nn.CrossEntropyLoss(), weight=0)
297
  loss_aux = AuxLoss(
298
  loss_aux=regularization,
 
305
  classifier=classifier,
306
  processor=processor,
307
  loss=loss,
308
+ lr=args.lr,
309
  loss_aux=loss_aux,
310
  adv_training=args.adv_training,
311
+ adv_parameters=args.adv_parameters,
312
  metrics=metrics,
313
  augmentation=augmentation,
314
  is_segmentation_task=dataset.task == 'segmentation',
 
350
 
351
  with mlflow.start_run(run_name=f"{args.run_name}_{k_iter}", nested=True) as child_run:
352
 
353
+ # mlflow.pytorch.autolog(silent=True)
354
 
355
  if k_iter == 0:
356
  display_mlflow_run_info(child_run)
 
377
  tracking_uri=args.tracking_uri,)
378
  mlf_logger._run_id = child_run.info.run_id
379
 
380
+ reference_processor = processor_default if args.adv_training and args.adv_track_differences else None
381
+
382
  callbacks = []
383
  if args.track_processing:
384
  callbacks += [TrackImagesCallback(track_loader,
385
+ reference_processor,
386
  track_every_epoch=args.track_every_epoch,
387
  track_processing=args.track_processing,
388
  track_gradients=args.track_processing_gradients,
389
  track_predictions=args.track_predictions,
390
  save_tensors=args.track_save_tensors)]
391
 
392
+ # if True: #args.save_best:
393
  # if dataset.task == 'classification':
394
+ #checkpoint_callback = ModelCheckpoint(pathmonitor="val_accuracy", mode='max')
395
+ # checkpoint_callback = ModelCheckpoint(dirpath=args.tracking_uri, save_top_k=1, verbose=True, monitor="val_accuracy", mode="max") #dirpath=args.tracking_uri,
396
  # else:
397
  # checkpoint_callback = ModelCheckpoint(monitor="val_iou_score")
398
  #callbacks += [checkpoint_callback]
 
404
  logger=mlf_logger,
405
  callbacks=callbacks,
406
  check_val_every_n_epoch=args.check_val_every_n_epoch,
407
+ # checkpoint_callback=True,
408
  )
409
 
410
  if args.log_model:
 
417
  val_dataloaders=valid_loader,
418
  )
419
 
420
+ globals().update(locals()) # for convenient access
421
+
 
422
  return model
423
 
424
 
utils/augmentation.py CHANGED
@@ -99,7 +99,7 @@ if __name__ == '__main__':
99
  os.chdir('..')
100
 
101
  # from utils.debug import debug
102
- from utils.dataset import get_dataset
103
  import matplotlib.pyplot as plt
104
 
105
  dataset = get_dataset('DS') # drone segmentation
 
99
  os.chdir('..')
100
 
101
  # from utils.debug import debug
102
+ from dataset import get_dataset
103
  import matplotlib.pyplot as plt
104
 
105
  dataset = get_dataset('DS') # drone segmentation
utils/{splitting.py β†’ dataset_utils.py} RENAMED
@@ -1,6 +1,3 @@
1
- """
2
- Split images in blocks and vice versa
3
- """
4
 
5
  import random
6
  import numpy as np
@@ -10,127 +7,183 @@ import torch
10
  from skimage.util.shape import view_as_windows
11
 
12
 
13
- def split_img(imgs, ROIs = (3,3) , step= (1,1)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """Split the imgs in regions of size ROIs.
15
 
16
  Args:
17
  imgs (ndarray): images which you want to split
18
  ROIs (tuple): size of sub-regions splitted (ROIs=region of interests)
19
  step (tuple): step path from one sub-region to the next one (in the x,y axis)
20
-
21
  Returns:
22
  ndarray: splitted subimages.
23
  The size is (x_num_subROIs*y_num_subROIs, **) where:
24
  x_num_subROIs = ( imgs.shape[1]-int(ROIs[1]/2)*2 )/step[1]
25
  y_num_subROIs = ( imgs.shape[0]-int(ROIs[0]/2)*2 )/step[0]
26
-
27
  Example:
28
  >>> from dataset_generator import split
29
  >>> imgs_splitted = split(imgs, ROI_size = (5,5), step=(2,3))
30
  """
31
-
32
  if len(ROIs) > 2:
33
  return print("ROIs is a 2 element list")
34
-
35
  if len(step) > 2:
36
  return print("step is a 2 element list")
37
-
38
  if type(imgs) != type(np.array(1)):
39
  return print("imgs should be a ndarray")
40
-
41
- if len(imgs.shape) == 2: # Single image with one channel (HxW)
42
- splitted = view_as_windows(imgs, (ROIs[0],ROIs[1]), (step[0], step[1]))
43
  return splitted.reshape((-1, ROIs[0], ROIs[1]))
44
-
45
- if len(imgs.shape) == 3:
46
  _, _, channels = imgs.shape
47
- if channels <= 3: # Single image more channels (HxWxC)
48
- splitted = view_as_windows(imgs, (ROIs[0],ROIs[1], channels), (step[0], step[1], channels))
49
  return splitted.reshape((-1, ROIs[0], ROIs[1], channels))
50
- else: # More images with 1 channel
51
- splitted = view_as_windows(imgs, (1, ROIs[0],ROIs[1]), (1, step[0], step[1]))
52
  return splitted.reshape((-1, ROIs[0], ROIs[1]))
53
-
54
- if len(imgs.shape) == 4: # More images with more channels(BxHxWxC)
55
  _, _, _, channels = imgs.shape
56
- splitted = view_as_windows(imgs, (1, ROIs[0],ROIs[1], channels), (1, step[0], step[1], channels))
57
  return splitted.reshape((-1, ROIs[0], ROIs[1], channels))
58
 
 
59
  def join_blocks(splitted, final_shape):
60
  """Join blocks to reobtain a splitted image
61
-
62
  Attribute:
63
  splitted (tensor) = image splitted in blocks, size = (N_blocks, Channels, Height, Width)
64
  final_shape (tuple) = size of the final image reconstructed (Height, Width)
65
  Return:
66
  tensor: image restored from blocks. size = (Channels, Height, Width)
67
-
68
  """
69
  n_blocks, channels, ROI_height, ROI_width = splitted.shape
70
-
71
  rows = final_shape[0] // ROI_height
72
  columns = final_shape[1] // ROI_width
73
 
74
- final_img = torch.empty(rows, channels, ROI_height, ROI_width*columns)
75
- for r in range(rows):
76
- stackblocks = splitted[r*columns]
77
  for c in range(1, columns):
78
- stackblocks = torch.cat((stackblocks, splitted[r*columns+c]), axis=2)
79
  final_img[r] = stackblocks
80
-
81
  joined_img = final_img[0]
82
-
83
- for i in np.arange(1, len(final_img)):
84
- joined_img = torch.cat((joined_img,final_img[i]), axis = 1)
85
-
86
  return joined_img
87
 
88
- def random_ROI(X, Y, ROIs = (512,512)):
 
89
  """ Return a random region for each input/target pair images of the dataset
90
  Args:
91
  Y (ndarray): target of your dataset --> size: (BxHxWxC)
92
  X (ndarray): input of your dataset --> size: (BxHxWxC)
93
  ROIs (tuple): size of random region (ROIs=region of interests)
94
-
95
  Returns:
96
  For each pair images (input/target) of the dataset, return respectively random ROIs
97
  Y_cut (ndarray): target of your dataset --> size: (Batch,Channels,ROIs[0],ROIs[1])
98
  X_cut (ndarray): input of your dataset --> size: (Batch,Channels,ROIs[0],ROIs[1])
99
-
100
  Example:
101
  >>> from dataset_generator import random_ROI
102
  >>> X,Y = random_ROI(X,Y, ROIs = (10,10))
103
- """
104
-
105
  batch, channels, height, width = X.shape
106
-
107
- X_cut=np.empty((batch, ROIs[0], ROIs[1], channels))
108
- Y_cut=np.empty((batch, ROIs[0], ROIs[1], channels))
109
-
110
  for i in np.arange(len(X)):
111
- x_size=int(random.random()*(height-(ROIs[0]+1)))
112
- y_size=int(random.random()*(width-(ROIs[1]+1)))
113
- X_cut[i]=X[i, x_size:x_size+ROIs[0],y_size:y_size+ROIs[1], :]
114
- Y_cut[i]=Y[i, x_size:x_size+ROIs[0],y_size:y_size+ROIs[1], :]
115
  return X_cut, Y_cut
116
 
117
- def one2many_random_ROI(X, Y, datasize=1000, ROIs = (512,512)):
 
118
  """ Return a dataset of N subimages obtained from random regions of the same image
119
  Args:
120
  Y (ndarray): target of your dataset --> size: (1,H,W,C)
121
  X (ndarray): input of your dataset --> size: (1,H,W,C)
122
  datasize = number of random ROIs to generate
123
  ROIs (tuple): size of random region (ROIs=region of interests)
124
-
125
  Returns:
126
  Y_cut (ndarray): target of your dataset --> size: (Datasize,ROIs[0],ROIs[1],Channels)
127
  X_cut (ndarray): input of your dataset --> size: (Datasize,ROIs[0],ROIs[1],Channels)
128
- """
129
 
130
  batch, channels, height, width = X.shape
131
-
132
- X_cut=np.empty((datasize, ROIs[0], ROIs[1], channels))
133
- Y_cut=np.empty((datasize, ROIs[0], ROIs[1], channels))
134
 
135
  for i in np.arange(datasize):
136
  X_cut[i], Y_cut[i] = random_ROI(X, Y, ROIs)
 
 
 
 
1
 
2
  import random
3
  import numpy as np
 
7
  from skimage.util.shape import view_as_windows
8
 
9
 
10
+ def load_image(path):
11
+ file_type = path.split('.')[-1].lower()
12
+ if file_type == 'dng':
13
+ img = rawpy.imread(path).raw_image_visible
14
+ elif file_type == 'tiff' or file_type == 'tif':
15
+ img = np.array(tiff.imread(path), dtype=np.float32)
16
+ else:
17
+ img = np.array(Image.open(path), dtype=np.float32)
18
+ return img
19
+
20
+
21
+ def list_images_in_dir(path):
22
+ image_list = [os.path.join(path, img_name)
23
+ for img_name in sorted(os.listdir(path))
24
+ if img_name.split('.')[-1].lower() in IMAGE_FILE_TYPES]
25
+ return image_list
26
+
27
+
28
+ def k_fold(dataset, n_splits: int, seed: int, train_size: float):
29
+ """Split dataset in subsets for cross-validation
30
+
31
+ Args:
32
+ dataset (class): dataset to split
33
+ n_split (int): Number of re-shuffling & splitting iterations.
34
+ seed (int): seed for k_fold splitting
35
+ train_size (float): should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the train split.
36
+ Returns:
37
+ idxs (list): indeces for splitting the dataset. The list contain n_split pair of train/test indeces.
38
+ """
39
+ if hasattr(dataset, 'labels'):
40
+ x = dataset.images
41
+ y = dataset.labels
42
+ elif hasattr(dataset, 'masks'):
43
+ x = dataset.images
44
+ y = dataset.masks
45
+
46
+ idxs = []
47
+
48
+ if dataset.task == 'classification':
49
+ sss = StratifiedShuffleSplit(n_splits=n_splits, train_size=train_size, random_state=seed)
50
+
51
+ for idxs_train, idxs_test in sss.split(x, y):
52
+ idxs.append((idxs_train.tolist(), idxs_test.tolist()))
53
+
54
+ elif dataset.task == 'segmentation':
55
+ for n in range(n_splits):
56
+ split_idx = int(len(dataset) * train_size)
57
+ indices = np.random.permutation(len(dataset))
58
+ idxs.append((indices[:split_idx].tolist(), indices[split_idx:].tolist()))
59
+
60
+ return idxs
61
+
62
+
63
+ def split_img(imgs, ROIs=(3, 3), step=(1, 1)):
64
  """Split the imgs in regions of size ROIs.
65
 
66
  Args:
67
  imgs (ndarray): images which you want to split
68
  ROIs (tuple): size of sub-regions splitted (ROIs=region of interests)
69
  step (tuple): step path from one sub-region to the next one (in the x,y axis)
70
+
71
  Returns:
72
  ndarray: splitted subimages.
73
  The size is (x_num_subROIs*y_num_subROIs, **) where:
74
  x_num_subROIs = ( imgs.shape[1]-int(ROIs[1]/2)*2 )/step[1]
75
  y_num_subROIs = ( imgs.shape[0]-int(ROIs[0]/2)*2 )/step[0]
76
+
77
  Example:
78
  >>> from dataset_generator import split
79
  >>> imgs_splitted = split(imgs, ROI_size = (5,5), step=(2,3))
80
  """
81
+
82
  if len(ROIs) > 2:
83
  return print("ROIs is a 2 element list")
84
+
85
  if len(step) > 2:
86
  return print("step is a 2 element list")
87
+
88
  if type(imgs) != type(np.array(1)):
89
  return print("imgs should be a ndarray")
90
+
91
+ if len(imgs.shape) == 2: # Single image with one channel (HxW)
92
+ splitted = view_as_windows(imgs, (ROIs[0], ROIs[1]), (step[0], step[1]))
93
  return splitted.reshape((-1, ROIs[0], ROIs[1]))
94
+
95
+ if len(imgs.shape) == 3:
96
  _, _, channels = imgs.shape
97
+ if channels <= 3: # Single image more channels (HxWxC)
98
+ splitted = view_as_windows(imgs, (ROIs[0], ROIs[1], channels), (step[0], step[1], channels))
99
  return splitted.reshape((-1, ROIs[0], ROIs[1], channels))
100
+ else: # More images with 1 channel
101
+ splitted = view_as_windows(imgs, (1, ROIs[0], ROIs[1]), (1, step[0], step[1]))
102
  return splitted.reshape((-1, ROIs[0], ROIs[1]))
103
+
104
+ if len(imgs.shape) == 4: # More images with more channels(BxHxWxC)
105
  _, _, _, channels = imgs.shape
106
+ splitted = view_as_windows(imgs, (1, ROIs[0], ROIs[1], channels), (1, step[0], step[1], channels))
107
  return splitted.reshape((-1, ROIs[0], ROIs[1], channels))
108
 
109
+
110
  def join_blocks(splitted, final_shape):
111
  """Join blocks to reobtain a splitted image
112
+
113
  Attribute:
114
  splitted (tensor) = image splitted in blocks, size = (N_blocks, Channels, Height, Width)
115
  final_shape (tuple) = size of the final image reconstructed (Height, Width)
116
  Return:
117
  tensor: image restored from blocks. size = (Channels, Height, Width)
118
+
119
  """
120
  n_blocks, channels, ROI_height, ROI_width = splitted.shape
121
+
122
  rows = final_shape[0] // ROI_height
123
  columns = final_shape[1] // ROI_width
124
 
125
+ final_img = torch.empty(rows, channels, ROI_height, ROI_width * columns)
126
+ for r in range(rows):
127
+ stackblocks = splitted[r * columns]
128
  for c in range(1, columns):
129
+ stackblocks = torch.cat((stackblocks, splitted[r * columns + c]), axis=2)
130
  final_img[r] = stackblocks
131
+
132
  joined_img = final_img[0]
133
+
134
+ for i in np.arange(1, len(final_img)):
135
+ joined_img = torch.cat((joined_img, final_img[i]), axis=1)
136
+
137
  return joined_img
138
 
139
+
140
+ def random_ROI(X, Y, ROIs=(512, 512)):
141
  """ Return a random region for each input/target pair images of the dataset
142
  Args:
143
  Y (ndarray): target of your dataset --> size: (BxHxWxC)
144
  X (ndarray): input of your dataset --> size: (BxHxWxC)
145
  ROIs (tuple): size of random region (ROIs=region of interests)
146
+
147
  Returns:
148
  For each pair images (input/target) of the dataset, return respectively random ROIs
149
  Y_cut (ndarray): target of your dataset --> size: (Batch,Channels,ROIs[0],ROIs[1])
150
  X_cut (ndarray): input of your dataset --> size: (Batch,Channels,ROIs[0],ROIs[1])
151
+
152
  Example:
153
  >>> from dataset_generator import random_ROI
154
  >>> X,Y = random_ROI(X,Y, ROIs = (10,10))
155
+ """
156
+
157
  batch, channels, height, width = X.shape
158
+
159
+ X_cut = np.empty((batch, ROIs[0], ROIs[1], channels))
160
+ Y_cut = np.empty((batch, ROIs[0], ROIs[1], channels))
161
+
162
  for i in np.arange(len(X)):
163
+ x_size = int(random.random() * (height - (ROIs[0] + 1)))
164
+ y_size = int(random.random() * (width - (ROIs[1] + 1)))
165
+ X_cut[i] = X[i, x_size:x_size + ROIs[0], y_size:y_size + ROIs[1], :]
166
+ Y_cut[i] = Y[i, x_size:x_size + ROIs[0], y_size:y_size + ROIs[1], :]
167
  return X_cut, Y_cut
168
 
169
+
170
+ def one2many_random_ROI(X, Y, datasize=1000, ROIs=(512, 512)):
171
  """ Return a dataset of N subimages obtained from random regions of the same image
172
  Args:
173
  Y (ndarray): target of your dataset --> size: (1,H,W,C)
174
  X (ndarray): input of your dataset --> size: (1,H,W,C)
175
  datasize = number of random ROIs to generate
176
  ROIs (tuple): size of random region (ROIs=region of interests)
177
+
178
  Returns:
179
  Y_cut (ndarray): target of your dataset --> size: (Datasize,ROIs[0],ROIs[1],Channels)
180
  X_cut (ndarray): input of your dataset --> size: (Datasize,ROIs[0],ROIs[1],Channels)
181
+ """
182
 
183
  batch, channels, height, width = X.shape
184
+
185
+ X_cut = np.empty((datasize, ROIs[0], ROIs[1], channels))
186
+ Y_cut = np.empty((datasize, ROIs[0], ROIs[1], channels))
187
 
188
  for i in np.arange(datasize):
189
  X_cut[i], Y_cut[i] = random_ROI(X, Y, ROIs)
utils/debug.py DELETED
@@ -1,371 +0,0 @@
1
- import torch
2
- import numpy as np
3
- import inspect
4
- from functools import reduce, wraps
5
- from collections.abc import Iterable
6
- from IPython import embed
7
-
8
- try:
9
- get_ipython() # pylint: disable=undefined-variable
10
- interactive_notebook = True
11
- except:
12
- interactive_notebook = False
13
-
14
- _NONE = "__UNSET_VARIABLE__"
15
-
16
-
17
- def debug_init():
18
- debug.disable = False
19
- debug.silent = False
20
- debug.verbose = 2
21
- debug.expand_ignore = ["DataLoader", "Dataset", "Subset"]
22
- debug.max_expand = 10
23
- debug.show_tensor = False
24
- debug.raise_exception = True
25
- debug.full_stack = True
26
- debug.restore_defaults_on_exception = not interactive_notebook
27
- debug._indent = 0
28
- debug._stack = ""
29
-
30
- debug.embed = embed
31
- debug.show = debug_show
32
- debug.pause = debug_pause
33
-
34
-
35
- def debug_pause():
36
- input("Press Enter to continue...")
37
-
38
-
39
- def debug(*args, assert_true=False):
40
- """Decorator for debugging functions and tensors.
41
- Will throw an exception as soon as a nan is encountered.
42
- If used on iterables, these will be expanded and also searched for nans.
43
- Usage:
44
- debug(x)
45
- Or:
46
- @debug
47
- def function():
48
- ...
49
- If used as a function wrapper, all arguments will be searched and printed.
50
- """
51
-
52
- single_arg = len(args) == 1
53
-
54
- if debug.disable:
55
- return args[0] if single_arg else None
56
-
57
- try:
58
- call_line = ''.join(inspect.stack()[1][4]).strip()
59
- except:
60
- call_line = '...'
61
- used_as_wrapper = 'def ' == call_line[:4]
62
- expect_return_arg = single_arg and 'debug' in call_line and call_line.split('debug')[0].strip() != ''
63
- is_func = single_arg and hasattr(args[0], '__call__')
64
-
65
- if is_func and (used_as_wrapper or expect_return_arg):
66
- func = args[0]
67
- sig_parameters = inspect.signature(func).parameters
68
- sig_argnames = [p.name for p in sig_parameters.values()]
69
- sig_defaults = {
70
- k: v.default
71
- for k, v in sig_parameters.items()
72
- if v.default is not inspect.Parameter.empty
73
- }
74
-
75
- @wraps(func)
76
- def _func(*args, **kwargs):
77
- if debug.disable:
78
- return func(*args, **kwargs)
79
-
80
- if debug._indent == 0:
81
- debug._stack = ""
82
- stack_before = debug._stack
83
- indent = ' ' * 4 * debug._indent
84
- debug._indent += 1
85
-
86
- args_kw = dict(zip(sig_argnames, args))
87
- defaults = {k: v for k, v in sig_defaults.items()
88
- if k not in kwargs
89
- if k not in args_kw}
90
- all_args = {**args_kw, **kwargs, **defaults}
91
-
92
- func_name = None
93
- if hasattr(func, '__name__'):
94
- func_name = func.__name__
95
- elif hasattr(func, '__class__'):
96
- func_name = func.__class__.__name__
97
-
98
- if func_name is None:
99
- func_name = '... ' + call_line + '...'
100
- else:
101
- func_name = '@' + func_name + '()'
102
-
103
- _debug_log('', indent=indent)
104
- _debug_log(func_name, indent=indent)
105
-
106
- debug._last_call = func
107
- debug._last_args = all_args
108
- debug._last_args_sig = sig_argnames
109
-
110
- for argtype, params in [("args", args_kw.items()),
111
- ("kwargs", kwargs.items()),
112
- ("defaults", defaults.items())]:
113
- if params:
114
- _debug_log(f"{argtype}:", indent=indent + ' ' * 6)
115
- for argname, arg in params:
116
- if argname == 'self':
117
- # _debug_log(f"- self: ...", indent=indent + ' ' * 8)
118
- pass
119
- else:
120
- _debug_log(f"- {argname}: ", arg, indent + ' ' * 8, assert_true)
121
- try:
122
- out = func(*args, **kwargs)
123
- except:
124
- _debug_crash_save()
125
- debug._stack = ""
126
- debug._indent = 0
127
- raise
128
- debug.out = out
129
- _debug_log("returned: ", out, indent, assert_true)
130
- _debug_log('', indent=indent)
131
- debug._indent -= 1
132
- if not debug.full_stack:
133
- debug._stack = stack_before
134
- return out
135
- return _func
136
- else:
137
- if debug._indent == 0:
138
- debug._stack = ""
139
- argname = ')'.join('('.join(call_line.split('(')[1:]).split(')')[:-1])
140
- if assert_true:
141
- argname = ','.join(argname.split(',')[:-1])
142
- _debug_log(f"assert{{{argname}}} ", args[0], ' ' * 4 * debug._indent, assert_true)
143
- else:
144
- for arg in args:
145
- _debug_log(f"{{{argname}}} = ", arg, ' ' * 4 * debug._indent, assert_true)
146
- if expect_return_arg:
147
- return args[0]
148
- return
149
-
150
-
151
- def is_iterable(x):
152
- return isinstance(x, Iterable) or hasattr(x, '__getitem__') and not isinstance(x, str)
153
-
154
-
155
- def ndarray_repr(t, assert_all=False):
156
- exception_encountered = False
157
- info = []
158
- shape = tuple(t.shape)
159
- single_entry = shape == () or shape == (1,)
160
- if single_entry:
161
- info.append(f"[{t.item():.4f}]")
162
- else:
163
- info.append(f"({', '.join(map(repr, shape))})")
164
- invalid_sum = (~np.isfinite(t)).sum().item()
165
- if invalid_sum:
166
- info.append(
167
- f"{invalid_sum} INVALID ENTR{'Y' if invalid_sum == 1 else 'IES'}")
168
- exception_encountered = True
169
- if debug.verbose > 1:
170
- if not invalid_sum and not single_entry:
171
- info.append(f"|x|={np.linalg.norm(t):.1f}")
172
- if t.size:
173
- info.append(f"x in [{t.min():.1f}, {t.max():.1f}]")
174
- if debug.verbose and t.dtype != np.float:
175
- info.append(f"dtype={str(t.dtype)}".replace("'", ''))
176
- if assert_all:
177
- assert_val = t.all()
178
- if not assert_val:
179
- exception_encountered = True
180
- if assert_all and not exception_encountered:
181
- output = "passed"
182
- else:
183
- if assert_all and not assert_val:
184
- output = f"ndarray({info[0]})"
185
- else:
186
- output = f"ndarray({', '.join(info)})"
187
- if exception_encountered and (not hasattr(debug, 'raise_exception') or debug.raise_exception):
188
- if debug.restore_defaults_on_exception:
189
- debug.raise_exception = False
190
- debug.silent = False
191
- debug.x = t
192
- msg = output
193
- debug._stack += output
194
- if debug._stack and '\n' in debug._stack:
195
- msg += '\nSTACK: ' + debug._stack
196
- if assert_all:
197
- assert assert_val, "Assert did not pass on " + msg
198
- raise Exception("Invalid entries encountered in " + msg)
199
- return output
200
-
201
-
202
- def tensor_repr(t, assert_all=False):
203
- exception_encountered = False
204
- info = []
205
- shape = tuple(t.shape)
206
- single_entry = shape == () or shape == (1,)
207
- if single_entry:
208
- info.append(f"[{t.item():.3f}]")
209
- else:
210
- info.append(f"({', '.join(map(repr, shape))})")
211
- invalid_sum = (~torch.isfinite(t)).sum().item()
212
- if invalid_sum:
213
- info.append(
214
- f"{invalid_sum} INVALID ENTR{'Y' if invalid_sum == 1 else 'IES'}")
215
- exception_encountered = True
216
- if debug.verbose and t.requires_grad:
217
- info.append('req_grad')
218
- if debug.verbose > 2:
219
- if t.is_leaf:
220
- info.append('leaf')
221
- if hasattr(t, 'retains_grad') and t.retains_grad:
222
- info.append('retains_grad')
223
- has_grad = (t.is_leaf or hasattr(t, 'retains_grad') and t.retains_grad) and t.grad is not None
224
- if has_grad:
225
- grad_invalid_sum = (~torch.isfinite(t.grad)).sum().item()
226
- if grad_invalid_sum:
227
- info.append(
228
- f"GRAD {grad_invalid_sum} INVALID ENTR{'Y' if grad_invalid_sum == 1 else 'IES'}")
229
- exception_encountered = True
230
- if debug.verbose > 1:
231
- if not invalid_sum and not single_entry:
232
- info.append(f"|x|={t.float().norm():.1f}")
233
- if t.numel():
234
- info.append(f"x in [{t.min():.2f}, {t.max():.2f}]")
235
- if has_grad and not grad_invalid_sum:
236
- if single_entry:
237
- info.append(f"grad={t.grad.float().item():.3f}")
238
- else:
239
- info.append(f"|grad|={t.grad.float().norm():.1f}")
240
- if debug.verbose and t.dtype != torch.float:
241
- info.append(f"dtype={str(t.dtype).split('.')[-1]}")
242
- if debug.verbose and t.device.type != 'cpu':
243
- info.append(f"device={t.device.type}")
244
- if assert_all:
245
- assert_val = t.all()
246
- if not assert_val:
247
- exception_encountered = True
248
- if assert_all and not exception_encountered:
249
- output = "passed"
250
- else:
251
- if assert_all and not assert_val:
252
- output = f"tensor({info[0]})"
253
- else:
254
- output = f"tensor({', '.join(info)})"
255
- if exception_encountered and (not hasattr(debug, 'raise_exception') or debug.raise_exception):
256
- if debug.restore_defaults_on_exception:
257
- debug.raise_exception = False
258
- debug.silent = False
259
- debug.x = t
260
- msg = output
261
- debug._stack += output
262
- if debug._stack and '\n' in debug._stack:
263
- msg += '\nSTACK: ' + debug._stack
264
- if assert_all:
265
- assert assert_val, "Assert did not pass on " + msg
266
- raise Exception("Invalid entries encountered in " + msg)
267
- return output
268
-
269
-
270
- def _debug_crash_save():
271
- if debug._indent:
272
- debug.args = debug._last_args
273
- debug.func = debug._last_call
274
-
275
- @wraps(debug.func)
276
- def _recall(*args, **kwargs):
277
- call_args = {**debug.args, **kwargs, **dict(zip(debug._last_args_sig, args))}
278
- return debug(debug.func)(**call_args)
279
-
280
- def print_stack(stack=debug._stack):
281
- print('\nSTACK: ' + stack)
282
- debug.stack = print_stack
283
-
284
- debug.recall = _recall
285
- debug._indent = 0
286
-
287
-
288
- def _debug_log(output, var=_NONE, indent='', assert_true=False, expand=True):
289
- debug._stack += indent + output
290
- if not debug.silent:
291
- print(indent + output, end='')
292
- if var is not _NONE:
293
- type_str = type(var).__name__.lower()
294
- if var is None:
295
- _debug_log('None')
296
- elif isinstance(var, str):
297
- _debug_log(f"'{var}'")
298
- elif type_str == 'ndarray':
299
- _debug_log(ndarray_repr(var, assert_true))
300
- if debug.show_tensor:
301
- _debug_show_print(var, indent=indent + 4 * ' ')
302
- # elif type_str in ('tensor', 'parameter'):
303
- elif type_str == 'tensor':
304
- _debug_log(tensor_repr(var, assert_true))
305
- if debug.show_tensor:
306
- _debug_show_print(var, indent=indent + 4 * ' ')
307
- elif hasattr(var, 'named_parameters'):
308
- _debug_log(type_str)
309
- params = list(var.named_parameters())
310
- _debug_log(f"{type_str}[{len(params)}] {{")
311
- for k, v in params:
312
- _debug_log(f"'{k}': ", v, indent + 6 * ' ')
313
- _debug_log(indent + 4 * ' ' + '}')
314
- elif is_iterable(var):
315
- expand = debug.expand_ignore != '*' and expand
316
- if expand:
317
- if isinstance(debug.expand_ignore, str):
318
- if type_str == str(debug.expand_ignore).lower():
319
- expand = False
320
- elif is_iterable(debug.expand_ignore):
321
- for ignore in debug.expand_ignore:
322
- if type_str == ignore.lower():
323
- expand = False
324
- if hasattr(var, '__len__'):
325
- length = len(var)
326
- else:
327
- var = list(var)
328
- length = len(var)
329
- if expand and length > 0:
330
- _debug_log(f"{type_str}[{length}] {{")
331
- if isinstance(var, dict):
332
- for k, v in var.items():
333
- _debug_log(f"'{k}': ", v, indent + 6 * ' ', assert_true)
334
- else:
335
- i = 0
336
- for k, i in zip(var, range(debug.max_expand)):
337
- _debug_log('- ', k, indent + 6 * ' ', assert_true)
338
- if i < length - 1:
339
- _debug_log('- ' + ' ' * 6 + '...', indent=indent + 6 * ' ')
340
- _debug_log(indent + 4 * ' ' + '}')
341
- else:
342
- _debug_log(f"{type_str}[{length}]")
343
- else:
344
- _debug_log(str(var))
345
- else:
346
- debug._stack += '\n'
347
- if not debug.silent:
348
- print()
349
-
350
-
351
- def debug_show(x):
352
- assert is_iterable(x)
353
- debug(x)
354
- _debug_show_print(x, indent=' ' * 4 * debug._indent)
355
-
356
-
357
- def _debug_show_print(x, indent=''):
358
- is_tensor = type(x).__name__ in ('Tensor', 'ndarray')
359
- if is_tensor:
360
- x = x.flatten()
361
- if type(x).__name__ == 'Tensor' and x.dim() == 0:
362
- return
363
- n_samples = min(10, len(x))
364
- di = len(x) // n_samples
365
- var = list(x[i * di] for i in range(n_samples))
366
- if is_tensor or type(var[0]) == float:
367
- var = [round(float(v), 4) for v in var]
368
- _debug_log('--> ', str(var), indent, expand=False)
369
-
370
-
371
- debug_init()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/{Cperturb.py β†’ hendrycks_robustness.py} RENAMED
File without changes
utils/mutual_entropy.py DELETED
@@ -1,193 +0,0 @@
1
- import numpy as np
2
- from PIL import Image
3
- import matplotlib.pyplot as plt
4
- from scipy.signal import convolve2d
5
-
6
- def mse(x,y):
7
- return ((x-y)**2).mean()
8
-
9
- def gaussian_noise_entropies(t1, bins=20):
10
- all_MI= []
11
- all_mse = []
12
- for sigma in np.linspace(0,100,201):
13
- t2 = np.random.normal(t1.copy(), scale=sigma, size = t1.shape)
14
- hist_2d, x_edges, y_edges = np.histogram2d(
15
- t1.ravel(),
16
- t2.ravel(),
17
- bins=bins)
18
- all_mse.append(mse(t1,t2))
19
- MI = mutual_information(hist_2d)
20
- all_MI.append(MI)
21
-
22
- return np.array((all_MI)), np.array((all_mse))
23
-
24
- def shifts_entropies(t1, bins=20):
25
- all_MI=[]
26
- all_mse=[]
27
- for N in np.linspace(1,50,50):
28
- N = int(N)
29
- temp_t2 = t1[:-N].copy()
30
- temp_t1 = t1[N:].copy()
31
- hist_2d, x_edges, y_edges = np.histogram2d(
32
- t1.ravel(),
33
- t2.ravel(),
34
- bins=bins)
35
- MI = mutual_information(hist_2d)
36
-
37
- all_mse.append(mse(temp_t1,temp_t2))
38
- all_MI.append(MI)
39
-
40
- return np.array((all_MI)), np.array((all_mse))
41
-
42
- def mutual_information(hgram):
43
- """ Mutual information for joint histogram
44
- """
45
- # Convert bins counts to probability values
46
- pxy = hgram / float(np.sum(hgram))
47
- px = np.sum(pxy, axis=1) # marginal for x over y
48
- py = np.sum(pxy, axis=0) # marginal for y over x
49
- px_py = px[:, None] * py[None, :] # Broadcast to multiply marginals
50
- # Now we can do the calculation using the pxy, px_py 2D arrays
51
- nzs = pxy > 0 # Only non-zero pxy values contribute to the sum
52
- return np.sum(pxy[nzs] * np.log(pxy[nzs] / px_py[nzs]))
53
-
54
- def entropy(image, bins=20):
55
- image = image.ravel()
56
- hist, bin_edges = np.histogram(image, bins = bins)
57
- hist = hist/hist.sum()
58
- entropy_term = np.where(hist != 0, hist*np.log(hist), 0)
59
- entropy = - np.sum(entropy_term)
60
-
61
- return entropy
62
-
63
- # Gray Image
64
- # t1 = np.array(Image.open("img.png"))[:,:,0].astype(float)
65
-
66
- # Colour Image
67
- t1 = np.array(Image.open("img.png").resize((255,255)))
68
-
69
- perturb = "gauss"
70
- show_image = True
71
- bins=20
72
-
73
- print(perturb)
74
-
75
- # Identity
76
- if perturb == "identity":
77
- t2 = t1
78
- title = "Identity"
79
- image1 = "Clean"
80
- image2 = "Clean"
81
-
82
- # Poisson Noise on t2
83
- if perturb == "poisson":
84
- t2 = np.random.poisson(t1)
85
- title = "Poisson Noise"
86
- image1 = "Clean"
87
- image2 = "Noisy"
88
-
89
- # Gaussian Noise on t2
90
- if perturb == "gauss":
91
- print(np.shape(t1))
92
- sigma = 50.0
93
- t2 = np.random.normal(t1.copy(), scale=sigma, size = t1.shape)
94
- if "grad" in locals():
95
- title = f"Gaussian Noise, grad= True, sigma = {sigma:.2f}"
96
- else:
97
- title = f"Gaussian Noise, sigma = {sigma:.2f}"
98
- image1 = "Clean"
99
- image2 = "Noisy"
100
-
101
- if perturb == "box":
102
- sigma = 50.0
103
- mean = np.mean(t1)
104
- print(np.shape(t1))
105
- t2 = t1.copy()
106
- t2[30:220,50:120,:] = mean
107
- title = "Box with mean pixels"
108
- image1 = "Clean"
109
- image2 = "Noisy"
110
-
111
-
112
- # Shift t2 on y axis
113
- if perturb == "shift":
114
- N=30
115
- t2 = t1[:-N]
116
- t1 = t1[N:]
117
- title = "y shift"
118
- image1 = "Clean"
119
- image2 = "Shifted"
120
-
121
- t2 = np.clip(t2,0,255).astype(int)
122
-
123
- print("Correlation Coefficient: ",np.corrcoef(t1.ravel(), t2.ravel())[0, 1])
124
-
125
- # 2D Histogram
126
- hist_2d, x_edges, y_edges = np.histogram2d(
127
- t1.ravel(),
128
- t2.ravel(),
129
- bins=bins)
130
-
131
- MI = mutual_information(hist_2d)
132
-
133
- print("Mutual Information", MI)
134
- print("Mean squared error:", mse(t1,t2))
135
-
136
- if show_image == True:
137
- plt.figure()
138
- plt.imshow(np.hstack((t2, t1)))
139
- plt.title(title)
140
-
141
- plt.figure()
142
-
143
- plt.plot(t1.ravel(), t2.ravel(), '.')
144
- plt.xlabel(image1)
145
- plt.ylabel(image2)
146
- plt.title('I1 vs I2')
147
-
148
- plt.figure()
149
- plt.imshow((hist_2d.T)/hist_2d.max(), origin='lower')
150
- plt.xlabel(image1)
151
- plt.ylabel(image2)
152
- plt.xticks(ticks=np.linspace(0,bins-1,10), labels=np.linspace(x_edges.min(),x_edges.max(),10).astype(int))
153
- plt.yticks(ticks=np.linspace(0,bins-1,10), labels=np.linspace(y_edges.min(),y_edges.max(),10).astype(int))
154
- plt.title('p(x,y)')
155
- plt.colorbar()
156
-
157
- # Show log histogram, avoiding divide by 0
158
- plt.figure(figsize=(4,4))
159
- hist_2d_log = np.zeros(hist_2d.shape)
160
- non_zeros = hist_2d != 0
161
- hist_2d_log[non_zeros] = np.log(hist_2d[non_zeros])
162
- plt.imshow((hist_2d_log.T)/hist_2d_log.max(), origin='lower')
163
- plt.xlabel(image1)
164
- plt.ylabel(image2)
165
- plt.xticks(ticks=np.linspace(0,bins-1,10), labels=np.linspace(x_edges.min(),x_edges.max(),10).astype(int))
166
- plt.yticks(ticks=np.linspace(0,bins-1,10), labels=np.linspace(y_edges.min(),y_edges.max(),10).astype(int))
167
- plt.title('log(p(x,y))')
168
- plt.colorbar()
169
- plt.show()
170
-
171
- if perturb == "shift":
172
- mi_array, mse_array = shifts_entropies(t1, bins=bins)
173
- plt.figure()
174
- plt.plot(np.linspace(0,50,50), mi_array)
175
- plt.xlabel("y shift")
176
- plt.ylabel("Mutual Information")
177
- plt.figure()
178
- plt.plot(np.linspace(0,50,50), mse_array)
179
- plt.xlabel("y shift")
180
- plt.ylabel("Mean Squared Error")
181
- plt.show()
182
-
183
- if perturb == "gauss":
184
- mi_array, mse_array = gaussian_noise_entropies(t1, bins= bins)
185
- plt.figure()
186
- plt.plot(np.linspace(0,100,201), mi_array)
187
- plt.xlabel("sigma")
188
- plt.ylabel("Mutual Information")
189
- plt.figure()
190
- plt.plot(np.linspace(0,100,201), mse_array)
191
- plt.xlabel("sigma")
192
- plt.ylabel("Mean Squared Error")
193
- plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/{pytorch_ssim.py β†’ ssim.py} RENAMED
File without changes