willis
commited on
Commit
Β·
0220054
1
Parent(s):
839dc8c
reorganize
Browse files- .DS_Store +0 -0
- utils/dataset.py β dataset.py +11 -60
- ABtesting.py β figures/ABtesting.py +343 -318
- figure1.sh β figures/figure1.sh +0 -0
- figure2.sh β figures/figure2.sh +0 -0
- figures.py β figures/figures.py +0 -0
- {processingpipeline β figures}/numpy_static_pipeline_show.ipynb +2 -2
- sanity_checks_and_statistics.ipynb β figures/sanity_checks_and_statistics.ipynb +2 -2
- show_classification_results.ipynb β figures/show_classification_results.ipynb +0 -0
- {utils β figures}/show_dataset.ipynb +2 -2
- show_results.sh β figures/show_results.sh +0 -0
- train.sh β figures/train.sh +0 -0
- models/classifier.py β model.py +49 -25
- processingpipeline/pipeline.py β processing/pipeline_numpy.py +1 -1
- processingpipeline/torch_pipeline.py β processing/pipeline_torch.py +19 -18
- readme/Slice 8.png +0 -0
- readme/init.md +0 -1
- readme/mlflow (1).png +0 -0
- train.py +44 -38
- utils/augmentation.py +1 -1
- utils/{splitting.py β dataset_utils.py} +105 -52
- utils/debug.py +0 -371
- utils/{Cperturb.py β hendrycks_robustness.py} +0 -0
- utils/mutual_entropy.py +0 -193
- utils/{pytorch_ssim.py β ssim.py} +0 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
utils/dataset.py β dataset.py
RENAMED
@@ -15,7 +15,7 @@ from sklearn.model_selection import StratifiedShuffleSplit
|
|
15 |
if not os.path.exists('README.md'): # set pwd to root
|
16 |
os.chdir('..')
|
17 |
|
18 |
-
from utils.
|
19 |
from utils.base import np2torch, torch2np, b2_download_folder
|
20 |
|
21 |
IMAGE_FILE_TYPES = ['dng', 'png', 'tif', 'tiff']
|
@@ -41,24 +41,6 @@ def get_dataset(name, I_ratio=1.0):
|
|
41 |
raise ValueError(name)
|
42 |
|
43 |
|
44 |
-
def load_image(path):
|
45 |
-
file_type = path.split('.')[-1].lower()
|
46 |
-
if file_type == 'dng':
|
47 |
-
img = rawpy.imread(path).raw_image_visible
|
48 |
-
elif file_type == 'tiff' or file_type == 'tif':
|
49 |
-
img = np.array(tiff.imread(path), dtype=np.float32)
|
50 |
-
else:
|
51 |
-
img = np.array(Image.open(path), dtype=np.float32)
|
52 |
-
return img
|
53 |
-
|
54 |
-
|
55 |
-
def list_images_in_dir(path):
|
56 |
-
image_list = [os.path.join(path, img_name)
|
57 |
-
for img_name in sorted(os.listdir(path))
|
58 |
-
if img_name.split('.')[-1].lower() in IMAGE_FILE_TYPES]
|
59 |
-
return image_list
|
60 |
-
|
61 |
-
|
62 |
class ImageFolderDataset(Dataset):
|
63 |
"""Creates a dataset of images in img_dir and corresponding masks in mask_dir.
|
64 |
Corresponding mask files need to contain the filename of the image.
|
@@ -166,18 +148,20 @@ class ImageFolderDatasetSegmentation(Dataset):
|
|
166 |
|
167 |
return img, mask
|
168 |
|
|
|
169 |
class MultiIntensity(Dataset):
|
170 |
"""Wrap datasets with different intesities
|
171 |
|
172 |
Args:
|
173 |
datasets (list): list of datasets to wrap
|
174 |
"""
|
|
|
175 |
def __init__(self, datasets):
|
176 |
self.dataset = datasets[0]
|
177 |
|
178 |
-
for d in range(1,len(datasets)):
|
179 |
-
self.dataset.images = self.dataset.images+datasets[d].images
|
180 |
-
self.dataset.labels = self.dataset.labels+datasets[d].labels
|
181 |
|
182 |
def __len__(self):
|
183 |
return len(self.dataset)
|
@@ -191,6 +175,7 @@ class MultiIntensity(Dataset):
|
|
191 |
x = self.transform(x)
|
192 |
return x, y
|
193 |
|
|
|
194 |
class Subset(Dataset):
|
195 |
"""Define a subset of a dataset by only selecting given indices.
|
196 |
|
@@ -228,8 +213,8 @@ class DroneDatasetSegmentationFull(ImageFolderDatasetSegmentation):
|
|
228 |
camera_parameters = black_level, white_balance, colour_matrix
|
229 |
|
230 |
def __init__(self, I_ratio=1.0, transform=None, force_download=False, bits=16):
|
231 |
-
|
232 |
-
assert I_ratio in [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0]
|
233 |
|
234 |
img_dir = f'data/drone/images_full/raw_scale{int(I_ratio*100):03d}'
|
235 |
mask_dir = 'data/drone/masks_full'
|
@@ -411,7 +396,7 @@ def download_microscopy_dataset(force_download):
|
|
411 |
|
412 |
|
413 |
def unzip_microscopy_images():
|
414 |
-
|
415 |
if os.path.isfile('data/microscopy/labels/.bzEmpty'):
|
416 |
os.remove('data/microscopy/labels/.bzEmpty')
|
417 |
|
@@ -421,6 +406,7 @@ def unzip_microscopy_images():
|
|
421 |
zip.extractall('data/microscopy/images')
|
422 |
os.remove(os.path.join('data/microscopy/images', file))
|
423 |
|
|
|
424 |
def unzip_drone_images():
|
425 |
|
426 |
if os.path.isfile('data/drone/masks_full/.bzEmpty'):
|
@@ -585,38 +571,3 @@ def check_image_folder_consistency(images, masks):
|
|
585 |
f"image file {img_file} file type mismatch. Shoule be: {file_type_images}"
|
586 |
assert mask_file.split('.')[-1].lower() == file_type_masks, \
|
587 |
f"image file {mask_file} file type mismatch. Should be: {file_type_masks}"
|
588 |
-
|
589 |
-
|
590 |
-
def k_fold(dataset, n_splits: int, seed: int, train_size: float):
|
591 |
-
"""Split dataset in subsets for cross-validation
|
592 |
-
|
593 |
-
Args:
|
594 |
-
dataset (class): dataset to split
|
595 |
-
n_split (int): Number of re-shuffling & splitting iterations.
|
596 |
-
seed (int): seed for k_fold splitting
|
597 |
-
train_size (float): should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the train split.
|
598 |
-
Returns:
|
599 |
-
idxs (list): indeces for splitting the dataset. The list contain n_split pair of train/test indeces.
|
600 |
-
"""
|
601 |
-
if hasattr(dataset, 'labels'):
|
602 |
-
x = dataset.images
|
603 |
-
y = dataset.labels
|
604 |
-
elif hasattr(dataset, 'masks'):
|
605 |
-
x = dataset.images
|
606 |
-
y = dataset.masks
|
607 |
-
|
608 |
-
idxs = []
|
609 |
-
|
610 |
-
if dataset.task == 'classification':
|
611 |
-
sss = StratifiedShuffleSplit(n_splits=n_splits, train_size=train_size, random_state=seed)
|
612 |
-
|
613 |
-
for idxs_train, idxs_test in sss.split(x, y):
|
614 |
-
idxs.append((idxs_train.tolist(), idxs_test.tolist()))
|
615 |
-
|
616 |
-
elif dataset.task == 'segmentation':
|
617 |
-
for n in range(n_splits):
|
618 |
-
split_idx = int(len(dataset) * train_size)
|
619 |
-
indices = np.random.permutation(len(dataset))
|
620 |
-
idxs.append((indices[:split_idx].tolist(), indices[split_idx:].tolist()))
|
621 |
-
|
622 |
-
return idxs
|
|
|
15 |
if not os.path.exists('README.md'): # set pwd to root
|
16 |
os.chdir('..')
|
17 |
|
18 |
+
from utils.dataset_utils import split_img, list_images_in_dir, load_image
|
19 |
from utils.base import np2torch, torch2np, b2_download_folder
|
20 |
|
21 |
IMAGE_FILE_TYPES = ['dng', 'png', 'tif', 'tiff']
|
|
|
41 |
raise ValueError(name)
|
42 |
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
class ImageFolderDataset(Dataset):
|
45 |
"""Creates a dataset of images in img_dir and corresponding masks in mask_dir.
|
46 |
Corresponding mask files need to contain the filename of the image.
|
|
|
148 |
|
149 |
return img, mask
|
150 |
|
151 |
+
|
152 |
class MultiIntensity(Dataset):
|
153 |
"""Wrap datasets with different intesities
|
154 |
|
155 |
Args:
|
156 |
datasets (list): list of datasets to wrap
|
157 |
"""
|
158 |
+
|
159 |
def __init__(self, datasets):
|
160 |
self.dataset = datasets[0]
|
161 |
|
162 |
+
for d in range(1, len(datasets)):
|
163 |
+
self.dataset.images = self.dataset.images + datasets[d].images
|
164 |
+
self.dataset.labels = self.dataset.labels + datasets[d].labels
|
165 |
|
166 |
def __len__(self):
|
167 |
return len(self.dataset)
|
|
|
175 |
x = self.transform(x)
|
176 |
return x, y
|
177 |
|
178 |
+
|
179 |
class Subset(Dataset):
|
180 |
"""Define a subset of a dataset by only selecting given indices.
|
181 |
|
|
|
213 |
camera_parameters = black_level, white_balance, colour_matrix
|
214 |
|
215 |
def __init__(self, I_ratio=1.0, transform=None, force_download=False, bits=16):
|
216 |
+
|
217 |
+
assert I_ratio in [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0]
|
218 |
|
219 |
img_dir = f'data/drone/images_full/raw_scale{int(I_ratio*100):03d}'
|
220 |
mask_dir = 'data/drone/masks_full'
|
|
|
396 |
|
397 |
|
398 |
def unzip_microscopy_images():
|
399 |
+
|
400 |
if os.path.isfile('data/microscopy/labels/.bzEmpty'):
|
401 |
os.remove('data/microscopy/labels/.bzEmpty')
|
402 |
|
|
|
406 |
zip.extractall('data/microscopy/images')
|
407 |
os.remove(os.path.join('data/microscopy/images', file))
|
408 |
|
409 |
+
|
410 |
def unzip_drone_images():
|
411 |
|
412 |
if os.path.isfile('data/drone/masks_full/.bzEmpty'):
|
|
|
571 |
f"image file {img_file} file type mismatch. Shoule be: {file_type_images}"
|
572 |
assert mask_file.split('.')[-1].lower() == file_type_masks, \
|
573 |
f"image file {mask_file} file type mismatch. Should be: {file_type_masks}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ABtesting.py β figures/ABtesting.py
RENAMED
@@ -8,11 +8,11 @@ from torch.utils.data import DataLoader
|
|
8 |
from torchvision.transforms import Compose, Normalize
|
9 |
import torch.nn.functional as F
|
10 |
|
11 |
-
from
|
12 |
from utils.base import get_mlflow_model_by_name, SmartFormatter
|
13 |
-
from
|
14 |
|
15 |
-
from utils.
|
16 |
|
17 |
import segmentation_models_pytorch as smp
|
18 |
|
@@ -20,94 +20,104 @@ import matplotlib.pyplot as plt
|
|
20 |
|
21 |
parser = argparse.ArgumentParser(description="AB testing, Show Results", formatter_class=SmartFormatter)
|
22 |
|
23 |
-
#Select experiment
|
24 |
-
parser.add_argument("--mode", type=str, default="ABShowImages", choices=('ABMakeTable', 'ABShowTable', 'ABShowImages', 'ABShowAllImages', 'CMakeTable', 'CShowTable', 'CShowImages', 'CShowAllImages'),
|
25 |
help='R|Choose operation to compute. \n'
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
parser.add_argument("--dataset_name", type=str, default='Microscopy',
|
38 |
-
|
|
|
|
|
39 |
parser.add_argument("--N_runs", type=int, default=5, help='Number of k-fold splitting used in the training')
|
40 |
parser.add_argument("--download_model", default=False, action='store_true', help='Download Models in cache')
|
41 |
|
42 |
-
#Select pipelines
|
43 |
-
parser.add_argument("--dm_train", type=str, default='bilinear', choices=
|
44 |
-
|
45 |
-
parser.add_argument("--
|
46 |
-
|
47 |
-
parser.add_argument("--
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
parser.add_argument("--
|
52 |
-
|
53 |
-
parser.add_argument("--
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
args = parser.parse_args()
|
56 |
|
|
|
57 |
class metrics:
|
58 |
def __init__(self, confusion_matrix):
|
59 |
self.cm = confusion_matrix
|
60 |
self.N_classes = len(confusion_matrix)
|
61 |
|
62 |
def accuracy(self):
|
63 |
-
Tp = torch.diagonal(self.cm,0).sum()
|
64 |
N_elements = torch.sum(self.cm)
|
65 |
-
return Tp/N_elements
|
66 |
-
|
67 |
def precision(self):
|
68 |
Tp_Fp = torch.sum(self.cm, 1)
|
69 |
Tp_Fp[Tp_Fp == 0] = 1
|
70 |
-
return torch.diagonal(self.cm,0) / Tp_Fp
|
71 |
|
72 |
def recall(self):
|
73 |
Tp_Fn = torch.sum(self.cm, 0)
|
74 |
Tp_Fn[Tp_Fn == 0] = 1
|
75 |
-
return torch.diagonal(self.cm,0) / Tp_Fn
|
76 |
|
77 |
def f1_score(self):
|
78 |
-
prod = (self.precision()*self.recall())
|
79 |
sum = (self.precision() + self.recall())
|
80 |
sum[sum == 0.] = 1.
|
81 |
-
return 2*(
|
82 |
|
83 |
def over_N_runs(ms, N_runs):
|
84 |
-
m, m2 = 0, 0
|
85 |
|
86 |
for i in ms:
|
87 |
-
m += i
|
88 |
-
mu = m/N_runs
|
89 |
-
|
90 |
for i in ms:
|
91 |
-
m2 += (i-mu)**2
|
|
|
|
|
92 |
|
93 |
-
sigma = torch.sqrt( m2 / (N_runs-1) )
|
94 |
-
|
95 |
return mu.tolist(), sigma.tolist()
|
96 |
|
|
|
97 |
class ABtesting:
|
98 |
-
def __init__(self,
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
self.experiment_name = 'ABtesting'
|
112 |
self.dataset_name = dataset_name
|
113 |
self.augmentation = augmentation
|
@@ -129,12 +139,12 @@ class ABtesting:
|
|
129 |
if sharpening == None:
|
130 |
sharpening = self.s_test
|
131 |
if denoising == None:
|
132 |
-
denoising = self.dn_test
|
133 |
if severity == None:
|
134 |
-
severity = self.severity
|
135 |
if transform == None:
|
136 |
transform = self.transform
|
137 |
-
|
138 |
dataset = get_dataset(self.dataset_name)
|
139 |
|
140 |
if self.dataset_name == "Drone" or self.dataset_name == "DroneSegmentation":
|
@@ -146,40 +156,41 @@ class ABtesting:
|
|
146 |
|
147 |
if not plot_mode:
|
148 |
dataset.transform = Compose([RawProcessingPipeline(
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
else:
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
|
163 |
return dataset
|
164 |
|
165 |
def ABclassification(self):
|
166 |
-
|
167 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
168 |
|
169 |
parent_run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}"
|
170 |
|
171 |
-
print(
|
172 |
-
|
|
|
173 |
|
174 |
-
accuracies, precisions, recalls, f1_scores = [],[],[],[]
|
175 |
|
176 |
os.system('rm -r /tmp/py*')
|
177 |
|
178 |
for N_run in range(self.N_runs):
|
179 |
|
180 |
print(f"Evaluating Run {N_run}")
|
181 |
-
|
182 |
-
run_name = parent_run_name+'_'+str(N_run)
|
183 |
|
184 |
state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name,
|
185 |
download_model=self.download_model)
|
@@ -190,18 +201,18 @@ class ABtesting:
|
|
190 |
|
191 |
model.eval()
|
192 |
|
193 |
-
len_classes = len(dataset.classes)
|
194 |
confusion_matrix = torch.zeros((len_classes, len_classes))
|
195 |
|
196 |
for img, label in valid_loader:
|
197 |
-
|
198 |
prediction = model(img.to(DEVICE)).detach().cpu()
|
199 |
-
prediction
|
200 |
-
confusion_matrix[label,prediction] += 1
|
201 |
|
202 |
m = metrics(confusion_matrix)
|
203 |
|
204 |
-
accuracies.append(m.accuracy())
|
205 |
precisions.append(m.precision())
|
206 |
recalls.append(m.recall())
|
207 |
f1_scores.append(m.f1_score())
|
@@ -213,15 +224,16 @@ class ABtesting:
|
|
213 |
recall = metrics.over_N_runs(recalls, self.N_runs)
|
214 |
f1_score = metrics.over_N_runs(f1_scores, self.N_runs)
|
215 |
return dataset.classes, accuracy, precision, recall, f1_score
|
216 |
-
|
217 |
def ABsegmentation(self):
|
218 |
-
|
219 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
220 |
|
221 |
parent_run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}"
|
222 |
|
223 |
-
print(
|
224 |
-
|
|
|
225 |
|
226 |
IoUs = []
|
227 |
|
@@ -230,34 +242,34 @@ class ABtesting:
|
|
230 |
for N_run in range(self.N_runs):
|
231 |
|
232 |
print(f"Evaluating Run {N_run}")
|
233 |
-
|
234 |
-
run_name = parent_run_name+'_'+str(N_run)
|
235 |
|
236 |
-
|
237 |
-
|
|
|
|
|
238 |
|
239 |
dataset = self.static_pip_val()
|
240 |
-
|
241 |
valid_set = Subset(dataset, indices=state_dict['valid_indices'])
|
242 |
valid_loader = DataLoader(valid_set, batch_size=1, num_workers=16, shuffle=False)
|
243 |
|
244 |
model.eval()
|
245 |
|
246 |
-
IoU=0
|
247 |
|
248 |
for img, label in valid_loader:
|
249 |
-
|
250 |
prediction = model(img.to(DEVICE)).detach().cpu()
|
251 |
prediction = F.logsigmoid(prediction).exp().squeeze()
|
252 |
-
IoU += smp.utils.metrics.IoU()(prediction,label)
|
253 |
|
254 |
-
IoU = IoU/len(valid_loader)
|
255 |
IoUs.append(IoU.item())
|
256 |
|
257 |
os.system('rm -r /tmp/t*')
|
258 |
|
259 |
IoU = metrics.over_N_runs(torch.tensor(IoUs), self.N_runs)
|
260 |
-
return IoU
|
261 |
|
262 |
def ABShowImages(self):
|
263 |
|
@@ -265,164 +277,169 @@ class ABtesting:
|
|
265 |
if not os.path.exists(path):
|
266 |
os.makedirs(path)
|
267 |
|
268 |
-
path = os.path.join(
|
|
|
269 |
|
270 |
if not os.path.exists(path):
|
271 |
os.makedirs(path)
|
272 |
|
273 |
-
run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}"+
|
|
|
274 |
|
275 |
state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name, download_model=self.download_model)
|
276 |
|
277 |
model.augmentation = None
|
278 |
-
|
279 |
-
for t in ([self.dm_train, self.s_train, self.dn_train, 'train_img'],
|
280 |
-
|
281 |
-
|
282 |
debayer, sharpening, denoising, img_type = t[0], t[1], t[2], t[3]
|
283 |
|
284 |
dataset = self.static_pip_val(debayer=debayer, sharpening=sharpening, denoising=denoising, plot_mode=True)
|
285 |
valid_set = Subset(dataset, indices=state_dict['valid_indices'])
|
286 |
-
|
287 |
img, _ = next(iter(valid_set))
|
288 |
|
289 |
plt.figure()
|
290 |
-
plt.imshow(img.permute(1,2,0))
|
291 |
if img_type == 'train_img':
|
292 |
plt.title('Train Image')
|
293 |
plt.savefig(os.path.join(path, f'img_train.png'))
|
294 |
imgA = img
|
295 |
else:
|
296 |
plt.title('Test Image')
|
297 |
-
plt.savefig(os.path.join(path,f'img_test.png'))
|
298 |
-
|
299 |
-
for c, color in enumerate(['Red','Green','Blue']):
|
300 |
-
diff = torch.abs(imgA-img)
|
301 |
plt.figure()
|
302 |
# plt.imshow(diff.permute(1,2,0))
|
303 |
-
plt.imshow(diff[c,50:200,50:200], cmap=f'{color}s')
|
304 |
plt.title(f'|Train Image - Test Image| - {color}')
|
305 |
plt.colorbar()
|
306 |
plt.savefig(os.path.join(path, f'diff_{color}.png'))
|
307 |
plt.figure()
|
308 |
-
diff[diff == 0.]= 1e-5
|
309 |
# plt.imshow(torch.log(diff.permute(1,2,0)))
|
310 |
plt.imshow(torch.log(diff)[c])
|
311 |
plt.title(f'log(|Train Image - Test Image|) - color')
|
312 |
plt.colorbar()
|
313 |
plt.savefig(os.path.join(path, f'logdiff_{color}.png'))
|
314 |
-
|
315 |
if self.dataset_name == 'DroneSegmentation':
|
316 |
plt.figure()
|
317 |
plt.imshow(model(img[None].cuda()).detach().cpu().squeeze())
|
318 |
if img_type == 'train_img':
|
319 |
plt.savefig(os.path.join(path, f'mask_train.png'))
|
320 |
else:
|
321 |
-
plt.savefig(os.path.join(path,f'mask_test.png'))
|
322 |
|
323 |
def ABShowAllImages(self):
|
324 |
if not os.path.exists('results/ABtesting'):
|
325 |
os.makedirs('results/ABtesting')
|
326 |
|
327 |
-
demosaicings=['bilinear','malvar2004', 'menon2007']
|
328 |
-
sharpenings=['sharpening_filter', 'unsharp_masking']
|
329 |
-
denoisings=['median_denoising', 'gaussian_denoising']
|
330 |
|
331 |
fig = plt.figure()
|
332 |
-
columns=4
|
333 |
-
rows=3
|
334 |
|
335 |
-
i=1
|
336 |
|
337 |
for dm in demosaicings:
|
338 |
for s in sharpenings:
|
339 |
for dn in denoisings:
|
340 |
-
|
341 |
-
dataset = self.static_pip_val(self.dm_test, self.s_test,
|
342 |
-
self.dn_test, plot_mode=True)
|
343 |
|
344 |
-
|
345 |
-
|
|
|
|
|
|
|
346 |
fig.add_subplot(rows, columns, i)
|
347 |
-
plt.imshow(img.permute(1,2,0))
|
348 |
plt.title(f'{dm}\n{s}\n{dn}', fontsize=8)
|
349 |
plt.xticks([])
|
350 |
plt.yticks([])
|
351 |
plt.tight_layout()
|
352 |
|
353 |
-
i+=1
|
354 |
|
355 |
plt.show()
|
356 |
plt.savefig(f'results/ABtesting/ABpipelines.png')
|
357 |
|
358 |
def CShowImages(self):
|
359 |
-
|
360 |
path = 'results/Ctesting/imgs/'
|
361 |
if not os.path.exists(path):
|
362 |
os.makedirs(path)
|
363 |
|
364 |
-
run_name = f"{self.dataset_name}_{self.dm_test}_{self.s_test}_{self.dn_test}_{self.augmentation}"+'_'+str(0)
|
365 |
|
366 |
state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name, download_model=True)
|
367 |
|
368 |
model.augmentation = None
|
369 |
-
|
370 |
-
dataset = self.static_pip_val(self.dm_test, self.s_test, self.dn_test,
|
|
|
371 |
valid_set = Subset(dataset, indices=state_dict['valid_indices'])
|
372 |
-
|
373 |
img, _ = next(iter(valid_set))
|
374 |
|
375 |
plt.figure()
|
376 |
-
plt.imshow(img.permute(1,2,0))
|
377 |
-
plt.savefig(os.path.join(
|
378 |
-
|
|
|
379 |
def CShowAllImages(self):
|
380 |
if not os.path.exists('results/Cimages'):
|
381 |
os.makedirs('results/Cimages')
|
382 |
|
383 |
-
transforms = ['identity','gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
|
384 |
-
|
385 |
|
386 |
-
for i,t in enumerate(transforms):
|
387 |
-
|
388 |
-
fig = plt.figure(figsize=(10,6))
|
389 |
columns = 5
|
390 |
rows = 1
|
391 |
|
392 |
-
for sev in range(1,6):
|
393 |
|
394 |
dataset = self.static_pip_val(severity=sev, transform=t, plot_mode=True)
|
395 |
|
396 |
-
img,_ = dataset[0]
|
397 |
-
|
398 |
fig.add_subplot(rows, columns, sev)
|
399 |
-
plt.imshow(img.permute(1,2,0))
|
400 |
plt.title(f'Severity: {sev}')
|
401 |
plt.xticks([])
|
402 |
plt.yticks([])
|
403 |
plt.tight_layout()
|
404 |
-
|
405 |
if '_' in t:
|
406 |
-
t=t.replace('_', ' ')
|
407 |
-
t=t[0].upper()+t[1:]
|
408 |
|
409 |
fig.suptitle(f'{t}', x=0.5, y=0.8, fontsize=24)
|
410 |
plt.show()
|
411 |
plt.savefig(f'results/Cimages/{i+1}_{t.lower()}.png')
|
412 |
|
413 |
-
def ABMakeTable(dataset_name:str, augmentation: str,
|
414 |
-
N_runs: int, download_model: bool):
|
415 |
|
416 |
-
|
417 |
-
|
418 |
-
|
|
|
|
|
|
|
419 |
|
420 |
-
path='results/ABtesting/tables'
|
421 |
if not os.path.exists(path):
|
422 |
os.makedirs(path)
|
423 |
|
424 |
-
runs={}
|
425 |
-
i=0
|
426 |
|
427 |
for dm_train in demosaicings:
|
428 |
for s_train in sharpenings:
|
@@ -431,28 +448,28 @@ def ABMakeTable(dataset_name:str, augmentation: str,
|
|
431 |
for s_test in sharpenings:
|
432 |
for dn_test in denoisings:
|
433 |
train_pip = [dm_train, s_train, dn_train]
|
434 |
-
test_pip = [dm_test, s_test, dn_test]
|
435 |
runs[f'run{i}'] = {
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
}
|
442 |
ABclass = ABtesting(
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
if dataset_name == 'DroneSegmentation':
|
456 |
IoU = ABclass.ABsegmentation()
|
457 |
runs[f'run{i}']['IoU'] = IoU
|
458 |
else:
|
@@ -462,15 +479,16 @@ def ABMakeTable(dataset_name:str, augmentation: str,
|
|
462 |
runs[f'run{i}']['precision'] = precision
|
463 |
runs[f'run{i}']['recall'] = recall
|
464 |
runs[f'run{i}']['f1_score'] = f1_score
|
465 |
-
|
466 |
-
with open(os.path.join(path,f'{dataset_name}_{augmentation}_runs.txt'), 'w') as outfile:
|
467 |
json.dump(runs, outfile)
|
468 |
|
469 |
-
i+=1
|
|
|
470 |
|
471 |
def ABShowTable(dataset_name: str, augmentation: str):
|
472 |
-
|
473 |
-
path='results/ABtesting/tables'
|
474 |
assert os.path.exists(path), 'No tables to plot'
|
475 |
|
476 |
json_file = os.path.join(path, f'{dataset_name}_{augmentation}_runs.txt')
|
@@ -478,14 +496,14 @@ def ABShowTable(dataset_name: str, augmentation: str):
|
|
478 |
with open(json_file, 'r') as run_file:
|
479 |
runs = json.load(run_file)
|
480 |
|
481 |
-
metrics=torch.zeros((2,12,12))
|
482 |
-
classes=[]
|
483 |
|
484 |
-
i,j=0,0
|
485 |
|
486 |
for r in range(len(runs)):
|
487 |
-
|
488 |
-
run = runs['run'+str(r)]
|
489 |
if dataset_name == 'DroneSegmentation':
|
490 |
acc = run['IoU']
|
491 |
else:
|
@@ -494,64 +512,66 @@ def ABShowTable(dataset_name: str, augmentation: str):
|
|
494 |
class_list = run['test_pip']
|
495 |
class_name = f'{class_list[0][:2]},{class_list[1][:1]},{class_list[2][:2]}'
|
496 |
classes.append(class_name)
|
497 |
-
mu,sigma = round(acc[0],4),round(acc[1],4)
|
|
|
|
|
|
|
498 |
|
499 |
-
|
500 |
-
metrics[1,j,i] = sigma
|
501 |
-
|
502 |
-
i+=1
|
503 |
|
504 |
if i == 12:
|
505 |
-
i=0
|
506 |
-
j+=1
|
507 |
|
508 |
differences = torch.zeros_like(metrics)
|
509 |
|
510 |
-
diag_mu = torch.diagonal(metrics[0],0)
|
511 |
-
diag_sigma = torch.diagonal(metrics[1],0)
|
512 |
-
|
513 |
for r in range(len(metrics[0])):
|
514 |
-
differences[0,r] = diag_mu[r] - metrics[0,r]
|
515 |
-
differences[1,r] = torch.sqrt(metrics[1,r]**2 + diag_sigma[r]**2)
|
516 |
|
517 |
# Plot with scatter
|
518 |
-
|
519 |
-
for i,img in enumerate([metrics, differences]):
|
520 |
|
521 |
x, y = torch.arange(12), torch.arange(12)
|
522 |
x, y = torch.meshgrid(x, y)
|
523 |
|
524 |
if i == 0:
|
525 |
-
vmin = max(0.65, round(img[0].min().item(),2))
|
526 |
-
vmax = round(img[0].max().item(),2)
|
527 |
step = 0.02
|
528 |
elif i == 1:
|
529 |
-
vmin = round(img[0].min().item(),2)
|
530 |
if augmentation == 'none':
|
531 |
-
vmax = min(0.15, round(img[0].max().item(),2))
|
532 |
if augmentation == 'weak':
|
533 |
-
vmax = min(0.08, round(img[0].max().item(),2))
|
534 |
if augmentation == 'strong':
|
535 |
-
vmax = min(0.05, round(img[0].max().item(),2))
|
536 |
step = 0.01
|
537 |
-
|
538 |
-
vmin = int(vmin/step)*step
|
539 |
-
vmax = int(vmax/step)*step
|
540 |
|
541 |
-
|
|
|
|
|
|
|
542 |
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
|
543 |
-
marker_size=350
|
544 |
-
plt.scatter(x, y, c=torch.rot90(img[1][x,y]
|
545 |
-
|
546 |
-
ticks = [
|
|
|
547 |
cba = plt.colorbar(pad=0.06)
|
548 |
cba.set_ticks(ticks)
|
549 |
cba.ax.set_yticklabels(ticks)
|
550 |
# cmap = plt.cm.get_cmap('tab20c').reversed()
|
551 |
cmap = plt.cm.get_cmap('Reds')
|
552 |
-
plt.scatter(x,y, c=torch.rot90(img[0][x,y]
|
|
|
553 |
ticks = torch.arange(vmin, vmax, step).tolist()
|
554 |
-
ticks = [round(tick,2) for tick in ticks]
|
555 |
if ticks[-1] != vmax:
|
556 |
ticks.append(vmax)
|
557 |
cbb = plt.colorbar(pad=0.06)
|
@@ -563,26 +583,26 @@ def ABShowTable(dataset_name: str, augmentation: str):
|
|
563 |
cbb.ax.set_yticklabels(ticks)
|
564 |
for x in range(12):
|
565 |
for y in range(12):
|
566 |
-
txt = round(torch.rot90(img[0]
|
567 |
if str(txt) == '-0.0':
|
568 |
txt = '0.00'
|
569 |
elif str(txt) == '0.0':
|
570 |
txt = '0.00'
|
571 |
elif len(str(txt)) == 3:
|
572 |
-
txt = str(txt)+'0'
|
573 |
else:
|
574 |
txt = str(txt)
|
575 |
-
|
576 |
-
plt.text(x-0.25,y-0.1,txt, color='black', fontsize='x-small')
|
577 |
|
578 |
-
|
|
|
|
|
579 |
ax.set_xticklabels(classes)
|
580 |
-
ax.set_yticks(torch.linspace(0,11,12))
|
581 |
classes.reverse()
|
582 |
ax.set_yticklabels(classes)
|
583 |
classes.reverse()
|
584 |
-
plt.xticks(rotation
|
585 |
-
plt.yticks(rotation
|
586 |
cba.set_label('Standard Deviation')
|
587 |
plt.xlabel("Test pipelines")
|
588 |
plt.ylabel("Train pipelines")
|
@@ -590,62 +610,63 @@ def ABShowTable(dataset_name: str, augmentation: str):
|
|
590 |
if i == 0:
|
591 |
if dataset_name == 'DroneSegmentation':
|
592 |
cbb.set_label('IoU')
|
593 |
-
plt.savefig(os.path.join(path,f"{dataset_name}_{augmentation}_IoU.png"))
|
594 |
else:
|
595 |
cbb.set_label('Accuracy')
|
596 |
-
plt.savefig(os.path.join(path,f"{dataset_name}_{augmentation}_accuracies.png"))
|
597 |
elif i == 1:
|
598 |
if dataset_name == 'DroneSegmentation':
|
599 |
cbb.set_label('IoU_d-IoU')
|
600 |
else:
|
601 |
cbb.set_label('Accuracy_d - Accuracy')
|
602 |
-
plt.savefig(os.path.join(path,f"{dataset_name}_{augmentation}_differences.png"))
|
|
|
603 |
|
604 |
def CMakeTable(dataset_name: str, augmentation: str, severity: int, N_runs: int, download_model: bool):
|
605 |
-
|
606 |
-
path='results/Ctesting/tables'
|
607 |
if not os.path.exists(path):
|
608 |
os.makedirs(path)
|
609 |
-
|
610 |
-
demosaicings=['bilinear','malvar2004', 'menon2007']
|
611 |
-
sharpenings=['sharpening_filter', 'unsharp_masking']
|
612 |
-
denoisings=['median_denoising', 'gaussian_denoising']
|
613 |
|
614 |
-
|
615 |
-
|
|
|
|
|
|
|
|
|
616 |
|
617 |
-
runs={}
|
618 |
-
i=0
|
619 |
|
620 |
for dm in demosaicings:
|
621 |
for s in sharpenings:
|
622 |
for dn in denoisings:
|
623 |
for t in transformations:
|
624 |
-
pip = [dm,s,dn]
|
625 |
runs[f'run{i}'] = {
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
}
|
633 |
ABclass = ABtesting(
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
if dataset_name == 'DroneSegmentation':
|
649 |
IoU = ABclass.ABsegmentation()
|
650 |
runs[f'run{i}']['IoU'] = IoU
|
651 |
else:
|
@@ -656,26 +677,27 @@ def CMakeTable(dataset_name: str, augmentation: str, severity: int, N_runs: int,
|
|
656 |
runs[f'run{i}']['recall'] = recall
|
657 |
runs[f'run{i}']['f1_score'] = f1_score
|
658 |
|
659 |
-
with open(os.path.join(path,f'{dataset_name}_{augmentation}_runs.json'), 'w') as outfile:
|
660 |
json.dump(runs, outfile)
|
661 |
|
662 |
-
i+=1
|
|
|
663 |
|
664 |
def CShowTable(dataset_name, augmentation):
|
665 |
|
666 |
-
path='results/Ctesting/tables'
|
667 |
assert os.path.exists(path), 'No tables to plot'
|
668 |
|
669 |
json_file = os.path.join(path, f'{dataset_name}_{augmentation}_runs.txt')
|
670 |
|
671 |
-
transforms = ['identity','gauss_noise', 'shot', 'impulse', 'speckle',
|
672 |
-
|
673 |
|
674 |
pip = []
|
675 |
-
|
676 |
-
demosaicings=['bilinear','malvar2004', 'menon2007']
|
677 |
-
sharpenings=['sharpening_filter', 'unsharp_masking']
|
678 |
-
denoisings=['median_denoising', 'gaussian_denoising']
|
679 |
|
680 |
for dm in demosaicings:
|
681 |
for s in sharpenings:
|
@@ -685,52 +707,54 @@ def CShowTable(dataset_name, augmentation):
|
|
685 |
with open(json_file, 'r') as run_file:
|
686 |
runs = json.load(run_file)
|
687 |
|
688 |
-
metrics=torch.zeros((2,len(pip),len(transforms)))
|
689 |
|
690 |
-
i,j=0,0
|
691 |
|
692 |
for r in range(len(runs)):
|
693 |
-
|
694 |
-
run = runs['run'+str(r)]
|
695 |
if dataset_name == 'DroneSegmentation':
|
696 |
acc = run['IoU']
|
697 |
else:
|
698 |
acc = run['accuracy']
|
699 |
-
mu,sigma = round(acc[0],4),round(acc[1],4)
|
700 |
|
701 |
-
metrics[0,j,i] = mu
|
702 |
-
metrics[1,j,i] = sigma
|
703 |
-
|
704 |
-
i+=1
|
705 |
|
706 |
if i == len(transforms):
|
707 |
-
i=0
|
708 |
-
j+=1
|
709 |
|
710 |
# Plot with scatter
|
711 |
|
712 |
img = metrics
|
713 |
|
714 |
-
vmin=0.
|
715 |
-
vmax=1.
|
716 |
-
|
717 |
x, y = torch.arange(12), torch.arange(11)
|
718 |
x, y = torch.meshgrid(x, y)
|
719 |
|
720 |
-
fig = plt.figure(figsize=(10,6.2))
|
721 |
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
|
722 |
-
marker_size=350
|
723 |
-
plt.scatter(x, y, c=torch.rot90(img[1][x,y]
|
724 |
-
|
725 |
-
ticks = [
|
|
|
726 |
cba = plt.colorbar(pad=0.06)
|
727 |
cba.set_ticks(ticks)
|
728 |
cba.ax.set_yticklabels(ticks)
|
729 |
# cmap = plt.cm.get_cmap('tab20c').reversed()
|
730 |
cmap = plt.cm.get_cmap('Reds')
|
731 |
-
plt.scatter(x,y, c=torch.rot90(img[0][x,y]
|
|
|
732 |
ticks = torch.arange(vmin, vmax, step).tolist()
|
733 |
-
ticks = [round(tick,2) for tick in ticks]
|
734 |
if ticks[-1] != vmax:
|
735 |
ticks.append(vmax)
|
736 |
cbb = plt.colorbar(pad=0.06)
|
@@ -742,65 +766,66 @@ def CShowTable(dataset_name, augmentation):
|
|
742 |
cbb.ax.set_yticklabels(ticks)
|
743 |
for x in range(12):
|
744 |
for y in range(12):
|
745 |
-
txt = round(torch.rot90(img[0]
|
746 |
if str(txt) == '-0.0':
|
747 |
txt = '0.00'
|
748 |
elif str(txt) == '0.0':
|
749 |
txt = '0.00'
|
750 |
elif len(str(txt)) == 3:
|
751 |
-
txt = str(txt)+'0'
|
752 |
else:
|
753 |
txt = str(txt)
|
754 |
-
|
755 |
-
plt.text(x-0.25,y-0.1,txt, color='black', fontsize='x-small')
|
756 |
|
757 |
-
|
|
|
|
|
758 |
ax.set_xticklabels(transforms)
|
759 |
-
ax.set_yticks(torch.linspace(0,11,12))
|
760 |
pip.reverse()
|
761 |
ax.set_yticklabels(pip)
|
762 |
pip.reverse()
|
763 |
-
plt.xticks(rotation
|
764 |
-
plt.yticks(rotation
|
765 |
cba.set_label('Standard Deviation')
|
766 |
plt.xlabel("Pipelines")
|
767 |
plt.ylabel("Distortions")
|
768 |
if dataset_name == 'DroneSegmentation':
|
769 |
cbb.set_label('IoU')
|
770 |
-
plt.savefig(os.path.join(path,f"{dataset_name}_{augmentation}_IoU.png"))
|
771 |
else:
|
772 |
cbb.set_label('Accuracy')
|
773 |
-
plt.savefig(os.path.join(path,f"{dataset_name}_{augmentation}_accuracies.png"))
|
|
|
774 |
|
775 |
if __name__ == '__main__':
|
776 |
-
|
777 |
-
if args.mode == 'ABMakeTable':
|
778 |
ABMakeTable(args.dataset_name, args.augmentation, args.N_runs, args.download_model)
|
779 |
-
elif args.mode == 'ABShowTable':
|
780 |
ABShowTable(args.dataset_name, args.augmentation)
|
781 |
-
elif args.mode == 'ABShowImages':
|
782 |
-
ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
|
783 |
-
args.s_train, args.dn_train, args.dm_test, args.s_test,
|
784 |
args.dn_test, args.N_runs, download_model=args.download_model)
|
785 |
ABclass.ABShowImages()
|
786 |
-
elif args.mode == 'ABShowAllImages':
|
787 |
-
ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
|
788 |
-
|
789 |
-
|
790 |
ABclass.ABShowAllImages()
|
791 |
-
elif args.mode == 'CMakeTable':
|
792 |
CMakeTable(args.dataset_name, args.augmentation, args.severity, args.N_runs, args.download_model)
|
793 |
-
elif args.mode == 'CShowTable':
|
794 |
CShowTable(args.dataset_name, args.augmentation, args.severity)
|
795 |
-
elif args.mode == 'CShowImages':
|
796 |
-
ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
|
797 |
-
|
798 |
-
|
799 |
-
|
800 |
ABclass.CShowImages()
|
801 |
-
elif args.mode == 'CShowAllImages':
|
802 |
-
ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
|
803 |
-
|
804 |
-
|
805 |
-
|
806 |
ABclass.CShowAllImages()
|
|
|
8 |
from torchvision.transforms import Compose, Normalize
|
9 |
import torch.nn.functional as F
|
10 |
|
11 |
+
from dataset import get_dataset, Subset
|
12 |
from utils.base import get_mlflow_model_by_name, SmartFormatter
|
13 |
+
from processing.pipeline_numpy import RawProcessingPipeline
|
14 |
|
15 |
+
from utils.hendrycks_robustness import Distortions
|
16 |
|
17 |
import segmentation_models_pytorch as smp
|
18 |
|
|
|
20 |
|
21 |
parser = argparse.ArgumentParser(description="AB testing, Show Results", formatter_class=SmartFormatter)
|
22 |
|
23 |
+
# Select experiment
|
24 |
+
parser.add_argument("--mode", type=str, default="ABShowImages", choices=('ABMakeTable', 'ABShowTable', 'ABShowImages', 'ABShowAllImages', 'CMakeTable', 'CShowTable', 'CShowImages', 'CShowAllImages'),
|
25 |
help='R|Choose operation to compute. \n'
|
26 |
+
'A) Lens2Logit image generation: \n '
|
27 |
+
'ABMakeTable: Compute cross-validation metrics results \n '
|
28 |
+
'ABShowTable: Plot cross-validation results on a table \n '
|
29 |
+
'ABShowImages: Choose a training and testing image to compare different pipelines \n '
|
30 |
+
'ABShowAllImages: Plot all possible pipelines \n'
|
31 |
+
'B) Hendrycks Perturbations, C-type dataset: \n '
|
32 |
+
'CMakeTable: For each pipeline, it computes cross-validation metrics for different perturbations \n '
|
33 |
+
'CShowTable: Plot metrics for different pipelines and perturbations \n '
|
34 |
+
'CShowImages: Plot an image with a selected a pipeline and perturbation\n '
|
35 |
+
'CShowAllImages: Plot all possible perturbations for a fixed pipeline')
|
36 |
+
|
37 |
+
parser.add_argument("--dataset_name", type=str, default='Microscopy',
|
38 |
+
choices=['Microscopy', 'Drone', 'DroneSegmentation'], help='Choose dataset')
|
39 |
+
parser.add_argument("--augmentation", type=str, default='weak',
|
40 |
+
choices=['none', 'weak', 'strong'], help='Choose augmentation')
|
41 |
parser.add_argument("--N_runs", type=int, default=5, help='Number of k-fold splitting used in the training')
|
42 |
parser.add_argument("--download_model", default=False, action='store_true', help='Download Models in cache')
|
43 |
|
44 |
+
# Select pipelines
|
45 |
+
parser.add_argument("--dm_train", type=str, default='bilinear', choices=('bilinear', 'malvar2004',
|
46 |
+
'menon2007'), help='Choose demosaicing for training processing model')
|
47 |
+
parser.add_argument("--s_train", type=str, default='sharpening_filter', choices=('sharpening_filter',
|
48 |
+
'unsharp_masking'), help='Choose sharpening for training processing model')
|
49 |
+
parser.add_argument("--dn_train", type=str, default='gaussian_denoising', choices=('gaussian_denoising',
|
50 |
+
'median_denoising'), help='Choose denoising for training processing model')
|
51 |
+
parser.add_argument("--dm_test", type=str, default='bilinear', choices=('bilinear', 'malvar2004',
|
52 |
+
'menon2007'), help='Choose demosaicing for testing processing model')
|
53 |
+
parser.add_argument("--s_test", type=str, default='sharpening_filter', choices=('sharpening_filter',
|
54 |
+
'unsharp_masking'), help='Choose sharpening for testing processing model')
|
55 |
+
parser.add_argument("--dn_test", type=str, default='gaussian_denoising', choices=('gaussian_denoising',
|
56 |
+
'median_denoising'), help='Choose denoising for testing processing model')
|
57 |
+
|
58 |
+
# Select Ctest parameters
|
59 |
+
parser.add_argument("--transform", type=str, default='identity', choices=('identity', 'gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
|
60 |
+
'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform'), help='Choose transformation to show for Ctesting')
|
61 |
+
parser.add_argument("--severity", type=int, default=1, choices=(1, 2, 3, 4, 5), help='Choose severity for Ctesting')
|
62 |
|
63 |
args = parser.parse_args()
|
64 |
|
65 |
+
|
66 |
class metrics:
|
67 |
def __init__(self, confusion_matrix):
|
68 |
self.cm = confusion_matrix
|
69 |
self.N_classes = len(confusion_matrix)
|
70 |
|
71 |
def accuracy(self):
|
72 |
+
Tp = torch.diagonal(self.cm, 0).sum()
|
73 |
N_elements = torch.sum(self.cm)
|
74 |
+
return Tp / N_elements
|
75 |
+
|
76 |
def precision(self):
|
77 |
Tp_Fp = torch.sum(self.cm, 1)
|
78 |
Tp_Fp[Tp_Fp == 0] = 1
|
79 |
+
return torch.diagonal(self.cm, 0) / Tp_Fp
|
80 |
|
81 |
def recall(self):
|
82 |
Tp_Fn = torch.sum(self.cm, 0)
|
83 |
Tp_Fn[Tp_Fn == 0] = 1
|
84 |
+
return torch.diagonal(self.cm, 0) / Tp_Fn
|
85 |
|
86 |
def f1_score(self):
|
87 |
+
prod = (self.precision() * self.recall())
|
88 |
sum = (self.precision() + self.recall())
|
89 |
sum[sum == 0.] = 1.
|
90 |
+
return 2 * (prod / sum)
|
91 |
|
92 |
def over_N_runs(ms, N_runs):
|
93 |
+
m, m2 = 0, 0
|
94 |
|
95 |
for i in ms:
|
96 |
+
m += i
|
97 |
+
mu = m / N_runs
|
98 |
+
|
99 |
for i in ms:
|
100 |
+
m2 += (i - mu)**2
|
101 |
+
|
102 |
+
sigma = torch.sqrt(m2 / (N_runs - 1))
|
103 |
|
|
|
|
|
104 |
return mu.tolist(), sigma.tolist()
|
105 |
|
106 |
+
|
107 |
class ABtesting:
|
108 |
+
def __init__(self,
|
109 |
+
dataset_name: str,
|
110 |
+
augmentation: str,
|
111 |
+
dm_train: str,
|
112 |
+
s_train: str,
|
113 |
+
dn_train: str,
|
114 |
+
dm_test: str,
|
115 |
+
s_test: str,
|
116 |
+
dn_test: str,
|
117 |
+
N_runs: int,
|
118 |
+
severity=1,
|
119 |
+
transform='identity',
|
120 |
+
download_model=False):
|
121 |
self.experiment_name = 'ABtesting'
|
122 |
self.dataset_name = dataset_name
|
123 |
self.augmentation = augmentation
|
|
|
139 |
if sharpening == None:
|
140 |
sharpening = self.s_test
|
141 |
if denoising == None:
|
142 |
+
denoising = self.dn_test
|
143 |
if severity == None:
|
144 |
+
severity = self.severity
|
145 |
if transform == None:
|
146 |
transform = self.transform
|
147 |
+
|
148 |
dataset = get_dataset(self.dataset_name)
|
149 |
|
150 |
if self.dataset_name == "Drone" or self.dataset_name == "DroneSegmentation":
|
|
|
156 |
|
157 |
if not plot_mode:
|
158 |
dataset.transform = Compose([RawProcessingPipeline(
|
159 |
+
camera_parameters=dataset.camera_parameters,
|
160 |
+
debayer=debayer,
|
161 |
+
sharpening=sharpening,
|
162 |
+
denoising=denoising,
|
163 |
+
), Distortions(severity=severity, transform=transform),
|
164 |
+
Normalize(mean, std)])
|
165 |
else:
|
166 |
+
dataset.transform = Compose([RawProcessingPipeline(
|
167 |
+
camera_parameters=dataset.camera_parameters,
|
168 |
+
debayer=debayer,
|
169 |
+
sharpening=sharpening,
|
170 |
+
denoising=denoising,
|
171 |
+
), Distortions(severity=severity, transform=transform)])
|
172 |
|
173 |
return dataset
|
174 |
|
175 |
def ABclassification(self):
|
176 |
+
|
177 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
178 |
|
179 |
parent_run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}"
|
180 |
|
181 |
+
print(
|
182 |
+
f'\nTraining pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_train}, Sharpening: {self.s_train}, Denoiser: {self.dn_train} \n')
|
183 |
+
print(f'\nTesting pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_test}, Sharpening: {self.s_test}, Denoiser: {self.dn_test} \n Transform: {self.transform}, Severity: {self.severity}\n')
|
184 |
|
185 |
+
accuracies, precisions, recalls, f1_scores = [], [], [], []
|
186 |
|
187 |
os.system('rm -r /tmp/py*')
|
188 |
|
189 |
for N_run in range(self.N_runs):
|
190 |
|
191 |
print(f"Evaluating Run {N_run}")
|
192 |
+
|
193 |
+
run_name = parent_run_name + '_' + str(N_run)
|
194 |
|
195 |
state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name,
|
196 |
download_model=self.download_model)
|
|
|
201 |
|
202 |
model.eval()
|
203 |
|
204 |
+
len_classes = len(dataset.classes)
|
205 |
confusion_matrix = torch.zeros((len_classes, len_classes))
|
206 |
|
207 |
for img, label in valid_loader:
|
208 |
+
|
209 |
prediction = model(img.to(DEVICE)).detach().cpu()
|
210 |
+
prediction = torch.argmax(prediction, dim=1)
|
211 |
+
confusion_matrix[label, prediction] += 1 # Real value rows, Declared columns
|
212 |
|
213 |
m = metrics(confusion_matrix)
|
214 |
|
215 |
+
accuracies.append(m.accuracy())
|
216 |
precisions.append(m.precision())
|
217 |
recalls.append(m.recall())
|
218 |
f1_scores.append(m.f1_score())
|
|
|
224 |
recall = metrics.over_N_runs(recalls, self.N_runs)
|
225 |
f1_score = metrics.over_N_runs(f1_scores, self.N_runs)
|
226 |
return dataset.classes, accuracy, precision, recall, f1_score
|
227 |
+
|
228 |
def ABsegmentation(self):
|
229 |
+
|
230 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
231 |
|
232 |
parent_run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}"
|
233 |
|
234 |
+
print(
|
235 |
+
f'\nTraining pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_train}, Sharpening: {self.s_train}, Denoiser: {self.dn_train} \n')
|
236 |
+
print(f'\nTesting pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_test}, Sharpening: {self.s_test}, Denoiser: {self.dn_test} \n Transform: {self.transform}, Severity: {self.severity}\n')
|
237 |
|
238 |
IoUs = []
|
239 |
|
|
|
242 |
for N_run in range(self.N_runs):
|
243 |
|
244 |
print(f"Evaluating Run {N_run}")
|
|
|
|
|
245 |
|
246 |
+
run_name = parent_run_name + '_' + str(N_run)
|
247 |
+
|
248 |
+
state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name,
|
249 |
+
download_model=self.download_model)
|
250 |
|
251 |
dataset = self.static_pip_val()
|
252 |
+
|
253 |
valid_set = Subset(dataset, indices=state_dict['valid_indices'])
|
254 |
valid_loader = DataLoader(valid_set, batch_size=1, num_workers=16, shuffle=False)
|
255 |
|
256 |
model.eval()
|
257 |
|
258 |
+
IoU = 0
|
259 |
|
260 |
for img, label in valid_loader:
|
261 |
+
|
262 |
prediction = model(img.to(DEVICE)).detach().cpu()
|
263 |
prediction = F.logsigmoid(prediction).exp().squeeze()
|
264 |
+
IoU += smp.utils.metrics.IoU()(prediction, label)
|
265 |
|
266 |
+
IoU = IoU / len(valid_loader)
|
267 |
IoUs.append(IoU.item())
|
268 |
|
269 |
os.system('rm -r /tmp/t*')
|
270 |
|
271 |
IoU = metrics.over_N_runs(torch.tensor(IoUs), self.N_runs)
|
272 |
+
return IoU
|
273 |
|
274 |
def ABShowImages(self):
|
275 |
|
|
|
277 |
if not os.path.exists(path):
|
278 |
os.makedirs(path)
|
279 |
|
280 |
+
path = os.path.join(
|
281 |
+
path, f'{self.dataset_name}_{self.augmentation}_{self.dm_train[:2]}{self.s_train[0]}{self.dn_train[:2]}_{self.dm_test[:2]}{self.s_test[0]}{self.dn_test[:2]}')
|
282 |
|
283 |
if not os.path.exists(path):
|
284 |
os.makedirs(path)
|
285 |
|
286 |
+
run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}" + \
|
287 |
+
'_' + str(0)
|
288 |
|
289 |
state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name, download_model=self.download_model)
|
290 |
|
291 |
model.augmentation = None
|
292 |
+
|
293 |
+
for t in ([self.dm_train, self.s_train, self.dn_train, 'train_img'],
|
294 |
+
[self.dm_test, self.s_test, self.dn_test, 'test_img']):
|
295 |
+
|
296 |
debayer, sharpening, denoising, img_type = t[0], t[1], t[2], t[3]
|
297 |
|
298 |
dataset = self.static_pip_val(debayer=debayer, sharpening=sharpening, denoising=denoising, plot_mode=True)
|
299 |
valid_set = Subset(dataset, indices=state_dict['valid_indices'])
|
300 |
+
|
301 |
img, _ = next(iter(valid_set))
|
302 |
|
303 |
plt.figure()
|
304 |
+
plt.imshow(img.permute(1, 2, 0))
|
305 |
if img_type == 'train_img':
|
306 |
plt.title('Train Image')
|
307 |
plt.savefig(os.path.join(path, f'img_train.png'))
|
308 |
imgA = img
|
309 |
else:
|
310 |
plt.title('Test Image')
|
311 |
+
plt.savefig(os.path.join(path, f'img_test.png'))
|
312 |
+
|
313 |
+
for c, color in enumerate(['Red', 'Green', 'Blue']):
|
314 |
+
diff = torch.abs(imgA - img)
|
315 |
plt.figure()
|
316 |
# plt.imshow(diff.permute(1,2,0))
|
317 |
+
plt.imshow(diff[c, 50:200, 50:200], cmap=f'{color}s')
|
318 |
plt.title(f'|Train Image - Test Image| - {color}')
|
319 |
plt.colorbar()
|
320 |
plt.savefig(os.path.join(path, f'diff_{color}.png'))
|
321 |
plt.figure()
|
322 |
+
diff[diff == 0.] = 1e-5
|
323 |
# plt.imshow(torch.log(diff.permute(1,2,0)))
|
324 |
plt.imshow(torch.log(diff)[c])
|
325 |
plt.title(f'log(|Train Image - Test Image|) - color')
|
326 |
plt.colorbar()
|
327 |
plt.savefig(os.path.join(path, f'logdiff_{color}.png'))
|
328 |
+
|
329 |
if self.dataset_name == 'DroneSegmentation':
|
330 |
plt.figure()
|
331 |
plt.imshow(model(img[None].cuda()).detach().cpu().squeeze())
|
332 |
if img_type == 'train_img':
|
333 |
plt.savefig(os.path.join(path, f'mask_train.png'))
|
334 |
else:
|
335 |
+
plt.savefig(os.path.join(path, f'mask_test.png'))
|
336 |
|
337 |
def ABShowAllImages(self):
|
338 |
if not os.path.exists('results/ABtesting'):
|
339 |
os.makedirs('results/ABtesting')
|
340 |
|
341 |
+
demosaicings = ['bilinear', 'malvar2004', 'menon2007']
|
342 |
+
sharpenings = ['sharpening_filter', 'unsharp_masking']
|
343 |
+
denoisings = ['median_denoising', 'gaussian_denoising']
|
344 |
|
345 |
fig = plt.figure()
|
346 |
+
columns = 4
|
347 |
+
rows = 3
|
348 |
|
349 |
+
i = 1
|
350 |
|
351 |
for dm in demosaicings:
|
352 |
for s in sharpenings:
|
353 |
for dn in denoisings:
|
|
|
|
|
|
|
354 |
|
355 |
+
dataset = self.static_pip_val(self.dm_test, self.s_test,
|
356 |
+
self.dn_test, plot_mode=True)
|
357 |
+
|
358 |
+
img, _ = dataset[0]
|
359 |
+
|
360 |
fig.add_subplot(rows, columns, i)
|
361 |
+
plt.imshow(img.permute(1, 2, 0))
|
362 |
plt.title(f'{dm}\n{s}\n{dn}', fontsize=8)
|
363 |
plt.xticks([])
|
364 |
plt.yticks([])
|
365 |
plt.tight_layout()
|
366 |
|
367 |
+
i += 1
|
368 |
|
369 |
plt.show()
|
370 |
plt.savefig(f'results/ABtesting/ABpipelines.png')
|
371 |
|
372 |
def CShowImages(self):
|
373 |
+
|
374 |
path = 'results/Ctesting/imgs/'
|
375 |
if not os.path.exists(path):
|
376 |
os.makedirs(path)
|
377 |
|
378 |
+
run_name = f"{self.dataset_name}_{self.dm_test}_{self.s_test}_{self.dn_test}_{self.augmentation}" + '_' + str(0)
|
379 |
|
380 |
state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name, download_model=True)
|
381 |
|
382 |
model.augmentation = None
|
383 |
+
|
384 |
+
dataset = self.static_pip_val(self.dm_test, self.s_test, self.dn_test,
|
385 |
+
self.severity, self.transform, plot_mode=True)
|
386 |
valid_set = Subset(dataset, indices=state_dict['valid_indices'])
|
387 |
+
|
388 |
img, _ = next(iter(valid_set))
|
389 |
|
390 |
plt.figure()
|
391 |
+
plt.imshow(img.permute(1, 2, 0))
|
392 |
+
plt.savefig(os.path.join(
|
393 |
+
path, f'{self.dataset_name}_{self.augmentation}_{self.dm_train[:2]}{self.s_train[0]}{self.dn_train[:2]}_{self.transform}_sev{self.severity}'))
|
394 |
+
|
395 |
def CShowAllImages(self):
|
396 |
if not os.path.exists('results/Cimages'):
|
397 |
os.makedirs('results/Cimages')
|
398 |
|
399 |
+
transforms = ['identity', 'gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
|
400 |
+
'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform']
|
401 |
|
402 |
+
for i, t in enumerate(transforms):
|
403 |
+
|
404 |
+
fig = plt.figure(figsize=(10, 6))
|
405 |
columns = 5
|
406 |
rows = 1
|
407 |
|
408 |
+
for sev in range(1, 6):
|
409 |
|
410 |
dataset = self.static_pip_val(severity=sev, transform=t, plot_mode=True)
|
411 |
|
412 |
+
img, _ = dataset[0]
|
413 |
+
|
414 |
fig.add_subplot(rows, columns, sev)
|
415 |
+
plt.imshow(img.permute(1, 2, 0))
|
416 |
plt.title(f'Severity: {sev}')
|
417 |
plt.xticks([])
|
418 |
plt.yticks([])
|
419 |
plt.tight_layout()
|
420 |
+
|
421 |
if '_' in t:
|
422 |
+
t = t.replace('_', ' ')
|
423 |
+
t = t[0].upper() + t[1:]
|
424 |
|
425 |
fig.suptitle(f'{t}', x=0.5, y=0.8, fontsize=24)
|
426 |
plt.show()
|
427 |
plt.savefig(f'results/Cimages/{i+1}_{t.lower()}.png')
|
428 |
|
|
|
|
|
429 |
|
430 |
+
def ABMakeTable(dataset_name: str, augmentation: str,
|
431 |
+
N_runs: int, download_model: bool):
|
432 |
+
|
433 |
+
demosaicings = ['bilinear', 'malvar2004', 'menon2007']
|
434 |
+
sharpenings = ['sharpening_filter', 'unsharp_masking']
|
435 |
+
denoisings = ['median_denoising', 'gaussian_denoising']
|
436 |
|
437 |
+
path = 'results/ABtesting/tables'
|
438 |
if not os.path.exists(path):
|
439 |
os.makedirs(path)
|
440 |
|
441 |
+
runs = {}
|
442 |
+
i = 0
|
443 |
|
444 |
for dm_train in demosaicings:
|
445 |
for s_train in sharpenings:
|
|
|
448 |
for s_test in sharpenings:
|
449 |
for dn_test in denoisings:
|
450 |
train_pip = [dm_train, s_train, dn_train]
|
451 |
+
test_pip = [dm_test, s_test, dn_test]
|
452 |
runs[f'run{i}'] = {
|
453 |
+
'dataset': dataset_name,
|
454 |
+
'augmentation': augmentation,
|
455 |
+
'train_pip': train_pip,
|
456 |
+
'test_pip': test_pip,
|
457 |
+
'N_runs': N_runs
|
458 |
}
|
459 |
ABclass = ABtesting(
|
460 |
+
dataset_name=dataset_name,
|
461 |
+
augmentation=augmentation,
|
462 |
+
dm_train=dm_train,
|
463 |
+
s_train=s_train,
|
464 |
+
dn_train=dn_train,
|
465 |
+
dm_test=dm_test,
|
466 |
+
s_test=s_test,
|
467 |
+
dn_test=dn_test,
|
468 |
+
N_runs=N_runs,
|
469 |
+
download_model=download_model
|
470 |
+
)
|
471 |
+
|
472 |
+
if dataset_name == 'DroneSegmentation':
|
473 |
IoU = ABclass.ABsegmentation()
|
474 |
runs[f'run{i}']['IoU'] = IoU
|
475 |
else:
|
|
|
479 |
runs[f'run{i}']['precision'] = precision
|
480 |
runs[f'run{i}']['recall'] = recall
|
481 |
runs[f'run{i}']['f1_score'] = f1_score
|
482 |
+
|
483 |
+
with open(os.path.join(path, f'{dataset_name}_{augmentation}_runs.txt'), 'w') as outfile:
|
484 |
json.dump(runs, outfile)
|
485 |
|
486 |
+
i += 1
|
487 |
+
|
488 |
|
489 |
def ABShowTable(dataset_name: str, augmentation: str):
|
490 |
+
|
491 |
+
path = 'results/ABtesting/tables'
|
492 |
assert os.path.exists(path), 'No tables to plot'
|
493 |
|
494 |
json_file = os.path.join(path, f'{dataset_name}_{augmentation}_runs.txt')
|
|
|
496 |
with open(json_file, 'r') as run_file:
|
497 |
runs = json.load(run_file)
|
498 |
|
499 |
+
metrics = torch.zeros((2, 12, 12))
|
500 |
+
classes = []
|
501 |
|
502 |
+
i, j = 0, 0
|
503 |
|
504 |
for r in range(len(runs)):
|
505 |
+
|
506 |
+
run = runs['run' + str(r)]
|
507 |
if dataset_name == 'DroneSegmentation':
|
508 |
acc = run['IoU']
|
509 |
else:
|
|
|
512 |
class_list = run['test_pip']
|
513 |
class_name = f'{class_list[0][:2]},{class_list[1][:1]},{class_list[2][:2]}'
|
514 |
classes.append(class_name)
|
515 |
+
mu, sigma = round(acc[0], 4), round(acc[1], 4)
|
516 |
+
|
517 |
+
metrics[0, j, i] = mu
|
518 |
+
metrics[1, j, i] = sigma
|
519 |
|
520 |
+
i += 1
|
|
|
|
|
|
|
521 |
|
522 |
if i == 12:
|
523 |
+
i = 0
|
524 |
+
j += 1
|
525 |
|
526 |
differences = torch.zeros_like(metrics)
|
527 |
|
528 |
+
diag_mu = torch.diagonal(metrics[0], 0)
|
529 |
+
diag_sigma = torch.diagonal(metrics[1], 0)
|
530 |
+
|
531 |
for r in range(len(metrics[0])):
|
532 |
+
differences[0, r] = diag_mu[r] - metrics[0, r]
|
533 |
+
differences[1, r] = torch.sqrt(metrics[1, r]**2 + diag_sigma[r]**2)
|
534 |
|
535 |
# Plot with scatter
|
536 |
+
|
537 |
+
for i, img in enumerate([metrics, differences]):
|
538 |
|
539 |
x, y = torch.arange(12), torch.arange(12)
|
540 |
x, y = torch.meshgrid(x, y)
|
541 |
|
542 |
if i == 0:
|
543 |
+
vmin = max(0.65, round(img[0].min().item(), 2))
|
544 |
+
vmax = round(img[0].max().item(), 2)
|
545 |
step = 0.02
|
546 |
elif i == 1:
|
547 |
+
vmin = round(img[0].min().item(), 2)
|
548 |
if augmentation == 'none':
|
549 |
+
vmax = min(0.15, round(img[0].max().item(), 2))
|
550 |
if augmentation == 'weak':
|
551 |
+
vmax = min(0.08, round(img[0].max().item(), 2))
|
552 |
if augmentation == 'strong':
|
553 |
+
vmax = min(0.05, round(img[0].max().item(), 2))
|
554 |
step = 0.01
|
|
|
|
|
|
|
555 |
|
556 |
+
vmin = int(vmin / step) * step
|
557 |
+
vmax = int(vmax / step) * step
|
558 |
+
|
559 |
+
fig = plt.figure(figsize=(10, 6.2))
|
560 |
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
|
561 |
+
marker_size = 350
|
562 |
+
plt.scatter(x, y, c=torch.rot90(img[1][x, y], -1, [0, 1]), vmin=0.,
|
563 |
+
vmax=img[1].max(), cmap='viridis', s=marker_size * 2, marker='s')
|
564 |
+
ticks = torch.arange(0., img[1].max(), 0.03).tolist()
|
565 |
+
ticks = [round(tick, 2) for tick in ticks]
|
566 |
cba = plt.colorbar(pad=0.06)
|
567 |
cba.set_ticks(ticks)
|
568 |
cba.ax.set_yticklabels(ticks)
|
569 |
# cmap = plt.cm.get_cmap('tab20c').reversed()
|
570 |
cmap = plt.cm.get_cmap('Reds')
|
571 |
+
plt.scatter(x, y, c=torch.rot90(img[0][x, y], -1, [0, 1]), vmin=vmin,
|
572 |
+
vmax=vmax, cmap=cmap, s=marker_size, marker='s')
|
573 |
ticks = torch.arange(vmin, vmax, step).tolist()
|
574 |
+
ticks = [round(tick, 2) for tick in ticks]
|
575 |
if ticks[-1] != vmax:
|
576 |
ticks.append(vmax)
|
577 |
cbb = plt.colorbar(pad=0.06)
|
|
|
583 |
cbb.ax.set_yticklabels(ticks)
|
584 |
for x in range(12):
|
585 |
for y in range(12):
|
586 |
+
txt = round(torch.rot90(img[0], -1, [0, 1])[x, y].item(), 2)
|
587 |
if str(txt) == '-0.0':
|
588 |
txt = '0.00'
|
589 |
elif str(txt) == '0.0':
|
590 |
txt = '0.00'
|
591 |
elif len(str(txt)) == 3:
|
592 |
+
txt = str(txt) + '0'
|
593 |
else:
|
594 |
txt = str(txt)
|
|
|
|
|
595 |
|
596 |
+
plt.text(x - 0.25, y - 0.1, txt, color='black', fontsize='x-small')
|
597 |
+
|
598 |
+
ax.set_xticks(torch.linspace(0, 11, 12))
|
599 |
ax.set_xticklabels(classes)
|
600 |
+
ax.set_yticks(torch.linspace(0, 11, 12))
|
601 |
classes.reverse()
|
602 |
ax.set_yticklabels(classes)
|
603 |
classes.reverse()
|
604 |
+
plt.xticks(rotation=45)
|
605 |
+
plt.yticks(rotation=45)
|
606 |
cba.set_label('Standard Deviation')
|
607 |
plt.xlabel("Test pipelines")
|
608 |
plt.ylabel("Train pipelines")
|
|
|
610 |
if i == 0:
|
611 |
if dataset_name == 'DroneSegmentation':
|
612 |
cbb.set_label('IoU')
|
613 |
+
plt.savefig(os.path.join(path, f"{dataset_name}_{augmentation}_IoU.png"))
|
614 |
else:
|
615 |
cbb.set_label('Accuracy')
|
616 |
+
plt.savefig(os.path.join(path, f"{dataset_name}_{augmentation}_accuracies.png"))
|
617 |
elif i == 1:
|
618 |
if dataset_name == 'DroneSegmentation':
|
619 |
cbb.set_label('IoU_d-IoU')
|
620 |
else:
|
621 |
cbb.set_label('Accuracy_d - Accuracy')
|
622 |
+
plt.savefig(os.path.join(path, f"{dataset_name}_{augmentation}_differences.png"))
|
623 |
+
|
624 |
|
625 |
def CMakeTable(dataset_name: str, augmentation: str, severity: int, N_runs: int, download_model: bool):
|
626 |
+
|
627 |
+
path = 'results/Ctesting/tables'
|
628 |
if not os.path.exists(path):
|
629 |
os.makedirs(path)
|
|
|
|
|
|
|
|
|
630 |
|
631 |
+
demosaicings = ['bilinear', 'malvar2004', 'menon2007']
|
632 |
+
sharpenings = ['sharpening_filter', 'unsharp_masking']
|
633 |
+
denoisings = ['median_denoising', 'gaussian_denoising']
|
634 |
+
|
635 |
+
transformations = ['identity', 'gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
|
636 |
+
'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform']
|
637 |
|
638 |
+
runs = {}
|
639 |
+
i = 0
|
640 |
|
641 |
for dm in demosaicings:
|
642 |
for s in sharpenings:
|
643 |
for dn in denoisings:
|
644 |
for t in transformations:
|
645 |
+
pip = [dm, s, dn]
|
646 |
runs[f'run{i}'] = {
|
647 |
+
'dataset': dataset_name,
|
648 |
+
'augmentation': augmentation,
|
649 |
+
'pipeline': pip,
|
650 |
+
'N_runs': N_runs,
|
651 |
+
'transform': t,
|
652 |
+
'severity': severity,
|
653 |
}
|
654 |
ABclass = ABtesting(
|
655 |
+
dataset_name=dataset_name,
|
656 |
+
augmentation=augmentation,
|
657 |
+
dm_train=dm,
|
658 |
+
s_train=s,
|
659 |
+
dn_train=dn,
|
660 |
+
dm_test=dm,
|
661 |
+
s_test=s,
|
662 |
+
dn_test=dn,
|
663 |
+
severity=severity,
|
664 |
+
transform=t,
|
665 |
+
N_runs=N_runs,
|
666 |
+
download_model=download_model
|
667 |
+
)
|
668 |
+
|
669 |
+
if dataset_name == 'DroneSegmentation':
|
670 |
IoU = ABclass.ABsegmentation()
|
671 |
runs[f'run{i}']['IoU'] = IoU
|
672 |
else:
|
|
|
677 |
runs[f'run{i}']['recall'] = recall
|
678 |
runs[f'run{i}']['f1_score'] = f1_score
|
679 |
|
680 |
+
with open(os.path.join(path, f'{dataset_name}_{augmentation}_runs.json'), 'w') as outfile:
|
681 |
json.dump(runs, outfile)
|
682 |
|
683 |
+
i += 1
|
684 |
+
|
685 |
|
686 |
def CShowTable(dataset_name, augmentation):
|
687 |
|
688 |
+
path = 'results/Ctesting/tables'
|
689 |
assert os.path.exists(path), 'No tables to plot'
|
690 |
|
691 |
json_file = os.path.join(path, f'{dataset_name}_{augmentation}_runs.txt')
|
692 |
|
693 |
+
transforms = ['identity', 'gauss_noise', 'shot', 'impulse', 'speckle',
|
694 |
+
'gauss_blur', 'zoom', 'contrast', 'brightness', 'saturate', 'elastic']
|
695 |
|
696 |
pip = []
|
697 |
+
|
698 |
+
demosaicings = ['bilinear', 'malvar2004', 'menon2007']
|
699 |
+
sharpenings = ['sharpening_filter', 'unsharp_masking']
|
700 |
+
denoisings = ['median_denoising', 'gaussian_denoising']
|
701 |
|
702 |
for dm in demosaicings:
|
703 |
for s in sharpenings:
|
|
|
707 |
with open(json_file, 'r') as run_file:
|
708 |
runs = json.load(run_file)
|
709 |
|
710 |
+
metrics = torch.zeros((2, len(pip), len(transforms)))
|
711 |
|
712 |
+
i, j = 0, 0
|
713 |
|
714 |
for r in range(len(runs)):
|
715 |
+
|
716 |
+
run = runs['run' + str(r)]
|
717 |
if dataset_name == 'DroneSegmentation':
|
718 |
acc = run['IoU']
|
719 |
else:
|
720 |
acc = run['accuracy']
|
721 |
+
mu, sigma = round(acc[0], 4), round(acc[1], 4)
|
722 |
|
723 |
+
metrics[0, j, i] = mu
|
724 |
+
metrics[1, j, i] = sigma
|
725 |
+
|
726 |
+
i += 1
|
727 |
|
728 |
if i == len(transforms):
|
729 |
+
i = 0
|
730 |
+
j += 1
|
731 |
|
732 |
# Plot with scatter
|
733 |
|
734 |
img = metrics
|
735 |
|
736 |
+
vmin = 0.
|
737 |
+
vmax = 1.
|
738 |
+
|
739 |
x, y = torch.arange(12), torch.arange(11)
|
740 |
x, y = torch.meshgrid(x, y)
|
741 |
|
742 |
+
fig = plt.figure(figsize=(10, 6.2))
|
743 |
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
|
744 |
+
marker_size = 350
|
745 |
+
plt.scatter(x, y, c=torch.rot90(img[1][x, y], -1, [0, 1]), vmin=0.,
|
746 |
+
vmax=img[1].max(), cmap='viridis', s=marker_size * 2, marker='s')
|
747 |
+
ticks = torch.arange(0., img[1].max(), 0.03).tolist()
|
748 |
+
ticks = [round(tick, 2) for tick in ticks]
|
749 |
cba = plt.colorbar(pad=0.06)
|
750 |
cba.set_ticks(ticks)
|
751 |
cba.ax.set_yticklabels(ticks)
|
752 |
# cmap = plt.cm.get_cmap('tab20c').reversed()
|
753 |
cmap = plt.cm.get_cmap('Reds')
|
754 |
+
plt.scatter(x, y, c=torch.rot90(img[0][x, y], -1, [0, 1]), vmin=vmin,
|
755 |
+
vmax=vmax, cmap=cmap, s=marker_size, marker='s')
|
756 |
ticks = torch.arange(vmin, vmax, step).tolist()
|
757 |
+
ticks = [round(tick, 2) for tick in ticks]
|
758 |
if ticks[-1] != vmax:
|
759 |
ticks.append(vmax)
|
760 |
cbb = plt.colorbar(pad=0.06)
|
|
|
766 |
cbb.ax.set_yticklabels(ticks)
|
767 |
for x in range(12):
|
768 |
for y in range(12):
|
769 |
+
txt = round(torch.rot90(img[0], -1, [0, 1])[x, y].item(), 2)
|
770 |
if str(txt) == '-0.0':
|
771 |
txt = '0.00'
|
772 |
elif str(txt) == '0.0':
|
773 |
txt = '0.00'
|
774 |
elif len(str(txt)) == 3:
|
775 |
+
txt = str(txt) + '0'
|
776 |
else:
|
777 |
txt = str(txt)
|
|
|
|
|
778 |
|
779 |
+
plt.text(x - 0.25, y - 0.1, txt, color='black', fontsize='x-small')
|
780 |
+
|
781 |
+
ax.set_xticks(torch.linspace(0, 11, 12))
|
782 |
ax.set_xticklabels(transforms)
|
783 |
+
ax.set_yticks(torch.linspace(0, 11, 12))
|
784 |
pip.reverse()
|
785 |
ax.set_yticklabels(pip)
|
786 |
pip.reverse()
|
787 |
+
plt.xticks(rotation=45)
|
788 |
+
plt.yticks(rotation=45)
|
789 |
cba.set_label('Standard Deviation')
|
790 |
plt.xlabel("Pipelines")
|
791 |
plt.ylabel("Distortions")
|
792 |
if dataset_name == 'DroneSegmentation':
|
793 |
cbb.set_label('IoU')
|
794 |
+
plt.savefig(os.path.join(path, f"{dataset_name}_{augmentation}_IoU.png"))
|
795 |
else:
|
796 |
cbb.set_label('Accuracy')
|
797 |
+
plt.savefig(os.path.join(path, f"{dataset_name}_{augmentation}_accuracies.png"))
|
798 |
+
|
799 |
|
800 |
if __name__ == '__main__':
|
801 |
+
|
802 |
+
if args.mode == 'ABMakeTable':
|
803 |
ABMakeTable(args.dataset_name, args.augmentation, args.N_runs, args.download_model)
|
804 |
+
elif args.mode == 'ABShowTable':
|
805 |
ABShowTable(args.dataset_name, args.augmentation)
|
806 |
+
elif args.mode == 'ABShowImages':
|
807 |
+
ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
|
808 |
+
args.s_train, args.dn_train, args.dm_test, args.s_test,
|
809 |
args.dn_test, args.N_runs, download_model=args.download_model)
|
810 |
ABclass.ABShowImages()
|
811 |
+
elif args.mode == 'ABShowAllImages':
|
812 |
+
ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
|
813 |
+
args.s_train, args.dn_train, args.dm_test, args.s_test,
|
814 |
+
args.dn_test, args.N_runs, download_model=args.download_model)
|
815 |
ABclass.ABShowAllImages()
|
816 |
+
elif args.mode == 'CMakeTable':
|
817 |
CMakeTable(args.dataset_name, args.augmentation, args.severity, args.N_runs, args.download_model)
|
818 |
+
elif args.mode == 'CShowTable': # TODO test it
|
819 |
CShowTable(args.dataset_name, args.augmentation, args.severity)
|
820 |
+
elif args.mode == 'CShowImages':
|
821 |
+
ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
|
822 |
+
args.s_train, args.dn_train, args.dm_test, args.s_test,
|
823 |
+
args.dn_test, args.N_runs, args.severity, args.transform,
|
824 |
+
download_model=args.download_model)
|
825 |
ABclass.CShowImages()
|
826 |
+
elif args.mode == 'CShowAllImages':
|
827 |
+
ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
|
828 |
+
args.s_train, args.dn_train, args.dm_test, args.s_test,
|
829 |
+
args.dn_test, args.N_runs, args.severity, args.transform,
|
830 |
+
download_model=args.download_model)
|
831 |
ABclass.CShowAllImages()
|
figure1.sh β figures/figure1.sh
RENAMED
File without changes
|
figure2.sh β figures/figure2.sh
RENAMED
File without changes
|
figures.py β figures/figures.py
RENAMED
File without changes
|
{processingpipeline β figures}/numpy_static_pipeline_show.ipynb
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9fd7f0e985d60a9d24fe74ef0faf0e8e10258416190be64d7b06908306f0e7fc
|
3 |
+
size 1906578
|
sanity_checks_and_statistics.ipynb β figures/sanity_checks_and_statistics.ipynb
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:446fa03594d28d8a23aaa4d721db22e0453724edd4199970418a631e2f4e7b18
|
3 |
+
size 6103863
|
show_classification_results.ipynb β figures/show_classification_results.ipynb
RENAMED
File without changes
|
{utils β figures}/show_dataset.ipynb
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:11568132f27aee6b772327127e7dac8b2bd81bc219a2b091e74dd8d6514e419d
|
3 |
+
size 2115521
|
show_results.sh β figures/show_results.sh
RENAMED
File without changes
|
train.sh β figures/train.sh
RENAMED
File without changes
|
models/classifier.py β model.py
RENAMED
@@ -29,6 +29,7 @@ class LitModel(pl.LightningModule):
|
|
29 |
weight_decay=0,
|
30 |
loss_aux=None,
|
31 |
adv_training=False,
|
|
|
32 |
metrics=None,
|
33 |
processor=None,
|
34 |
augmentation=None,
|
@@ -57,11 +58,19 @@ class LitModel(pl.LightningModule):
|
|
57 |
self.freeze_classifier = freeze_classifier
|
58 |
self.freeze_processor = freeze_processor
|
59 |
|
|
|
60 |
if freeze_classifier:
|
61 |
pl.LightningModule.freeze(self.classifier)
|
62 |
if freeze_processor:
|
63 |
pl.LightningModule.freeze(self.processor)
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
def forward(self, x):
|
66 |
x = self.processor(x)
|
67 |
apply_augmentation_step = self.training or self.augmentation_on_eval
|
@@ -97,7 +106,6 @@ class LitModel(pl.LightningModule):
|
|
97 |
y_hat = F.logsigmoid(logits).exp().squeeze()
|
98 |
else:
|
99 |
y_hat = torch.argmax(logits, dim=1)
|
100 |
-
|
101 |
|
102 |
if self.metrics is not None:
|
103 |
for metric in self.metrics:
|
@@ -105,11 +113,15 @@ class LitModel(pl.LightningModule):
|
|
105 |
if metric_name == 'accuracy' or not self.training or self.metrics_on_training:
|
106 |
m = metric(y_hat.cpu().detach(), y.cpu())
|
107 |
self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True,
|
108 |
-
prog_bar=self.training or metric_name == 'accuracy')
|
109 |
if metric_name == 'iou_score' or not self.training or self.metrics_on_training:
|
110 |
m = metric(y_hat.cpu().detach(), y.cpu())
|
111 |
self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True,
|
112 |
prog_bar=self.training or metric_name == 'iou_score')
|
|
|
|
|
|
|
|
|
113 |
|
114 |
return loss
|
115 |
|
@@ -124,25 +136,15 @@ class LitModel(pl.LightningModule):
|
|
124 |
|
125 |
def train(self, mode=True):
|
126 |
self.training = mode
|
127 |
-
|
128 |
-
|
|
|
129 |
self.classifier.train(mode=mode and not self.freeze_classifier)
|
130 |
-
if self.adv_training and self.processor.batch_norm is not None: # don't update batchnorm in adversarial training
|
131 |
-
self.processor.batch_norm.track_running_stats = False
|
132 |
return self
|
133 |
|
134 |
def configure_optimizers(self):
|
135 |
self.optimizer = torch.optim.Adam(self.parameters(), self.lr, weight_decay=self.weight_decay)
|
136 |
-
# parameters = [self.processor.additive_layer]
|
137 |
-
# self.optimizer = torch.optim.Adam(parameters, self.lr, weight_decay=self.weight_decay)
|
138 |
return self.optimizer
|
139 |
-
# self.scheduler = {
|
140 |
-
# 'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(
|
141 |
-
# self.optimizer, mode='min', factor=0.2, patience=2, min_lr=1e-6, verbose=True,
|
142 |
-
# ),
|
143 |
-
# 'monitor': 'val_loss',
|
144 |
-
# }
|
145 |
-
# return [self.optimizer], [self.scheduler]
|
146 |
|
147 |
def get_progress_bar_dict(self):
|
148 |
items = super().get_progress_bar_dict()
|
@@ -151,7 +153,7 @@ class LitModel(pl.LightningModule):
|
|
151 |
|
152 |
|
153 |
class TrackImagesCallback(pl.callbacks.base.Callback):
|
154 |
-
def __init__(self, data_loader, track_every_epoch=False, track_processing=True, track_gradients=True, track_predictions=True, save_tensors=True):
|
155 |
super().__init__()
|
156 |
self.data_loader = data_loader
|
157 |
|
@@ -162,9 +164,12 @@ class TrackImagesCallback(pl.callbacks.base.Callback):
|
|
162 |
self.track_predictions = track_predictions
|
163 |
self.save_tensors = save_tensors
|
164 |
|
165 |
-
|
166 |
-
|
|
|
|
|
167 |
self.data_loader,
|
|
|
168 |
track_processing=self.track_processing,
|
169 |
track_gradients=self.track_gradients,
|
170 |
track_predictions=self.track_predictions,
|
@@ -175,12 +180,12 @@ class TrackImagesCallback(pl.callbacks.base.Callback):
|
|
175 |
def on_fit_end(self, trainer, pl_module):
|
176 |
if not self.track_every_epoch:
|
177 |
save_loc = 'results'
|
178 |
-
self.callback_track_images(trainer, save_loc)
|
179 |
|
180 |
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
181 |
if self.track_every_epoch:
|
182 |
save_loc = f'results/epoch_{trainer.current_epoch + 1:04d}'
|
183 |
-
self.callback_track_images(trainer, save_loc)
|
184 |
|
185 |
|
186 |
from utils.debug import debug
|
@@ -201,7 +206,7 @@ def log_tensor(batch, path, save_tensors=True, nrow=8):
|
|
201 |
mlflow.log_artifact(img_path, os.path.dirname(path))
|
202 |
|
203 |
|
204 |
-
def track_images(model, data_loader, track_processing=True, track_gradients=True, track_predictions=True, save_tensors=True, save_loc='results'):
|
205 |
|
206 |
device = model.device
|
207 |
processor = model.processor
|
@@ -219,6 +224,9 @@ def track_images(model, data_loader, track_processing=True, track_gradients=True
|
|
219 |
logits_full = []
|
220 |
stages_full = defaultdict(list)
|
221 |
grads_full = defaultdict(list)
|
|
|
|
|
|
|
222 |
|
223 |
for inputs, labels in data_loader:
|
224 |
|
@@ -227,6 +235,10 @@ def track_images(model, data_loader, track_processing=True, track_gradients=True
|
|
227 |
|
228 |
processed_rgb = processor(inputs)
|
229 |
|
|
|
|
|
|
|
|
|
230 |
if track_gradients or track_predictions:
|
231 |
logits = classifier(processed_rgb)
|
232 |
|
@@ -241,6 +253,8 @@ def track_images(model, data_loader, track_processing=True, track_gradients=True
|
|
241 |
|
242 |
for stage, batch in processor.stages.items():
|
243 |
stages_full[stage].append(batch.cpu().detach())
|
|
|
|
|
244 |
if track_gradients:
|
245 |
grads_full[stage].append(batch.grad.cpu().detach())
|
246 |
|
@@ -248,19 +262,29 @@ def track_images(model, data_loader, track_processing=True, track_gradients=True
|
|
248 |
|
249 |
stages = stages_full
|
250 |
grads = grads_full
|
|
|
251 |
|
252 |
if track_processing:
|
253 |
-
for stage, batch in
|
254 |
stages[stage] = torch.cat(batch)
|
255 |
|
|
|
|
|
|
|
|
|
256 |
if track_gradients:
|
257 |
-
for stage, batch in
|
258 |
grads[stage] = torch.cat(batch)
|
259 |
|
260 |
for stage_nr, stage_name in enumerate(stages):
|
261 |
if track_processing:
|
262 |
batch = stages[stage_name]
|
263 |
log_tensor(batch, os.path.join(save_loc, f'processing_{stage_nr}_{stage_name}.pt'), save_tensors)
|
|
|
|
|
|
|
|
|
|
|
264 |
if track_gradients:
|
265 |
batch_grad = grads[stage_name]
|
266 |
batch_grad = batch_grad.abs()
|
@@ -270,11 +294,11 @@ def track_images(model, data_loader, track_processing=True, track_gradients=True
|
|
270 |
|
271 |
# inputs = torch.cat(inputs_full)
|
272 |
|
273 |
-
if track_predictions:
|
274 |
labels = torch.cat(labels_full)
|
275 |
logits = torch.cat(logits_full)
|
276 |
masks = labels.unsqueeze(1)
|
277 |
-
predictions = logits
|
278 |
#mask_vis = torch.cat((masks, predictions, masks * predictions), dim=1)
|
279 |
#log_tensor(mask_vis, os.path.join(save_loc, f'masks.pt'), save_tensors)
|
280 |
log_tensor(masks, os.path.join(save_loc, f'targets.pt'), save_tensors)
|
|
|
29 |
weight_decay=0,
|
30 |
loss_aux=None,
|
31 |
adv_training=False,
|
32 |
+
adv_parameters='all',
|
33 |
metrics=None,
|
34 |
processor=None,
|
35 |
augmentation=None,
|
|
|
58 |
self.freeze_classifier = freeze_classifier
|
59 |
self.freeze_processor = freeze_processor
|
60 |
|
61 |
+
self.unfreeze()
|
62 |
if freeze_classifier:
|
63 |
pl.LightningModule.freeze(self.classifier)
|
64 |
if freeze_processor:
|
65 |
pl.LightningModule.freeze(self.processor)
|
66 |
|
67 |
+
if adv_training and adv_parameters != 'all':
|
68 |
+
if adv_parameters != 'all':
|
69 |
+
pl.LightningModule.freeze(self.processor)
|
70 |
+
for name, p in self.processor.named_parameters():
|
71 |
+
if adv_parameters in name:
|
72 |
+
p.requires_grad = True
|
73 |
+
|
74 |
def forward(self, x):
|
75 |
x = self.processor(x)
|
76 |
apply_augmentation_step = self.training or self.augmentation_on_eval
|
|
|
106 |
y_hat = F.logsigmoid(logits).exp().squeeze()
|
107 |
else:
|
108 |
y_hat = torch.argmax(logits, dim=1)
|
|
|
109 |
|
110 |
if self.metrics is not None:
|
111 |
for metric in self.metrics:
|
|
|
113 |
if metric_name == 'accuracy' or not self.training or self.metrics_on_training:
|
114 |
m = metric(y_hat.cpu().detach(), y.cpu())
|
115 |
self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True,
|
116 |
+
prog_bar=self.training or metric_name == 'accuracy')
|
117 |
if metric_name == 'iou_score' or not self.training or self.metrics_on_training:
|
118 |
m = metric(y_hat.cpu().detach(), y.cpu())
|
119 |
self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True,
|
120 |
prog_bar=self.training or metric_name == 'iou_score')
|
121 |
+
elif metric_name == 'accuracy' or not self.training or self.metrics_on_training:
|
122 |
+
m = metric(y_hat.cpu().detach(), y.cpu())
|
123 |
+
self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True,
|
124 |
+
prog_bar=self.training or metric_name == 'accuracy')
|
125 |
|
126 |
return loss
|
127 |
|
|
|
136 |
|
137 |
def train(self, mode=True):
|
138 |
self.training = mode
|
139 |
+
|
140 |
+
# don't update batchnorm in adversarial training
|
141 |
+
self.processor.train(mode=mode and not self.freeze_processor and not self.adv_training)
|
142 |
self.classifier.train(mode=mode and not self.freeze_classifier)
|
|
|
|
|
143 |
return self
|
144 |
|
145 |
def configure_optimizers(self):
|
146 |
self.optimizer = torch.optim.Adam(self.parameters(), self.lr, weight_decay=self.weight_decay)
|
|
|
|
|
147 |
return self.optimizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
def get_progress_bar_dict(self):
|
150 |
items = super().get_progress_bar_dict()
|
|
|
153 |
|
154 |
|
155 |
class TrackImagesCallback(pl.callbacks.base.Callback):
|
156 |
+
def __init__(self, data_loader, reference_processor=None, track_every_epoch=False, track_processing=True, track_gradients=True, track_predictions=True, save_tensors=True):
|
157 |
super().__init__()
|
158 |
self.data_loader = data_loader
|
159 |
|
|
|
164 |
self.track_predictions = track_predictions
|
165 |
self.save_tensors = save_tensors
|
166 |
|
167 |
+
self.reference_processor = reference_processor
|
168 |
+
|
169 |
+
def callback_track_images(self, model, save_loc):
|
170 |
+
track_images(model,
|
171 |
self.data_loader,
|
172 |
+
reference_processor=self.reference_processor,
|
173 |
track_processing=self.track_processing,
|
174 |
track_gradients=self.track_gradients,
|
175 |
track_predictions=self.track_predictions,
|
|
|
180 |
def on_fit_end(self, trainer, pl_module):
|
181 |
if not self.track_every_epoch:
|
182 |
save_loc = 'results'
|
183 |
+
self.callback_track_images(trainer.model, save_loc)
|
184 |
|
185 |
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
186 |
if self.track_every_epoch:
|
187 |
save_loc = f'results/epoch_{trainer.current_epoch + 1:04d}'
|
188 |
+
self.callback_track_images(trainer.model, save_loc)
|
189 |
|
190 |
|
191 |
from utils.debug import debug
|
|
|
206 |
mlflow.log_artifact(img_path, os.path.dirname(path))
|
207 |
|
208 |
|
209 |
+
def track_images(model, data_loader, reference_processor=None, track_processing=True, track_gradients=True, track_predictions=True, save_tensors=True, save_loc='results'):
|
210 |
|
211 |
device = model.device
|
212 |
processor = model.processor
|
|
|
224 |
logits_full = []
|
225 |
stages_full = defaultdict(list)
|
226 |
grads_full = defaultdict(list)
|
227 |
+
diffs_full = defaultdict(list)
|
228 |
+
|
229 |
+
track_differences = reference_processor is not None
|
230 |
|
231 |
for inputs, labels in data_loader:
|
232 |
|
|
|
235 |
|
236 |
processed_rgb = processor(inputs)
|
237 |
|
238 |
+
if track_differences:
|
239 |
+
# debug(processor)
|
240 |
+
processed_rgb_ref = reference_processor(inputs)
|
241 |
+
|
242 |
if track_gradients or track_predictions:
|
243 |
logits = classifier(processed_rgb)
|
244 |
|
|
|
253 |
|
254 |
for stage, batch in processor.stages.items():
|
255 |
stages_full[stage].append(batch.cpu().detach())
|
256 |
+
if track_differences:
|
257 |
+
diffs_full[stage].append((reference_processor.stages[stage] - batch).cpu().detach())
|
258 |
if track_gradients:
|
259 |
grads_full[stage].append(batch.grad.cpu().detach())
|
260 |
|
|
|
262 |
|
263 |
stages = stages_full
|
264 |
grads = grads_full
|
265 |
+
diffs = diffs_full
|
266 |
|
267 |
if track_processing:
|
268 |
+
for stage, batch in stages.items():
|
269 |
stages[stage] = torch.cat(batch)
|
270 |
|
271 |
+
if track_differences:
|
272 |
+
for stage, batch in diffs.items():
|
273 |
+
diffs[stage] = torch.cat(batch)
|
274 |
+
|
275 |
if track_gradients:
|
276 |
+
for stage, batch in grads.items():
|
277 |
grads[stage] = torch.cat(batch)
|
278 |
|
279 |
for stage_nr, stage_name in enumerate(stages):
|
280 |
if track_processing:
|
281 |
batch = stages[stage_name]
|
282 |
log_tensor(batch, os.path.join(save_loc, f'processing_{stage_nr}_{stage_name}.pt'), save_tensors)
|
283 |
+
|
284 |
+
if track_differences:
|
285 |
+
batch = diffs[stage_name]
|
286 |
+
log_tensor(batch, os.path.join(save_loc, f'diffs_{stage_nr}_{stage_name}.pt'), False)
|
287 |
+
|
288 |
if track_gradients:
|
289 |
batch_grad = grads[stage_name]
|
290 |
batch_grad = batch_grad.abs()
|
|
|
294 |
|
295 |
# inputs = torch.cat(inputs_full)
|
296 |
|
297 |
+
if track_predictions: # and model.is_segmentation_task:
|
298 |
labels = torch.cat(labels_full)
|
299 |
logits = torch.cat(logits_full)
|
300 |
masks = labels.unsqueeze(1)
|
301 |
+
predictions = logits # torch.sigmoid(logits).unsqueeze(1)
|
302 |
#mask_vis = torch.cat((masks, predictions, masks * predictions), dim=1)
|
303 |
#log_tensor(mask_vis, os.path.join(save_loc, f'masks.pt'), save_tensors)
|
304 |
log_tensor(masks, os.path.join(save_loc, f'targets.pt'), save_tensors)
|
processingpipeline/pipeline.py β processing/pipeline_numpy.py
RENAMED
@@ -23,7 +23,7 @@ from colour_demosaicing import (demosaicing_CFA_Bayer_bilinear,
|
|
23 |
import torch
|
24 |
import numpy as np
|
25 |
|
26 |
-
from
|
27 |
from torch.utils.data import DataLoader
|
28 |
|
29 |
from colour_demosaicing import (demosaicing_CFA_Bayer_bilinear,
|
|
|
23 |
import torch
|
24 |
import numpy as np
|
25 |
|
26 |
+
from dataset import Subset
|
27 |
from torch.utils.data import DataLoader
|
28 |
|
29 |
from colour_demosaicing import (demosaicing_CFA_Bayer_bilinear,
|
processingpipeline/torch_pipeline.py β processing/pipeline_torch.py
RENAMED
@@ -5,7 +5,7 @@ import torch.nn as nn
|
|
5 |
if not os.path.exists('README.md'):
|
6 |
os.chdir('..')
|
7 |
|
8 |
-
from
|
9 |
from utils.base import np2torch, torch2np
|
10 |
|
11 |
import segmentation_models_pytorch as smp
|
@@ -83,7 +83,7 @@ class NNProcessing(nn.Module):
|
|
83 |
in_channels=3,
|
84 |
classes=3,
|
85 |
)
|
86 |
-
self.batch_norm = None if not batch_norm_output else nn.BatchNorm2d(3)
|
87 |
self.normalize_mosaic = normalize_mosaic
|
88 |
|
89 |
def forward(self, raw):
|
@@ -108,20 +108,23 @@ class NNProcessing(nn.Module):
|
|
108 |
return rgb
|
109 |
|
110 |
|
|
|
|
|
|
|
|
|
|
|
111 |
class ParametrizedProcessing(nn.Module):
|
112 |
-
def __init__(self, camera_parameters, track_stages=False, batch_norm_output=True
|
113 |
super().__init__()
|
114 |
self.stages = None
|
115 |
self.buffer = None
|
116 |
self.track_stages = track_stages
|
117 |
|
118 |
black_level, white_balance, colour_matrix = camera_parameters
|
119 |
-
|
120 |
-
self.
|
121 |
-
|
122 |
-
|
123 |
-
self.register_buffer('M_RGB_2_YUV', M_RGB_2_YUV.clone())
|
124 |
-
self.register_buffer('M_YUV_2_RGB', M_YUV_2_RGB.clone())
|
125 |
|
126 |
self.gamma_correct = nn.Parameter(torch.Tensor([2.2]))
|
127 |
|
@@ -133,14 +136,12 @@ class ParametrizedProcessing(nn.Module):
|
|
133 |
self.gaussian_blur = nn.Conv2d(1, 1, kernel_size=5, padding=2, padding_mode='reflect', bias=False)
|
134 |
self.gaussian_blur.weight.data[0][0] = K_BLUR.clone()
|
135 |
|
136 |
-
self.batch_norm = nn.BatchNorm2d(3) if batch_norm_output else None
|
137 |
|
138 |
-
|
139 |
-
|
140 |
-
# param.requires_grad = False
|
141 |
|
142 |
-
self.additive_layer =
|
143 |
-
) if noise_layer else None # XXX: can this be 0?
|
144 |
|
145 |
def forward(self, raw):
|
146 |
assert raw.ndim == 3, f"needs dims (B, H, W), got {raw.shape}"
|
@@ -157,6 +158,7 @@ class ParametrizedProcessing(nn.Module):
|
|
157 |
rgb = self.debayer(rgb)
|
158 |
# self.stages['debayer'] = rgb
|
159 |
|
|
|
160 |
rgb = torch.einsum('bchw,kc->bkhw', rgb, self.colour_correction).contiguous()
|
161 |
self.stages['color_correct'] = rgb
|
162 |
|
@@ -179,7 +181,6 @@ class ParametrizedProcessing(nn.Module):
|
|
179 |
self.stages['gamma_correct'] = rgb
|
180 |
|
181 |
if self.additive_layer is not None:
|
182 |
-
# rgb = rgb + 0 * self.additive_layer
|
183 |
rgb = rgb + self.additive_layer
|
184 |
self.stages['noise'] = rgb
|
185 |
|
@@ -259,11 +260,11 @@ if __name__ == "__main__":
|
|
259 |
os.chdir('..')
|
260 |
|
261 |
import matplotlib.pyplot as plt
|
262 |
-
from
|
263 |
from utils.base import np2torch, torch2np
|
264 |
|
265 |
from utils.debug import debug
|
266 |
-
from
|
267 |
|
268 |
raw_dataset = get_dataset('DS')
|
269 |
loader = torch.utils.data.DataLoader(raw_dataset, batch_size=1)
|
|
|
5 |
if not os.path.exists('README.md'):
|
6 |
os.chdir('..')
|
7 |
|
8 |
+
from processing.pipeline_numpy import processing as default_processing
|
9 |
from utils.base import np2torch, torch2np
|
10 |
|
11 |
import segmentation_models_pytorch as smp
|
|
|
83 |
in_channels=3,
|
84 |
classes=3,
|
85 |
)
|
86 |
+
self.batch_norm = None if not batch_norm_output else nn.BatchNorm2d(3, affine=False)
|
87 |
self.normalize_mosaic = normalize_mosaic
|
88 |
|
89 |
def forward(self, raw):
|
|
|
108 |
return rgb
|
109 |
|
110 |
|
111 |
+
def add_additive_layer(processor):
|
112 |
+
processor.additive_layer = nn.Parameter(torch.zeros((1, 3, 256, 256)))
|
113 |
+
# processor.additive_layer = nn.Parameter(0.001 * torch.randn((1, 3, 256, 256)))
|
114 |
+
|
115 |
+
|
116 |
class ParametrizedProcessing(nn.Module):
|
117 |
+
def __init__(self, camera_parameters, track_stages=False, batch_norm_output=True):
|
118 |
super().__init__()
|
119 |
self.stages = None
|
120 |
self.buffer = None
|
121 |
self.track_stages = track_stages
|
122 |
|
123 |
black_level, white_balance, colour_matrix = camera_parameters
|
124 |
+
|
125 |
+
self.black_level = nn.Parameter(torch.as_tensor(black_level))
|
126 |
+
self.white_balance = nn.Parameter(torch.as_tensor(white_balance).reshape(1, 3))
|
127 |
+
self.colour_correction = nn.Parameter(torch.as_tensor(colour_matrix).reshape(3, 3))
|
|
|
|
|
128 |
|
129 |
self.gamma_correct = nn.Parameter(torch.Tensor([2.2]))
|
130 |
|
|
|
136 |
self.gaussian_blur = nn.Conv2d(1, 1, kernel_size=5, padding=2, padding_mode='reflect', bias=False)
|
137 |
self.gaussian_blur.weight.data[0][0] = K_BLUR.clone()
|
138 |
|
139 |
+
self.batch_norm = nn.BatchNorm2d(3, affine=False) if batch_norm_output else None
|
140 |
|
141 |
+
self.register_buffer('M_RGB_2_YUV', M_RGB_2_YUV.clone())
|
142 |
+
self.register_buffer('M_YUV_2_RGB', M_YUV_2_RGB.clone())
|
|
|
143 |
|
144 |
+
self.additive_layer = None # this can be added in later
|
|
|
145 |
|
146 |
def forward(self, raw):
|
147 |
assert raw.ndim == 3, f"needs dims (B, H, W), got {raw.shape}"
|
|
|
158 |
rgb = self.debayer(rgb)
|
159 |
# self.stages['debayer'] = rgb
|
160 |
|
161 |
+
rgb = torch.einsum('bchw,kc->bchw', rgb, self.white_balance).contiguous()
|
162 |
rgb = torch.einsum('bchw,kc->bkhw', rgb, self.colour_correction).contiguous()
|
163 |
self.stages['color_correct'] = rgb
|
164 |
|
|
|
181 |
self.stages['gamma_correct'] = rgb
|
182 |
|
183 |
if self.additive_layer is not None:
|
|
|
184 |
rgb = rgb + self.additive_layer
|
185 |
self.stages['noise'] = rgb
|
186 |
|
|
|
260 |
os.chdir('..')
|
261 |
|
262 |
import matplotlib.pyplot as plt
|
263 |
+
from dataset import get_dataset
|
264 |
from utils.base import np2torch, torch2np
|
265 |
|
266 |
from utils.debug import debug
|
267 |
+
from processing.pipeline_numpy import processing as default_processing
|
268 |
|
269 |
raw_dataset = get_dataset('DS')
|
270 |
loader = torch.utils.data.DataLoader(raw_dataset, batch_size=1)
|
readme/Slice 8.png
DELETED
Binary file (181 kB)
|
|
readme/init.md
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
|
|
|
|
readme/mlflow (1).png
DELETED
Binary file (206 kB)
|
|
train.py
CHANGED
@@ -4,6 +4,7 @@ import copy
|
|
4 |
import argparse
|
5 |
|
6 |
import torch
|
|
|
7 |
import torch.nn as nn
|
8 |
|
9 |
import mlflow.pytorch
|
@@ -16,17 +17,18 @@ from pytorch_lightning.callbacks import ModelCheckpoint
|
|
16 |
|
17 |
from utils.base import display_mlflow_run_info, str2bool, fetch_from_mlflow, get_name, data_loader_mean_and_std
|
18 |
from utils.debug import debug
|
|
|
19 |
from utils.augmentation import get_augmentation
|
20 |
-
from
|
21 |
|
22 |
-
from
|
23 |
-
from
|
24 |
|
25 |
-
from
|
26 |
|
27 |
import segmentation_models_pytorch as smp
|
28 |
|
29 |
-
from utils.
|
30 |
|
31 |
# args to set up task
|
32 |
parser = argparse.ArgumentParser(description="classification_task")
|
@@ -106,41 +108,33 @@ parser.add_argument("--adv_training", action='store_true', help="Enable adversar
|
|
106 |
parser.add_argument("--adv_aux_weight", type=float, default=1, help="Weighting of the adversarial auxilliary loss")
|
107 |
parser.add_argument("--adv_aux_loss", type=str, default='ssim', choices=['l2', 'ssim'],
|
108 |
help="Type of adversarial auxilliary regularization loss")
|
|
|
|
|
|
|
|
|
109 |
|
110 |
parser.add_argument("--cache_downloaded_models", type=str2bool, default=True)
|
111 |
|
112 |
parser.add_argument('--test_run', action='store_true')
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
'--augmentation=strong',
|
119 |
-
'--lr=1e-5',
|
120 |
-
'--freeze_processor',
|
121 |
-
# '--track_processing',
|
122 |
-
# '--test_run',
|
123 |
-
# '--track_predictions',
|
124 |
-
# '--track_every_epoch',
|
125 |
-
# '--adv_training',
|
126 |
-
# '--adv_aux_weight=100',
|
127 |
-
# '--adv_aux_loss=l2',
|
128 |
-
# '--log_model=',
|
129 |
-
])
|
130 |
-
else:
|
131 |
-
args = parser.parse_args()
|
132 |
|
133 |
|
134 |
def run_train(args):
|
135 |
|
|
|
|
|
136 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
137 |
training_mode = 'adversarial' if args.adv_training else 'default'
|
138 |
|
139 |
# set tracking uri, this is the address of the mlflow server where light experimental data will be stored
|
140 |
mlflow.set_tracking_uri(args.tracking_uri)
|
141 |
mlflow.set_experiment(args.experiment_name)
|
142 |
-
os.environ["AWS_ACCESS_KEY_ID"] =
|
143 |
-
os.environ["AWS_SECRET_ACCESS_KEY"] =
|
144 |
|
145 |
# dataset
|
146 |
|
@@ -197,8 +191,9 @@ def run_train(args):
|
|
197 |
if args.processing_mode == 'parametrized':
|
198 |
processor = ParametrizedProcessing(
|
199 |
camera_parameters=dataset.camera_parameters, track_stages=track_stages, batch_norm_output=True,
|
200 |
-
noise_layer=args.
|
201 |
)
|
|
|
202 |
elif args.processing_mode == 'neural_network':
|
203 |
processor = NNProcessing(track_stages=track_stages,
|
204 |
normalize_mosaic=normalize_mosaic, batch_norm_output=True)
|
@@ -252,13 +247,19 @@ def run_train(args):
|
|
252 |
assert not args.freeze_processor, "Processor should not be frozen for adversarial training"
|
253 |
|
254 |
processor_default = copy.deepcopy(processor)
|
255 |
-
processor_default.track_stages =
|
256 |
processor_default.eval()
|
257 |
processor_default.to(DEVICE)
|
258 |
# debug(processor_default)
|
|
|
|
|
|
|
|
|
|
|
259 |
|
260 |
def l2_regularization(x, y):
|
261 |
-
return (x - y).
|
|
|
262 |
|
263 |
if args.adv_aux_loss == 'l2':
|
264 |
regularization = l2_regularization
|
@@ -274,7 +275,8 @@ def run_train(args):
|
|
274 |
self.weight = weight
|
275 |
|
276 |
def forward(self, x):
|
277 |
-
|
|
|
278 |
x_processed = processor.buffer['processed_rgb']
|
279 |
return self.weight * self.loss_aux(x_reference, x_processed)
|
280 |
|
@@ -290,7 +292,7 @@ def run_train(args):
|
|
290 |
def __repr__(self):
|
291 |
return f'{self.weight} * {get_name(self.loss)}'
|
292 |
|
293 |
-
loss = WeightedLoss(loss=
|
294 |
# loss = WeightedLoss(loss=nn.CrossEntropyLoss(), weight=0)
|
295 |
loss_aux = AuxLoss(
|
296 |
loss_aux=regularization,
|
@@ -303,8 +305,10 @@ def run_train(args):
|
|
303 |
classifier=classifier,
|
304 |
processor=processor,
|
305 |
loss=loss,
|
|
|
306 |
loss_aux=loss_aux,
|
307 |
adv_training=args.adv_training,
|
|
|
308 |
metrics=metrics,
|
309 |
augmentation=augmentation,
|
310 |
is_segmentation_task=dataset.task == 'segmentation',
|
@@ -346,7 +350,7 @@ def run_train(args):
|
|
346 |
|
347 |
with mlflow.start_run(run_name=f"{args.run_name}_{k_iter}", nested=True) as child_run:
|
348 |
|
349 |
-
#mlflow.pytorch.autolog(silent=True)
|
350 |
|
351 |
if k_iter == 0:
|
352 |
display_mlflow_run_info(child_run)
|
@@ -373,19 +377,22 @@ def run_train(args):
|
|
373 |
tracking_uri=args.tracking_uri,)
|
374 |
mlf_logger._run_id = child_run.info.run_id
|
375 |
|
|
|
|
|
376 |
callbacks = []
|
377 |
if args.track_processing:
|
378 |
callbacks += [TrackImagesCallback(track_loader,
|
|
|
379 |
track_every_epoch=args.track_every_epoch,
|
380 |
track_processing=args.track_processing,
|
381 |
track_gradients=args.track_processing_gradients,
|
382 |
track_predictions=args.track_predictions,
|
383 |
save_tensors=args.track_save_tensors)]
|
384 |
|
385 |
-
#if True: #args.save_best:
|
386 |
# if dataset.task == 'classification':
|
387 |
-
|
388 |
-
# checkpoint_callback = ModelCheckpoint(dirpath=args.tracking_uri, save_top_k=1, verbose=True, monitor="val_accuracy", mode="max") #dirpath=args.tracking_uri,
|
389 |
# else:
|
390 |
# checkpoint_callback = ModelCheckpoint(monitor="val_iou_score")
|
391 |
#callbacks += [checkpoint_callback]
|
@@ -397,7 +404,7 @@ def run_train(args):
|
|
397 |
logger=mlf_logger,
|
398 |
callbacks=callbacks,
|
399 |
check_val_every_n_epoch=args.check_val_every_n_epoch,
|
400 |
-
#checkpoint_callback=True,
|
401 |
)
|
402 |
|
403 |
if args.log_model:
|
@@ -410,9 +417,8 @@ def run_train(args):
|
|
410 |
val_dataloaders=valid_loader,
|
411 |
)
|
412 |
|
413 |
-
|
414 |
-
|
415 |
-
# print(f"param '{name}' diff: {p2 - p1}, l2: {(p2-p1).norm().item()}")
|
416 |
return model
|
417 |
|
418 |
|
|
|
4 |
import argparse
|
5 |
|
6 |
import torch
|
7 |
+
from torch import optim
|
8 |
import torch.nn as nn
|
9 |
|
10 |
import mlflow.pytorch
|
|
|
17 |
|
18 |
from utils.base import display_mlflow_run_info, str2bool, fetch_from_mlflow, get_name, data_loader_mean_and_std
|
19 |
from utils.debug import debug
|
20 |
+
from utils.dataset_utils import k_fold
|
21 |
from utils.augmentation import get_augmentation
|
22 |
+
from dataset import Subset, get_dataset
|
23 |
|
24 |
+
from processing.pipeline_numpy import RawProcessingPipeline
|
25 |
+
from processing.pipeline_torch import add_additive_layer, raw2rgb, RawToRGB, ParametrizedProcessing, NNProcessing
|
26 |
|
27 |
+
from model import log_tensor, resnet_model, LitModel, TrackImagesCallback
|
28 |
|
29 |
import segmentation_models_pytorch as smp
|
30 |
|
31 |
+
from utils.ssim import SSIM
|
32 |
|
33 |
# args to set up task
|
34 |
parser = argparse.ArgumentParser(description="classification_task")
|
|
|
108 |
parser.add_argument("--adv_aux_weight", type=float, default=1, help="Weighting of the adversarial auxilliary loss")
|
109 |
parser.add_argument("--adv_aux_loss", type=str, default='ssim', choices=['l2', 'ssim'],
|
110 |
help="Type of adversarial auxilliary regularization loss")
|
111 |
+
parser.add_argument("--adv_noise_layer", action='store_true', help="Adds an additive layer to Parametrized Processing")
|
112 |
+
parser.add_argument("--adv_track_differences", action='store_true', help='Save difference to default pipeline')
|
113 |
+
parser.add_argument('--adv_parameters', choices=['all', 'black_level', 'white_balance',
|
114 |
+
'colour_correction', 'gamma_correct', 'sharpening_filter', 'gaussian_blur', 'additive_layer'])
|
115 |
|
116 |
parser.add_argument("--cache_downloaded_models", type=str2bool, default=True)
|
117 |
|
118 |
parser.add_argument('--test_run', action='store_true')
|
119 |
|
120 |
+
|
121 |
+
args = parser.parse_args()
|
122 |
+
|
123 |
+
os.makedirs('results', exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
|
126 |
def run_train(args):
|
127 |
|
128 |
+
print(args)
|
129 |
+
|
130 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
131 |
training_mode = 'adversarial' if args.adv_training else 'default'
|
132 |
|
133 |
# set tracking uri, this is the address of the mlflow server where light experimental data will be stored
|
134 |
mlflow.set_tracking_uri(args.tracking_uri)
|
135 |
mlflow.set_experiment(args.experiment_name)
|
136 |
+
os.environ["AWS_ACCESS_KEY_ID"] = "AKIAYYIYHFHKBIJHOJPA"
|
137 |
+
os.environ["AWS_SECRET_ACCESS_KEY"] = "eUSKKy+T+KzBKWvAw5PrM/MDwEgkE0LcpNWgnmir"
|
138 |
|
139 |
# dataset
|
140 |
|
|
|
191 |
if args.processing_mode == 'parametrized':
|
192 |
processor = ParametrizedProcessing(
|
193 |
camera_parameters=dataset.camera_parameters, track_stages=track_stages, batch_norm_output=True,
|
194 |
+
# noise_layer=args.adv_noise_layer, # this has to be added manually afterwards for when a model is loaded that doesn't have one yet
|
195 |
)
|
196 |
+
|
197 |
elif args.processing_mode == 'neural_network':
|
198 |
processor = NNProcessing(track_stages=track_stages,
|
199 |
normalize_mosaic=normalize_mosaic, batch_norm_output=True)
|
|
|
247 |
assert not args.freeze_processor, "Processor should not be frozen for adversarial training"
|
248 |
|
249 |
processor_default = copy.deepcopy(processor)
|
250 |
+
processor_default.track_stages = args.track_processing
|
251 |
processor_default.eval()
|
252 |
processor_default.to(DEVICE)
|
253 |
# debug(processor_default)
|
254 |
+
for p in processor_default.parameters():
|
255 |
+
p.requires_grad = False
|
256 |
+
|
257 |
+
if args.adv_noise_layer:
|
258 |
+
add_additive_layer(processor)
|
259 |
|
260 |
def l2_regularization(x, y):
|
261 |
+
return ((x - y) ** 2).sum()
|
262 |
+
# return (x - y).norm()
|
263 |
|
264 |
if args.adv_aux_loss == 'l2':
|
265 |
regularization = l2_regularization
|
|
|
275 |
self.weight = weight
|
276 |
|
277 |
def forward(self, x):
|
278 |
+
with torch.no_grad():
|
279 |
+
x_reference = processor_default(x)
|
280 |
x_processed = processor.buffer['processed_rgb']
|
281 |
return self.weight * self.loss_aux(x_reference, x_processed)
|
282 |
|
|
|
292 |
def __repr__(self):
|
293 |
return f'{self.weight} * {get_name(self.loss)}'
|
294 |
|
295 |
+
loss = WeightedLoss(loss=loss, weight=-1)
|
296 |
# loss = WeightedLoss(loss=nn.CrossEntropyLoss(), weight=0)
|
297 |
loss_aux = AuxLoss(
|
298 |
loss_aux=regularization,
|
|
|
305 |
classifier=classifier,
|
306 |
processor=processor,
|
307 |
loss=loss,
|
308 |
+
lr=args.lr,
|
309 |
loss_aux=loss_aux,
|
310 |
adv_training=args.adv_training,
|
311 |
+
adv_parameters=args.adv_parameters,
|
312 |
metrics=metrics,
|
313 |
augmentation=augmentation,
|
314 |
is_segmentation_task=dataset.task == 'segmentation',
|
|
|
350 |
|
351 |
with mlflow.start_run(run_name=f"{args.run_name}_{k_iter}", nested=True) as child_run:
|
352 |
|
353 |
+
# mlflow.pytorch.autolog(silent=True)
|
354 |
|
355 |
if k_iter == 0:
|
356 |
display_mlflow_run_info(child_run)
|
|
|
377 |
tracking_uri=args.tracking_uri,)
|
378 |
mlf_logger._run_id = child_run.info.run_id
|
379 |
|
380 |
+
reference_processor = processor_default if args.adv_training and args.adv_track_differences else None
|
381 |
+
|
382 |
callbacks = []
|
383 |
if args.track_processing:
|
384 |
callbacks += [TrackImagesCallback(track_loader,
|
385 |
+
reference_processor,
|
386 |
track_every_epoch=args.track_every_epoch,
|
387 |
track_processing=args.track_processing,
|
388 |
track_gradients=args.track_processing_gradients,
|
389 |
track_predictions=args.track_predictions,
|
390 |
save_tensors=args.track_save_tensors)]
|
391 |
|
392 |
+
# if True: #args.save_best:
|
393 |
# if dataset.task == 'classification':
|
394 |
+
#checkpoint_callback = ModelCheckpoint(pathmonitor="val_accuracy", mode='max')
|
395 |
+
# checkpoint_callback = ModelCheckpoint(dirpath=args.tracking_uri, save_top_k=1, verbose=True, monitor="val_accuracy", mode="max") #dirpath=args.tracking_uri,
|
396 |
# else:
|
397 |
# checkpoint_callback = ModelCheckpoint(monitor="val_iou_score")
|
398 |
#callbacks += [checkpoint_callback]
|
|
|
404 |
logger=mlf_logger,
|
405 |
callbacks=callbacks,
|
406 |
check_val_every_n_epoch=args.check_val_every_n_epoch,
|
407 |
+
# checkpoint_callback=True,
|
408 |
)
|
409 |
|
410 |
if args.log_model:
|
|
|
417 |
val_dataloaders=valid_loader,
|
418 |
)
|
419 |
|
420 |
+
globals().update(locals()) # for convenient access
|
421 |
+
|
|
|
422 |
return model
|
423 |
|
424 |
|
utils/augmentation.py
CHANGED
@@ -99,7 +99,7 @@ if __name__ == '__main__':
|
|
99 |
os.chdir('..')
|
100 |
|
101 |
# from utils.debug import debug
|
102 |
-
from
|
103 |
import matplotlib.pyplot as plt
|
104 |
|
105 |
dataset = get_dataset('DS') # drone segmentation
|
|
|
99 |
os.chdir('..')
|
100 |
|
101 |
# from utils.debug import debug
|
102 |
+
from dataset import get_dataset
|
103 |
import matplotlib.pyplot as plt
|
104 |
|
105 |
dataset = get_dataset('DS') # drone segmentation
|
utils/{splitting.py β dataset_utils.py}
RENAMED
@@ -1,6 +1,3 @@
|
|
1 |
-
"""
|
2 |
-
Split images in blocks and vice versa
|
3 |
-
"""
|
4 |
|
5 |
import random
|
6 |
import numpy as np
|
@@ -10,127 +7,183 @@ import torch
|
|
10 |
from skimage.util.shape import view_as_windows
|
11 |
|
12 |
|
13 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
"""Split the imgs in regions of size ROIs.
|
15 |
|
16 |
Args:
|
17 |
imgs (ndarray): images which you want to split
|
18 |
ROIs (tuple): size of sub-regions splitted (ROIs=region of interests)
|
19 |
step (tuple): step path from one sub-region to the next one (in the x,y axis)
|
20 |
-
|
21 |
Returns:
|
22 |
ndarray: splitted subimages.
|
23 |
The size is (x_num_subROIs*y_num_subROIs, **) where:
|
24 |
x_num_subROIs = ( imgs.shape[1]-int(ROIs[1]/2)*2 )/step[1]
|
25 |
y_num_subROIs = ( imgs.shape[0]-int(ROIs[0]/2)*2 )/step[0]
|
26 |
-
|
27 |
Example:
|
28 |
>>> from dataset_generator import split
|
29 |
>>> imgs_splitted = split(imgs, ROI_size = (5,5), step=(2,3))
|
30 |
"""
|
31 |
-
|
32 |
if len(ROIs) > 2:
|
33 |
return print("ROIs is a 2 element list")
|
34 |
-
|
35 |
if len(step) > 2:
|
36 |
return print("step is a 2 element list")
|
37 |
-
|
38 |
if type(imgs) != type(np.array(1)):
|
39 |
return print("imgs should be a ndarray")
|
40 |
-
|
41 |
-
if len(imgs.shape) == 2:
|
42 |
-
splitted = view_as_windows(imgs, (ROIs[0],ROIs[1]), (step[0], step[1]))
|
43 |
return splitted.reshape((-1, ROIs[0], ROIs[1]))
|
44 |
-
|
45 |
-
if len(imgs.shape) == 3:
|
46 |
_, _, channels = imgs.shape
|
47 |
-
if channels <= 3:
|
48 |
-
splitted = view_as_windows(imgs, (ROIs[0],ROIs[1], channels), (step[0], step[1], channels))
|
49 |
return splitted.reshape((-1, ROIs[0], ROIs[1], channels))
|
50 |
-
else:
|
51 |
-
splitted = view_as_windows(imgs, (1, ROIs[0],ROIs[1]), (1, step[0], step[1]))
|
52 |
return splitted.reshape((-1, ROIs[0], ROIs[1]))
|
53 |
-
|
54 |
-
if len(imgs.shape) == 4:
|
55 |
_, _, _, channels = imgs.shape
|
56 |
-
splitted = view_as_windows(imgs, (1, ROIs[0],ROIs[1], channels), (1, step[0], step[1], channels))
|
57 |
return splitted.reshape((-1, ROIs[0], ROIs[1], channels))
|
58 |
|
|
|
59 |
def join_blocks(splitted, final_shape):
|
60 |
"""Join blocks to reobtain a splitted image
|
61 |
-
|
62 |
Attribute:
|
63 |
splitted (tensor) = image splitted in blocks, size = (N_blocks, Channels, Height, Width)
|
64 |
final_shape (tuple) = size of the final image reconstructed (Height, Width)
|
65 |
Return:
|
66 |
tensor: image restored from blocks. size = (Channels, Height, Width)
|
67 |
-
|
68 |
"""
|
69 |
n_blocks, channels, ROI_height, ROI_width = splitted.shape
|
70 |
-
|
71 |
rows = final_shape[0] // ROI_height
|
72 |
columns = final_shape[1] // ROI_width
|
73 |
|
74 |
-
final_img = torch.empty(rows, channels, ROI_height, ROI_width*columns)
|
75 |
-
for r in range(rows):
|
76 |
-
stackblocks = splitted[r*columns]
|
77 |
for c in range(1, columns):
|
78 |
-
stackblocks = torch.cat((stackblocks, splitted[r*columns+c]), axis=2)
|
79 |
final_img[r] = stackblocks
|
80 |
-
|
81 |
joined_img = final_img[0]
|
82 |
-
|
83 |
-
for i in np.arange(1, len(final_img)):
|
84 |
-
joined_img = torch.cat((joined_img,final_img[i]), axis
|
85 |
-
|
86 |
return joined_img
|
87 |
|
88 |
-
|
|
|
89 |
""" Return a random region for each input/target pair images of the dataset
|
90 |
Args:
|
91 |
Y (ndarray): target of your dataset --> size: (BxHxWxC)
|
92 |
X (ndarray): input of your dataset --> size: (BxHxWxC)
|
93 |
ROIs (tuple): size of random region (ROIs=region of interests)
|
94 |
-
|
95 |
Returns:
|
96 |
For each pair images (input/target) of the dataset, return respectively random ROIs
|
97 |
Y_cut (ndarray): target of your dataset --> size: (Batch,Channels,ROIs[0],ROIs[1])
|
98 |
X_cut (ndarray): input of your dataset --> size: (Batch,Channels,ROIs[0],ROIs[1])
|
99 |
-
|
100 |
Example:
|
101 |
>>> from dataset_generator import random_ROI
|
102 |
>>> X,Y = random_ROI(X,Y, ROIs = (10,10))
|
103 |
-
"""
|
104 |
-
|
105 |
batch, channels, height, width = X.shape
|
106 |
-
|
107 |
-
X_cut=np.empty((batch, ROIs[0], ROIs[1], channels))
|
108 |
-
Y_cut=np.empty((batch, ROIs[0], ROIs[1], channels))
|
109 |
-
|
110 |
for i in np.arange(len(X)):
|
111 |
-
x_size=int(random.random()*(height-(ROIs[0]+1)))
|
112 |
-
y_size=int(random.random()*(width-(ROIs[1]+1)))
|
113 |
-
X_cut[i]=X[i, x_size:x_size+ROIs[0],y_size:y_size+ROIs[1], :]
|
114 |
-
Y_cut[i]=Y[i, x_size:x_size+ROIs[0],y_size:y_size+ROIs[1], :]
|
115 |
return X_cut, Y_cut
|
116 |
|
117 |
-
|
|
|
118 |
""" Return a dataset of N subimages obtained from random regions of the same image
|
119 |
Args:
|
120 |
Y (ndarray): target of your dataset --> size: (1,H,W,C)
|
121 |
X (ndarray): input of your dataset --> size: (1,H,W,C)
|
122 |
datasize = number of random ROIs to generate
|
123 |
ROIs (tuple): size of random region (ROIs=region of interests)
|
124 |
-
|
125 |
Returns:
|
126 |
Y_cut (ndarray): target of your dataset --> size: (Datasize,ROIs[0],ROIs[1],Channels)
|
127 |
X_cut (ndarray): input of your dataset --> size: (Datasize,ROIs[0],ROIs[1],Channels)
|
128 |
-
"""
|
129 |
|
130 |
batch, channels, height, width = X.shape
|
131 |
-
|
132 |
-
X_cut=np.empty((datasize, ROIs[0], ROIs[1], channels))
|
133 |
-
Y_cut=np.empty((datasize, ROIs[0], ROIs[1], channels))
|
134 |
|
135 |
for i in np.arange(datasize):
|
136 |
X_cut[i], Y_cut[i] = random_ROI(X, Y, ROIs)
|
|
|
|
|
|
|
|
|
1 |
|
2 |
import random
|
3 |
import numpy as np
|
|
|
7 |
from skimage.util.shape import view_as_windows
|
8 |
|
9 |
|
10 |
+
def load_image(path):
|
11 |
+
file_type = path.split('.')[-1].lower()
|
12 |
+
if file_type == 'dng':
|
13 |
+
img = rawpy.imread(path).raw_image_visible
|
14 |
+
elif file_type == 'tiff' or file_type == 'tif':
|
15 |
+
img = np.array(tiff.imread(path), dtype=np.float32)
|
16 |
+
else:
|
17 |
+
img = np.array(Image.open(path), dtype=np.float32)
|
18 |
+
return img
|
19 |
+
|
20 |
+
|
21 |
+
def list_images_in_dir(path):
|
22 |
+
image_list = [os.path.join(path, img_name)
|
23 |
+
for img_name in sorted(os.listdir(path))
|
24 |
+
if img_name.split('.')[-1].lower() in IMAGE_FILE_TYPES]
|
25 |
+
return image_list
|
26 |
+
|
27 |
+
|
28 |
+
def k_fold(dataset, n_splits: int, seed: int, train_size: float):
|
29 |
+
"""Split dataset in subsets for cross-validation
|
30 |
+
|
31 |
+
Args:
|
32 |
+
dataset (class): dataset to split
|
33 |
+
n_split (int): Number of re-shuffling & splitting iterations.
|
34 |
+
seed (int): seed for k_fold splitting
|
35 |
+
train_size (float): should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the train split.
|
36 |
+
Returns:
|
37 |
+
idxs (list): indeces for splitting the dataset. The list contain n_split pair of train/test indeces.
|
38 |
+
"""
|
39 |
+
if hasattr(dataset, 'labels'):
|
40 |
+
x = dataset.images
|
41 |
+
y = dataset.labels
|
42 |
+
elif hasattr(dataset, 'masks'):
|
43 |
+
x = dataset.images
|
44 |
+
y = dataset.masks
|
45 |
+
|
46 |
+
idxs = []
|
47 |
+
|
48 |
+
if dataset.task == 'classification':
|
49 |
+
sss = StratifiedShuffleSplit(n_splits=n_splits, train_size=train_size, random_state=seed)
|
50 |
+
|
51 |
+
for idxs_train, idxs_test in sss.split(x, y):
|
52 |
+
idxs.append((idxs_train.tolist(), idxs_test.tolist()))
|
53 |
+
|
54 |
+
elif dataset.task == 'segmentation':
|
55 |
+
for n in range(n_splits):
|
56 |
+
split_idx = int(len(dataset) * train_size)
|
57 |
+
indices = np.random.permutation(len(dataset))
|
58 |
+
idxs.append((indices[:split_idx].tolist(), indices[split_idx:].tolist()))
|
59 |
+
|
60 |
+
return idxs
|
61 |
+
|
62 |
+
|
63 |
+
def split_img(imgs, ROIs=(3, 3), step=(1, 1)):
|
64 |
"""Split the imgs in regions of size ROIs.
|
65 |
|
66 |
Args:
|
67 |
imgs (ndarray): images which you want to split
|
68 |
ROIs (tuple): size of sub-regions splitted (ROIs=region of interests)
|
69 |
step (tuple): step path from one sub-region to the next one (in the x,y axis)
|
70 |
+
|
71 |
Returns:
|
72 |
ndarray: splitted subimages.
|
73 |
The size is (x_num_subROIs*y_num_subROIs, **) where:
|
74 |
x_num_subROIs = ( imgs.shape[1]-int(ROIs[1]/2)*2 )/step[1]
|
75 |
y_num_subROIs = ( imgs.shape[0]-int(ROIs[0]/2)*2 )/step[0]
|
76 |
+
|
77 |
Example:
|
78 |
>>> from dataset_generator import split
|
79 |
>>> imgs_splitted = split(imgs, ROI_size = (5,5), step=(2,3))
|
80 |
"""
|
81 |
+
|
82 |
if len(ROIs) > 2:
|
83 |
return print("ROIs is a 2 element list")
|
84 |
+
|
85 |
if len(step) > 2:
|
86 |
return print("step is a 2 element list")
|
87 |
+
|
88 |
if type(imgs) != type(np.array(1)):
|
89 |
return print("imgs should be a ndarray")
|
90 |
+
|
91 |
+
if len(imgs.shape) == 2: # Single image with one channel (HxW)
|
92 |
+
splitted = view_as_windows(imgs, (ROIs[0], ROIs[1]), (step[0], step[1]))
|
93 |
return splitted.reshape((-1, ROIs[0], ROIs[1]))
|
94 |
+
|
95 |
+
if len(imgs.shape) == 3:
|
96 |
_, _, channels = imgs.shape
|
97 |
+
if channels <= 3: # Single image more channels (HxWxC)
|
98 |
+
splitted = view_as_windows(imgs, (ROIs[0], ROIs[1], channels), (step[0], step[1], channels))
|
99 |
return splitted.reshape((-1, ROIs[0], ROIs[1], channels))
|
100 |
+
else: # More images with 1 channel
|
101 |
+
splitted = view_as_windows(imgs, (1, ROIs[0], ROIs[1]), (1, step[0], step[1]))
|
102 |
return splitted.reshape((-1, ROIs[0], ROIs[1]))
|
103 |
+
|
104 |
+
if len(imgs.shape) == 4: # More images with more channels(BxHxWxC)
|
105 |
_, _, _, channels = imgs.shape
|
106 |
+
splitted = view_as_windows(imgs, (1, ROIs[0], ROIs[1], channels), (1, step[0], step[1], channels))
|
107 |
return splitted.reshape((-1, ROIs[0], ROIs[1], channels))
|
108 |
|
109 |
+
|
110 |
def join_blocks(splitted, final_shape):
|
111 |
"""Join blocks to reobtain a splitted image
|
112 |
+
|
113 |
Attribute:
|
114 |
splitted (tensor) = image splitted in blocks, size = (N_blocks, Channels, Height, Width)
|
115 |
final_shape (tuple) = size of the final image reconstructed (Height, Width)
|
116 |
Return:
|
117 |
tensor: image restored from blocks. size = (Channels, Height, Width)
|
118 |
+
|
119 |
"""
|
120 |
n_blocks, channels, ROI_height, ROI_width = splitted.shape
|
121 |
+
|
122 |
rows = final_shape[0] // ROI_height
|
123 |
columns = final_shape[1] // ROI_width
|
124 |
|
125 |
+
final_img = torch.empty(rows, channels, ROI_height, ROI_width * columns)
|
126 |
+
for r in range(rows):
|
127 |
+
stackblocks = splitted[r * columns]
|
128 |
for c in range(1, columns):
|
129 |
+
stackblocks = torch.cat((stackblocks, splitted[r * columns + c]), axis=2)
|
130 |
final_img[r] = stackblocks
|
131 |
+
|
132 |
joined_img = final_img[0]
|
133 |
+
|
134 |
+
for i in np.arange(1, len(final_img)):
|
135 |
+
joined_img = torch.cat((joined_img, final_img[i]), axis=1)
|
136 |
+
|
137 |
return joined_img
|
138 |
|
139 |
+
|
140 |
+
def random_ROI(X, Y, ROIs=(512, 512)):
|
141 |
""" Return a random region for each input/target pair images of the dataset
|
142 |
Args:
|
143 |
Y (ndarray): target of your dataset --> size: (BxHxWxC)
|
144 |
X (ndarray): input of your dataset --> size: (BxHxWxC)
|
145 |
ROIs (tuple): size of random region (ROIs=region of interests)
|
146 |
+
|
147 |
Returns:
|
148 |
For each pair images (input/target) of the dataset, return respectively random ROIs
|
149 |
Y_cut (ndarray): target of your dataset --> size: (Batch,Channels,ROIs[0],ROIs[1])
|
150 |
X_cut (ndarray): input of your dataset --> size: (Batch,Channels,ROIs[0],ROIs[1])
|
151 |
+
|
152 |
Example:
|
153 |
>>> from dataset_generator import random_ROI
|
154 |
>>> X,Y = random_ROI(X,Y, ROIs = (10,10))
|
155 |
+
"""
|
156 |
+
|
157 |
batch, channels, height, width = X.shape
|
158 |
+
|
159 |
+
X_cut = np.empty((batch, ROIs[0], ROIs[1], channels))
|
160 |
+
Y_cut = np.empty((batch, ROIs[0], ROIs[1], channels))
|
161 |
+
|
162 |
for i in np.arange(len(X)):
|
163 |
+
x_size = int(random.random() * (height - (ROIs[0] + 1)))
|
164 |
+
y_size = int(random.random() * (width - (ROIs[1] + 1)))
|
165 |
+
X_cut[i] = X[i, x_size:x_size + ROIs[0], y_size:y_size + ROIs[1], :]
|
166 |
+
Y_cut[i] = Y[i, x_size:x_size + ROIs[0], y_size:y_size + ROIs[1], :]
|
167 |
return X_cut, Y_cut
|
168 |
|
169 |
+
|
170 |
+
def one2many_random_ROI(X, Y, datasize=1000, ROIs=(512, 512)):
|
171 |
""" Return a dataset of N subimages obtained from random regions of the same image
|
172 |
Args:
|
173 |
Y (ndarray): target of your dataset --> size: (1,H,W,C)
|
174 |
X (ndarray): input of your dataset --> size: (1,H,W,C)
|
175 |
datasize = number of random ROIs to generate
|
176 |
ROIs (tuple): size of random region (ROIs=region of interests)
|
177 |
+
|
178 |
Returns:
|
179 |
Y_cut (ndarray): target of your dataset --> size: (Datasize,ROIs[0],ROIs[1],Channels)
|
180 |
X_cut (ndarray): input of your dataset --> size: (Datasize,ROIs[0],ROIs[1],Channels)
|
181 |
+
"""
|
182 |
|
183 |
batch, channels, height, width = X.shape
|
184 |
+
|
185 |
+
X_cut = np.empty((datasize, ROIs[0], ROIs[1], channels))
|
186 |
+
Y_cut = np.empty((datasize, ROIs[0], ROIs[1], channels))
|
187 |
|
188 |
for i in np.arange(datasize):
|
189 |
X_cut[i], Y_cut[i] = random_ROI(X, Y, ROIs)
|
utils/debug.py
DELETED
@@ -1,371 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import numpy as np
|
3 |
-
import inspect
|
4 |
-
from functools import reduce, wraps
|
5 |
-
from collections.abc import Iterable
|
6 |
-
from IPython import embed
|
7 |
-
|
8 |
-
try:
|
9 |
-
get_ipython() # pylint: disable=undefined-variable
|
10 |
-
interactive_notebook = True
|
11 |
-
except:
|
12 |
-
interactive_notebook = False
|
13 |
-
|
14 |
-
_NONE = "__UNSET_VARIABLE__"
|
15 |
-
|
16 |
-
|
17 |
-
def debug_init():
|
18 |
-
debug.disable = False
|
19 |
-
debug.silent = False
|
20 |
-
debug.verbose = 2
|
21 |
-
debug.expand_ignore = ["DataLoader", "Dataset", "Subset"]
|
22 |
-
debug.max_expand = 10
|
23 |
-
debug.show_tensor = False
|
24 |
-
debug.raise_exception = True
|
25 |
-
debug.full_stack = True
|
26 |
-
debug.restore_defaults_on_exception = not interactive_notebook
|
27 |
-
debug._indent = 0
|
28 |
-
debug._stack = ""
|
29 |
-
|
30 |
-
debug.embed = embed
|
31 |
-
debug.show = debug_show
|
32 |
-
debug.pause = debug_pause
|
33 |
-
|
34 |
-
|
35 |
-
def debug_pause():
|
36 |
-
input("Press Enter to continue...")
|
37 |
-
|
38 |
-
|
39 |
-
def debug(*args, assert_true=False):
|
40 |
-
"""Decorator for debugging functions and tensors.
|
41 |
-
Will throw an exception as soon as a nan is encountered.
|
42 |
-
If used on iterables, these will be expanded and also searched for nans.
|
43 |
-
Usage:
|
44 |
-
debug(x)
|
45 |
-
Or:
|
46 |
-
@debug
|
47 |
-
def function():
|
48 |
-
...
|
49 |
-
If used as a function wrapper, all arguments will be searched and printed.
|
50 |
-
"""
|
51 |
-
|
52 |
-
single_arg = len(args) == 1
|
53 |
-
|
54 |
-
if debug.disable:
|
55 |
-
return args[0] if single_arg else None
|
56 |
-
|
57 |
-
try:
|
58 |
-
call_line = ''.join(inspect.stack()[1][4]).strip()
|
59 |
-
except:
|
60 |
-
call_line = '...'
|
61 |
-
used_as_wrapper = 'def ' == call_line[:4]
|
62 |
-
expect_return_arg = single_arg and 'debug' in call_line and call_line.split('debug')[0].strip() != ''
|
63 |
-
is_func = single_arg and hasattr(args[0], '__call__')
|
64 |
-
|
65 |
-
if is_func and (used_as_wrapper or expect_return_arg):
|
66 |
-
func = args[0]
|
67 |
-
sig_parameters = inspect.signature(func).parameters
|
68 |
-
sig_argnames = [p.name for p in sig_parameters.values()]
|
69 |
-
sig_defaults = {
|
70 |
-
k: v.default
|
71 |
-
for k, v in sig_parameters.items()
|
72 |
-
if v.default is not inspect.Parameter.empty
|
73 |
-
}
|
74 |
-
|
75 |
-
@wraps(func)
|
76 |
-
def _func(*args, **kwargs):
|
77 |
-
if debug.disable:
|
78 |
-
return func(*args, **kwargs)
|
79 |
-
|
80 |
-
if debug._indent == 0:
|
81 |
-
debug._stack = ""
|
82 |
-
stack_before = debug._stack
|
83 |
-
indent = ' ' * 4 * debug._indent
|
84 |
-
debug._indent += 1
|
85 |
-
|
86 |
-
args_kw = dict(zip(sig_argnames, args))
|
87 |
-
defaults = {k: v for k, v in sig_defaults.items()
|
88 |
-
if k not in kwargs
|
89 |
-
if k not in args_kw}
|
90 |
-
all_args = {**args_kw, **kwargs, **defaults}
|
91 |
-
|
92 |
-
func_name = None
|
93 |
-
if hasattr(func, '__name__'):
|
94 |
-
func_name = func.__name__
|
95 |
-
elif hasattr(func, '__class__'):
|
96 |
-
func_name = func.__class__.__name__
|
97 |
-
|
98 |
-
if func_name is None:
|
99 |
-
func_name = '... ' + call_line + '...'
|
100 |
-
else:
|
101 |
-
func_name = '@' + func_name + '()'
|
102 |
-
|
103 |
-
_debug_log('', indent=indent)
|
104 |
-
_debug_log(func_name, indent=indent)
|
105 |
-
|
106 |
-
debug._last_call = func
|
107 |
-
debug._last_args = all_args
|
108 |
-
debug._last_args_sig = sig_argnames
|
109 |
-
|
110 |
-
for argtype, params in [("args", args_kw.items()),
|
111 |
-
("kwargs", kwargs.items()),
|
112 |
-
("defaults", defaults.items())]:
|
113 |
-
if params:
|
114 |
-
_debug_log(f"{argtype}:", indent=indent + ' ' * 6)
|
115 |
-
for argname, arg in params:
|
116 |
-
if argname == 'self':
|
117 |
-
# _debug_log(f"- self: ...", indent=indent + ' ' * 8)
|
118 |
-
pass
|
119 |
-
else:
|
120 |
-
_debug_log(f"- {argname}: ", arg, indent + ' ' * 8, assert_true)
|
121 |
-
try:
|
122 |
-
out = func(*args, **kwargs)
|
123 |
-
except:
|
124 |
-
_debug_crash_save()
|
125 |
-
debug._stack = ""
|
126 |
-
debug._indent = 0
|
127 |
-
raise
|
128 |
-
debug.out = out
|
129 |
-
_debug_log("returned: ", out, indent, assert_true)
|
130 |
-
_debug_log('', indent=indent)
|
131 |
-
debug._indent -= 1
|
132 |
-
if not debug.full_stack:
|
133 |
-
debug._stack = stack_before
|
134 |
-
return out
|
135 |
-
return _func
|
136 |
-
else:
|
137 |
-
if debug._indent == 0:
|
138 |
-
debug._stack = ""
|
139 |
-
argname = ')'.join('('.join(call_line.split('(')[1:]).split(')')[:-1])
|
140 |
-
if assert_true:
|
141 |
-
argname = ','.join(argname.split(',')[:-1])
|
142 |
-
_debug_log(f"assert{{{argname}}} ", args[0], ' ' * 4 * debug._indent, assert_true)
|
143 |
-
else:
|
144 |
-
for arg in args:
|
145 |
-
_debug_log(f"{{{argname}}} = ", arg, ' ' * 4 * debug._indent, assert_true)
|
146 |
-
if expect_return_arg:
|
147 |
-
return args[0]
|
148 |
-
return
|
149 |
-
|
150 |
-
|
151 |
-
def is_iterable(x):
|
152 |
-
return isinstance(x, Iterable) or hasattr(x, '__getitem__') and not isinstance(x, str)
|
153 |
-
|
154 |
-
|
155 |
-
def ndarray_repr(t, assert_all=False):
|
156 |
-
exception_encountered = False
|
157 |
-
info = []
|
158 |
-
shape = tuple(t.shape)
|
159 |
-
single_entry = shape == () or shape == (1,)
|
160 |
-
if single_entry:
|
161 |
-
info.append(f"[{t.item():.4f}]")
|
162 |
-
else:
|
163 |
-
info.append(f"({', '.join(map(repr, shape))})")
|
164 |
-
invalid_sum = (~np.isfinite(t)).sum().item()
|
165 |
-
if invalid_sum:
|
166 |
-
info.append(
|
167 |
-
f"{invalid_sum} INVALID ENTR{'Y' if invalid_sum == 1 else 'IES'}")
|
168 |
-
exception_encountered = True
|
169 |
-
if debug.verbose > 1:
|
170 |
-
if not invalid_sum and not single_entry:
|
171 |
-
info.append(f"|x|={np.linalg.norm(t):.1f}")
|
172 |
-
if t.size:
|
173 |
-
info.append(f"x in [{t.min():.1f}, {t.max():.1f}]")
|
174 |
-
if debug.verbose and t.dtype != np.float:
|
175 |
-
info.append(f"dtype={str(t.dtype)}".replace("'", ''))
|
176 |
-
if assert_all:
|
177 |
-
assert_val = t.all()
|
178 |
-
if not assert_val:
|
179 |
-
exception_encountered = True
|
180 |
-
if assert_all and not exception_encountered:
|
181 |
-
output = "passed"
|
182 |
-
else:
|
183 |
-
if assert_all and not assert_val:
|
184 |
-
output = f"ndarray({info[0]})"
|
185 |
-
else:
|
186 |
-
output = f"ndarray({', '.join(info)})"
|
187 |
-
if exception_encountered and (not hasattr(debug, 'raise_exception') or debug.raise_exception):
|
188 |
-
if debug.restore_defaults_on_exception:
|
189 |
-
debug.raise_exception = False
|
190 |
-
debug.silent = False
|
191 |
-
debug.x = t
|
192 |
-
msg = output
|
193 |
-
debug._stack += output
|
194 |
-
if debug._stack and '\n' in debug._stack:
|
195 |
-
msg += '\nSTACK: ' + debug._stack
|
196 |
-
if assert_all:
|
197 |
-
assert assert_val, "Assert did not pass on " + msg
|
198 |
-
raise Exception("Invalid entries encountered in " + msg)
|
199 |
-
return output
|
200 |
-
|
201 |
-
|
202 |
-
def tensor_repr(t, assert_all=False):
|
203 |
-
exception_encountered = False
|
204 |
-
info = []
|
205 |
-
shape = tuple(t.shape)
|
206 |
-
single_entry = shape == () or shape == (1,)
|
207 |
-
if single_entry:
|
208 |
-
info.append(f"[{t.item():.3f}]")
|
209 |
-
else:
|
210 |
-
info.append(f"({', '.join(map(repr, shape))})")
|
211 |
-
invalid_sum = (~torch.isfinite(t)).sum().item()
|
212 |
-
if invalid_sum:
|
213 |
-
info.append(
|
214 |
-
f"{invalid_sum} INVALID ENTR{'Y' if invalid_sum == 1 else 'IES'}")
|
215 |
-
exception_encountered = True
|
216 |
-
if debug.verbose and t.requires_grad:
|
217 |
-
info.append('req_grad')
|
218 |
-
if debug.verbose > 2:
|
219 |
-
if t.is_leaf:
|
220 |
-
info.append('leaf')
|
221 |
-
if hasattr(t, 'retains_grad') and t.retains_grad:
|
222 |
-
info.append('retains_grad')
|
223 |
-
has_grad = (t.is_leaf or hasattr(t, 'retains_grad') and t.retains_grad) and t.grad is not None
|
224 |
-
if has_grad:
|
225 |
-
grad_invalid_sum = (~torch.isfinite(t.grad)).sum().item()
|
226 |
-
if grad_invalid_sum:
|
227 |
-
info.append(
|
228 |
-
f"GRAD {grad_invalid_sum} INVALID ENTR{'Y' if grad_invalid_sum == 1 else 'IES'}")
|
229 |
-
exception_encountered = True
|
230 |
-
if debug.verbose > 1:
|
231 |
-
if not invalid_sum and not single_entry:
|
232 |
-
info.append(f"|x|={t.float().norm():.1f}")
|
233 |
-
if t.numel():
|
234 |
-
info.append(f"x in [{t.min():.2f}, {t.max():.2f}]")
|
235 |
-
if has_grad and not grad_invalid_sum:
|
236 |
-
if single_entry:
|
237 |
-
info.append(f"grad={t.grad.float().item():.3f}")
|
238 |
-
else:
|
239 |
-
info.append(f"|grad|={t.grad.float().norm():.1f}")
|
240 |
-
if debug.verbose and t.dtype != torch.float:
|
241 |
-
info.append(f"dtype={str(t.dtype).split('.')[-1]}")
|
242 |
-
if debug.verbose and t.device.type != 'cpu':
|
243 |
-
info.append(f"device={t.device.type}")
|
244 |
-
if assert_all:
|
245 |
-
assert_val = t.all()
|
246 |
-
if not assert_val:
|
247 |
-
exception_encountered = True
|
248 |
-
if assert_all and not exception_encountered:
|
249 |
-
output = "passed"
|
250 |
-
else:
|
251 |
-
if assert_all and not assert_val:
|
252 |
-
output = f"tensor({info[0]})"
|
253 |
-
else:
|
254 |
-
output = f"tensor({', '.join(info)})"
|
255 |
-
if exception_encountered and (not hasattr(debug, 'raise_exception') or debug.raise_exception):
|
256 |
-
if debug.restore_defaults_on_exception:
|
257 |
-
debug.raise_exception = False
|
258 |
-
debug.silent = False
|
259 |
-
debug.x = t
|
260 |
-
msg = output
|
261 |
-
debug._stack += output
|
262 |
-
if debug._stack and '\n' in debug._stack:
|
263 |
-
msg += '\nSTACK: ' + debug._stack
|
264 |
-
if assert_all:
|
265 |
-
assert assert_val, "Assert did not pass on " + msg
|
266 |
-
raise Exception("Invalid entries encountered in " + msg)
|
267 |
-
return output
|
268 |
-
|
269 |
-
|
270 |
-
def _debug_crash_save():
|
271 |
-
if debug._indent:
|
272 |
-
debug.args = debug._last_args
|
273 |
-
debug.func = debug._last_call
|
274 |
-
|
275 |
-
@wraps(debug.func)
|
276 |
-
def _recall(*args, **kwargs):
|
277 |
-
call_args = {**debug.args, **kwargs, **dict(zip(debug._last_args_sig, args))}
|
278 |
-
return debug(debug.func)(**call_args)
|
279 |
-
|
280 |
-
def print_stack(stack=debug._stack):
|
281 |
-
print('\nSTACK: ' + stack)
|
282 |
-
debug.stack = print_stack
|
283 |
-
|
284 |
-
debug.recall = _recall
|
285 |
-
debug._indent = 0
|
286 |
-
|
287 |
-
|
288 |
-
def _debug_log(output, var=_NONE, indent='', assert_true=False, expand=True):
|
289 |
-
debug._stack += indent + output
|
290 |
-
if not debug.silent:
|
291 |
-
print(indent + output, end='')
|
292 |
-
if var is not _NONE:
|
293 |
-
type_str = type(var).__name__.lower()
|
294 |
-
if var is None:
|
295 |
-
_debug_log('None')
|
296 |
-
elif isinstance(var, str):
|
297 |
-
_debug_log(f"'{var}'")
|
298 |
-
elif type_str == 'ndarray':
|
299 |
-
_debug_log(ndarray_repr(var, assert_true))
|
300 |
-
if debug.show_tensor:
|
301 |
-
_debug_show_print(var, indent=indent + 4 * ' ')
|
302 |
-
# elif type_str in ('tensor', 'parameter'):
|
303 |
-
elif type_str == 'tensor':
|
304 |
-
_debug_log(tensor_repr(var, assert_true))
|
305 |
-
if debug.show_tensor:
|
306 |
-
_debug_show_print(var, indent=indent + 4 * ' ')
|
307 |
-
elif hasattr(var, 'named_parameters'):
|
308 |
-
_debug_log(type_str)
|
309 |
-
params = list(var.named_parameters())
|
310 |
-
_debug_log(f"{type_str}[{len(params)}] {{")
|
311 |
-
for k, v in params:
|
312 |
-
_debug_log(f"'{k}': ", v, indent + 6 * ' ')
|
313 |
-
_debug_log(indent + 4 * ' ' + '}')
|
314 |
-
elif is_iterable(var):
|
315 |
-
expand = debug.expand_ignore != '*' and expand
|
316 |
-
if expand:
|
317 |
-
if isinstance(debug.expand_ignore, str):
|
318 |
-
if type_str == str(debug.expand_ignore).lower():
|
319 |
-
expand = False
|
320 |
-
elif is_iterable(debug.expand_ignore):
|
321 |
-
for ignore in debug.expand_ignore:
|
322 |
-
if type_str == ignore.lower():
|
323 |
-
expand = False
|
324 |
-
if hasattr(var, '__len__'):
|
325 |
-
length = len(var)
|
326 |
-
else:
|
327 |
-
var = list(var)
|
328 |
-
length = len(var)
|
329 |
-
if expand and length > 0:
|
330 |
-
_debug_log(f"{type_str}[{length}] {{")
|
331 |
-
if isinstance(var, dict):
|
332 |
-
for k, v in var.items():
|
333 |
-
_debug_log(f"'{k}': ", v, indent + 6 * ' ', assert_true)
|
334 |
-
else:
|
335 |
-
i = 0
|
336 |
-
for k, i in zip(var, range(debug.max_expand)):
|
337 |
-
_debug_log('- ', k, indent + 6 * ' ', assert_true)
|
338 |
-
if i < length - 1:
|
339 |
-
_debug_log('- ' + ' ' * 6 + '...', indent=indent + 6 * ' ')
|
340 |
-
_debug_log(indent + 4 * ' ' + '}')
|
341 |
-
else:
|
342 |
-
_debug_log(f"{type_str}[{length}]")
|
343 |
-
else:
|
344 |
-
_debug_log(str(var))
|
345 |
-
else:
|
346 |
-
debug._stack += '\n'
|
347 |
-
if not debug.silent:
|
348 |
-
print()
|
349 |
-
|
350 |
-
|
351 |
-
def debug_show(x):
|
352 |
-
assert is_iterable(x)
|
353 |
-
debug(x)
|
354 |
-
_debug_show_print(x, indent=' ' * 4 * debug._indent)
|
355 |
-
|
356 |
-
|
357 |
-
def _debug_show_print(x, indent=''):
|
358 |
-
is_tensor = type(x).__name__ in ('Tensor', 'ndarray')
|
359 |
-
if is_tensor:
|
360 |
-
x = x.flatten()
|
361 |
-
if type(x).__name__ == 'Tensor' and x.dim() == 0:
|
362 |
-
return
|
363 |
-
n_samples = min(10, len(x))
|
364 |
-
di = len(x) // n_samples
|
365 |
-
var = list(x[i * di] for i in range(n_samples))
|
366 |
-
if is_tensor or type(var[0]) == float:
|
367 |
-
var = [round(float(v), 4) for v in var]
|
368 |
-
_debug_log('--> ', str(var), indent, expand=False)
|
369 |
-
|
370 |
-
|
371 |
-
debug_init()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/{Cperturb.py β hendrycks_robustness.py}
RENAMED
File without changes
|
utils/mutual_entropy.py
DELETED
@@ -1,193 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
from PIL import Image
|
3 |
-
import matplotlib.pyplot as plt
|
4 |
-
from scipy.signal import convolve2d
|
5 |
-
|
6 |
-
def mse(x,y):
|
7 |
-
return ((x-y)**2).mean()
|
8 |
-
|
9 |
-
def gaussian_noise_entropies(t1, bins=20):
|
10 |
-
all_MI= []
|
11 |
-
all_mse = []
|
12 |
-
for sigma in np.linspace(0,100,201):
|
13 |
-
t2 = np.random.normal(t1.copy(), scale=sigma, size = t1.shape)
|
14 |
-
hist_2d, x_edges, y_edges = np.histogram2d(
|
15 |
-
t1.ravel(),
|
16 |
-
t2.ravel(),
|
17 |
-
bins=bins)
|
18 |
-
all_mse.append(mse(t1,t2))
|
19 |
-
MI = mutual_information(hist_2d)
|
20 |
-
all_MI.append(MI)
|
21 |
-
|
22 |
-
return np.array((all_MI)), np.array((all_mse))
|
23 |
-
|
24 |
-
def shifts_entropies(t1, bins=20):
|
25 |
-
all_MI=[]
|
26 |
-
all_mse=[]
|
27 |
-
for N in np.linspace(1,50,50):
|
28 |
-
N = int(N)
|
29 |
-
temp_t2 = t1[:-N].copy()
|
30 |
-
temp_t1 = t1[N:].copy()
|
31 |
-
hist_2d, x_edges, y_edges = np.histogram2d(
|
32 |
-
t1.ravel(),
|
33 |
-
t2.ravel(),
|
34 |
-
bins=bins)
|
35 |
-
MI = mutual_information(hist_2d)
|
36 |
-
|
37 |
-
all_mse.append(mse(temp_t1,temp_t2))
|
38 |
-
all_MI.append(MI)
|
39 |
-
|
40 |
-
return np.array((all_MI)), np.array((all_mse))
|
41 |
-
|
42 |
-
def mutual_information(hgram):
|
43 |
-
""" Mutual information for joint histogram
|
44 |
-
"""
|
45 |
-
# Convert bins counts to probability values
|
46 |
-
pxy = hgram / float(np.sum(hgram))
|
47 |
-
px = np.sum(pxy, axis=1) # marginal for x over y
|
48 |
-
py = np.sum(pxy, axis=0) # marginal for y over x
|
49 |
-
px_py = px[:, None] * py[None, :] # Broadcast to multiply marginals
|
50 |
-
# Now we can do the calculation using the pxy, px_py 2D arrays
|
51 |
-
nzs = pxy > 0 # Only non-zero pxy values contribute to the sum
|
52 |
-
return np.sum(pxy[nzs] * np.log(pxy[nzs] / px_py[nzs]))
|
53 |
-
|
54 |
-
def entropy(image, bins=20):
|
55 |
-
image = image.ravel()
|
56 |
-
hist, bin_edges = np.histogram(image, bins = bins)
|
57 |
-
hist = hist/hist.sum()
|
58 |
-
entropy_term = np.where(hist != 0, hist*np.log(hist), 0)
|
59 |
-
entropy = - np.sum(entropy_term)
|
60 |
-
|
61 |
-
return entropy
|
62 |
-
|
63 |
-
# Gray Image
|
64 |
-
# t1 = np.array(Image.open("img.png"))[:,:,0].astype(float)
|
65 |
-
|
66 |
-
# Colour Image
|
67 |
-
t1 = np.array(Image.open("img.png").resize((255,255)))
|
68 |
-
|
69 |
-
perturb = "gauss"
|
70 |
-
show_image = True
|
71 |
-
bins=20
|
72 |
-
|
73 |
-
print(perturb)
|
74 |
-
|
75 |
-
# Identity
|
76 |
-
if perturb == "identity":
|
77 |
-
t2 = t1
|
78 |
-
title = "Identity"
|
79 |
-
image1 = "Clean"
|
80 |
-
image2 = "Clean"
|
81 |
-
|
82 |
-
# Poisson Noise on t2
|
83 |
-
if perturb == "poisson":
|
84 |
-
t2 = np.random.poisson(t1)
|
85 |
-
title = "Poisson Noise"
|
86 |
-
image1 = "Clean"
|
87 |
-
image2 = "Noisy"
|
88 |
-
|
89 |
-
# Gaussian Noise on t2
|
90 |
-
if perturb == "gauss":
|
91 |
-
print(np.shape(t1))
|
92 |
-
sigma = 50.0
|
93 |
-
t2 = np.random.normal(t1.copy(), scale=sigma, size = t1.shape)
|
94 |
-
if "grad" in locals():
|
95 |
-
title = f"Gaussian Noise, grad= True, sigma = {sigma:.2f}"
|
96 |
-
else:
|
97 |
-
title = f"Gaussian Noise, sigma = {sigma:.2f}"
|
98 |
-
image1 = "Clean"
|
99 |
-
image2 = "Noisy"
|
100 |
-
|
101 |
-
if perturb == "box":
|
102 |
-
sigma = 50.0
|
103 |
-
mean = np.mean(t1)
|
104 |
-
print(np.shape(t1))
|
105 |
-
t2 = t1.copy()
|
106 |
-
t2[30:220,50:120,:] = mean
|
107 |
-
title = "Box with mean pixels"
|
108 |
-
image1 = "Clean"
|
109 |
-
image2 = "Noisy"
|
110 |
-
|
111 |
-
|
112 |
-
# Shift t2 on y axis
|
113 |
-
if perturb == "shift":
|
114 |
-
N=30
|
115 |
-
t2 = t1[:-N]
|
116 |
-
t1 = t1[N:]
|
117 |
-
title = "y shift"
|
118 |
-
image1 = "Clean"
|
119 |
-
image2 = "Shifted"
|
120 |
-
|
121 |
-
t2 = np.clip(t2,0,255).astype(int)
|
122 |
-
|
123 |
-
print("Correlation Coefficient: ",np.corrcoef(t1.ravel(), t2.ravel())[0, 1])
|
124 |
-
|
125 |
-
# 2D Histogram
|
126 |
-
hist_2d, x_edges, y_edges = np.histogram2d(
|
127 |
-
t1.ravel(),
|
128 |
-
t2.ravel(),
|
129 |
-
bins=bins)
|
130 |
-
|
131 |
-
MI = mutual_information(hist_2d)
|
132 |
-
|
133 |
-
print("Mutual Information", MI)
|
134 |
-
print("Mean squared error:", mse(t1,t2))
|
135 |
-
|
136 |
-
if show_image == True:
|
137 |
-
plt.figure()
|
138 |
-
plt.imshow(np.hstack((t2, t1)))
|
139 |
-
plt.title(title)
|
140 |
-
|
141 |
-
plt.figure()
|
142 |
-
|
143 |
-
plt.plot(t1.ravel(), t2.ravel(), '.')
|
144 |
-
plt.xlabel(image1)
|
145 |
-
plt.ylabel(image2)
|
146 |
-
plt.title('I1 vs I2')
|
147 |
-
|
148 |
-
plt.figure()
|
149 |
-
plt.imshow((hist_2d.T)/hist_2d.max(), origin='lower')
|
150 |
-
plt.xlabel(image1)
|
151 |
-
plt.ylabel(image2)
|
152 |
-
plt.xticks(ticks=np.linspace(0,bins-1,10), labels=np.linspace(x_edges.min(),x_edges.max(),10).astype(int))
|
153 |
-
plt.yticks(ticks=np.linspace(0,bins-1,10), labels=np.linspace(y_edges.min(),y_edges.max(),10).astype(int))
|
154 |
-
plt.title('p(x,y)')
|
155 |
-
plt.colorbar()
|
156 |
-
|
157 |
-
# Show log histogram, avoiding divide by 0
|
158 |
-
plt.figure(figsize=(4,4))
|
159 |
-
hist_2d_log = np.zeros(hist_2d.shape)
|
160 |
-
non_zeros = hist_2d != 0
|
161 |
-
hist_2d_log[non_zeros] = np.log(hist_2d[non_zeros])
|
162 |
-
plt.imshow((hist_2d_log.T)/hist_2d_log.max(), origin='lower')
|
163 |
-
plt.xlabel(image1)
|
164 |
-
plt.ylabel(image2)
|
165 |
-
plt.xticks(ticks=np.linspace(0,bins-1,10), labels=np.linspace(x_edges.min(),x_edges.max(),10).astype(int))
|
166 |
-
plt.yticks(ticks=np.linspace(0,bins-1,10), labels=np.linspace(y_edges.min(),y_edges.max(),10).astype(int))
|
167 |
-
plt.title('log(p(x,y))')
|
168 |
-
plt.colorbar()
|
169 |
-
plt.show()
|
170 |
-
|
171 |
-
if perturb == "shift":
|
172 |
-
mi_array, mse_array = shifts_entropies(t1, bins=bins)
|
173 |
-
plt.figure()
|
174 |
-
plt.plot(np.linspace(0,50,50), mi_array)
|
175 |
-
plt.xlabel("y shift")
|
176 |
-
plt.ylabel("Mutual Information")
|
177 |
-
plt.figure()
|
178 |
-
plt.plot(np.linspace(0,50,50), mse_array)
|
179 |
-
plt.xlabel("y shift")
|
180 |
-
plt.ylabel("Mean Squared Error")
|
181 |
-
plt.show()
|
182 |
-
|
183 |
-
if perturb == "gauss":
|
184 |
-
mi_array, mse_array = gaussian_noise_entropies(t1, bins= bins)
|
185 |
-
plt.figure()
|
186 |
-
plt.plot(np.linspace(0,100,201), mi_array)
|
187 |
-
plt.xlabel("sigma")
|
188 |
-
plt.ylabel("Mutual Information")
|
189 |
-
plt.figure()
|
190 |
-
plt.plot(np.linspace(0,100,201), mse_array)
|
191 |
-
plt.xlabel("sigma")
|
192 |
-
plt.ylabel("Mean Squared Error")
|
193 |
-
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/{pytorch_ssim.py β ssim.py}
RENAMED
File without changes
|