willis
commited on
Commit
·
712050d
1
Parent(s):
0b2696b
cleanup
Browse files- model.py +8 -9
- processing/pipeline_torch.py +41 -6
- train.py +93 -132
- utils/base.py +65 -24
model.py
CHANGED
@@ -3,7 +3,7 @@ from collections import defaultdict
|
|
3 |
|
4 |
import torch
|
5 |
import torch.optim
|
6 |
-
from torchvision.models import resnet18
|
7 |
from torchvision.utils import make_grid, save_image
|
8 |
import torch.nn.functional as F
|
9 |
|
@@ -12,10 +12,13 @@ import pytorch_lightning as pl
|
|
12 |
import mlflow.pytorch
|
13 |
|
14 |
|
15 |
-
def resnet_model(model=resnet18, pretrained=True, in_channels=3, fc_out_features=2):
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
19 |
resnet.fc = torch.nn.Linear(in_features=512, out_features=fc_out_features, bias=True)
|
20 |
return resnet
|
21 |
|
@@ -81,10 +84,6 @@ class LitModel(pl.LightningModule):
|
|
81 |
|
82 |
def update_step(self, batch, step_name):
|
83 |
x, y = batch
|
84 |
-
# debug(self.processor)
|
85 |
-
# debug(self.processor.parameters())
|
86 |
-
# debug.pause()
|
87 |
-
# print('type', type(self.processor).__name__)
|
88 |
|
89 |
logits = self(x)
|
90 |
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.optim
|
6 |
+
from torchvision.models import resnet18, resnet34, resnet50
|
7 |
from torchvision.utils import make_grid, save_image
|
8 |
import torch.nn.functional as F
|
9 |
|
|
|
12 |
import mlflow.pytorch
|
13 |
|
14 |
|
15 |
+
def resnet_model(model='resnet18', pretrained=True, in_channels=3, fc_out_features=2):
|
16 |
+
if model.lower() == 'resnet18':
|
17 |
+
resnet = resnet18(pretrained=pretrained)
|
18 |
+
if model.lower() == 'resnet34':
|
19 |
+
resnet = resnet34(pretrained=pretrained)
|
20 |
+
if model.lower() == 'resnet50':
|
21 |
+
resnet = resnet50(pretrained=pretrained)
|
22 |
resnet.fc = torch.nn.Linear(in_features=512, out_features=fc_out_features, bias=True)
|
23 |
return resnet
|
24 |
|
|
|
84 |
|
85 |
def update_step(self, batch, step_name):
|
86 |
x, y = batch
|
|
|
|
|
|
|
|
|
87 |
|
88 |
logits = self(x)
|
89 |
|
processing/pipeline_torch.py
CHANGED
@@ -43,6 +43,18 @@ DEFAULT_CAMERA_PARAMS = (
|
|
43 |
|
44 |
|
45 |
class RawToRGB(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
def __init__(self, reduce_size=True, out_channels=3, track_stages=False, normalize_mosaic=None):
|
47 |
super().__init__()
|
48 |
self.stages = None
|
@@ -71,6 +83,14 @@ class RawToRGB(nn.Module):
|
|
71 |
|
72 |
|
73 |
class NNProcessing(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
def __init__(self, track_stages=False, normalize_mosaic=None, batch_norm_output=True):
|
75 |
super().__init__()
|
76 |
self.stages = None
|
@@ -89,7 +109,7 @@ class NNProcessing(nn.Module):
|
|
89 |
def forward(self, raw):
|
90 |
self.stages = {}
|
91 |
self.buffer = {}
|
92 |
-
|
93 |
rgb = raw2rgb(raw)
|
94 |
if self.normalize_mosaic:
|
95 |
rgb = self.normalize_mosaic(rgb)
|
@@ -108,18 +128,29 @@ class NNProcessing(nn.Module):
|
|
108 |
return rgb
|
109 |
|
110 |
|
111 |
-
def
|
112 |
processor.additive_layer = nn.Parameter(torch.zeros((1, 3, 256, 256)))
|
113 |
# processor.additive_layer = nn.Parameter(0.001 * torch.randn((1, 3, 256, 256)))
|
114 |
|
115 |
|
116 |
class ParametrizedProcessing(nn.Module):
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
super().__init__()
|
119 |
self.stages = None
|
120 |
self.buffer = None
|
121 |
self.track_stages = track_stages
|
122 |
|
|
|
|
|
|
|
123 |
black_level, white_balance, colour_matrix = camera_parameters
|
124 |
|
125 |
self.black_level = nn.Parameter(torch.as_tensor(black_level))
|
@@ -197,8 +228,11 @@ class ParametrizedProcessing(nn.Module):
|
|
197 |
|
198 |
|
199 |
class Debayer(nn.Conv2d):
|
|
|
|
|
|
|
200 |
def __init__(self):
|
201 |
-
super().__init__(3, 3, kernel_size=3, padding=1, padding_mode='reflect', bias=False) #
|
202 |
self.weight.data.fill_(0)
|
203 |
self.weight.data[0, 0] = K_RB.clone()
|
204 |
self.weight.data[1, 1] = K_G.clone()
|
@@ -206,15 +240,16 @@ class Debayer(nn.Conv2d):
|
|
206 |
|
207 |
|
208 |
def raw2rgb(raw, black_level=None, reduce_size=True, out_channels=3):
|
209 |
-
"""
|
|
|
210 |
Args:
|
211 |
raw (Tensor): raw Tensor of shape (B, H, W)
|
212 |
black_level (iterable, optional): RGGB black level values to subtract
|
213 |
reduce_size (bool, optional): if False, the output image will have the same height and width
|
214 |
as the raw input, i.e. (B, C, H, W), empty values are filled with zeros.
|
215 |
if True, the output dimensions are reduced by half (B, C, H//2, W//2),
|
216 |
-
the two green channels are averaged.
|
217 |
out_channels (int, optional): number of output channels. One of {3, 4}.
|
|
|
218 |
"""
|
219 |
assert out_channels in [3, 4]
|
220 |
if black_level is None:
|
|
|
43 |
|
44 |
|
45 |
class RawToRGB(nn.Module):
|
46 |
+
"""transforms a raw image with 1 channel to rgb with 3 channels
|
47 |
+
|
48 |
+
Args:
|
49 |
+
reduce_size (bool, optional): if False, the output image will have the same height and width
|
50 |
+
as the raw input, i.e. (B, C, H, W), empty values are filled with zeros.
|
51 |
+
if True, the output dimensions are reduced by half (B, C, H//2, W//2),
|
52 |
+
out_channels (int, optional): number of output channels. One of {3, 4}.
|
53 |
+
for 3 channels, the two green channels are averaged.
|
54 |
+
track_stages (bool, optional): whether or not to retain intermediary steps in processing
|
55 |
+
normalize_mosaic (function, optional): applies normalization transformation to rgb image
|
56 |
+
"""
|
57 |
+
|
58 |
def __init__(self, reduce_size=True, out_channels=3, track_stages=False, normalize_mosaic=None):
|
59 |
super().__init__()
|
60 |
self.stages = None
|
|
|
83 |
|
84 |
|
85 |
class NNProcessing(nn.Module):
|
86 |
+
"""Transforms raw images to processed rgb via a segmentation Unet
|
87 |
+
|
88 |
+
Args:
|
89 |
+
track_stages (bool, optional): whether or not to retain intermediary steps in processing
|
90 |
+
normalize_mosaic (function, optional): applies normalization transformation to rgb image
|
91 |
+
batch_norm_output (bool, optional): adds a BatchNorm layer to the end of the processing
|
92 |
+
"""
|
93 |
+
|
94 |
def __init__(self, track_stages=False, normalize_mosaic=None, batch_norm_output=True):
|
95 |
super().__init__()
|
96 |
self.stages = None
|
|
|
109 |
def forward(self, raw):
|
110 |
self.stages = {}
|
111 |
self.buffer = {}
|
112 |
+
|
113 |
rgb = raw2rgb(raw)
|
114 |
if self.normalize_mosaic:
|
115 |
rgb = self.normalize_mosaic(rgb)
|
|
|
128 |
return rgb
|
129 |
|
130 |
|
131 |
+
def append_additive_layer(processor):
|
132 |
processor.additive_layer = nn.Parameter(torch.zeros((1, 3, 256, 256)))
|
133 |
# processor.additive_layer = nn.Parameter(0.001 * torch.randn((1, 3, 256, 256)))
|
134 |
|
135 |
|
136 |
class ParametrizedProcessing(nn.Module):
|
137 |
+
"""Differentiable processing pipeline via torch transformations
|
138 |
+
|
139 |
+
Args:
|
140 |
+
camera_parameters (tuple(list), optional): applies given camera parameters in processing
|
141 |
+
track_stages (bool, optional): whether or not to retain intermediary steps in processing
|
142 |
+
batch_norm_output (bool, optional): adds a BatchNorm layer to the end of the processing
|
143 |
+
"""
|
144 |
+
|
145 |
+
def __init__(self, camera_parameters=None, track_stages=False, batch_norm_output=True):
|
146 |
super().__init__()
|
147 |
self.stages = None
|
148 |
self.buffer = None
|
149 |
self.track_stages = track_stages
|
150 |
|
151 |
+
if camera_parameters is None:
|
152 |
+
camera_parameters = DEFAULT_CAMERA_PARAMS
|
153 |
+
|
154 |
black_level, white_balance, colour_matrix = camera_parameters
|
155 |
|
156 |
self.black_level = nn.Parameter(torch.as_tensor(black_level))
|
|
|
228 |
|
229 |
|
230 |
class Debayer(nn.Conv2d):
|
231 |
+
"""Separates the mosaiced raw image into its channels and interpolates bilinearly. Output is of same size as input.
|
232 |
+
"""
|
233 |
+
|
234 |
def __init__(self):
|
235 |
+
super().__init__(3, 3, kernel_size=3, padding=1, padding_mode='reflect', bias=False) # pipeline_numpy uses 'replicate'
|
236 |
self.weight.data.fill_(0)
|
237 |
self.weight.data[0, 0] = K_RB.clone()
|
238 |
self.weight.data[1, 1] = K_G.clone()
|
|
|
240 |
|
241 |
|
242 |
def raw2rgb(raw, black_level=None, reduce_size=True, out_channels=3):
|
243 |
+
"""Transforms a raw image with 1 channel to rgb with 3 channels
|
244 |
+
|
245 |
Args:
|
246 |
raw (Tensor): raw Tensor of shape (B, H, W)
|
247 |
black_level (iterable, optional): RGGB black level values to subtract
|
248 |
reduce_size (bool, optional): if False, the output image will have the same height and width
|
249 |
as the raw input, i.e. (B, C, H, W), empty values are filled with zeros.
|
250 |
if True, the output dimensions are reduced by half (B, C, H//2, W//2),
|
|
|
251 |
out_channels (int, optional): number of output channels. One of {3, 4}.
|
252 |
+
The two green channels are averaged if out_channels == 3.
|
253 |
"""
|
254 |
assert out_channels in [3, 4]
|
255 |
if black_level is None:
|
train.py
CHANGED
@@ -15,14 +15,14 @@ from pytorch_lightning.metrics.functional import accuracy
|
|
15 |
import pytorch_lightning as pl
|
16 |
from pytorch_lightning.callbacks import ModelCheckpoint
|
17 |
|
18 |
-
from utils.base import display_mlflow_run_info, str2bool, fetch_from_mlflow, get_name, data_loader_mean_and_std
|
19 |
from utils.debug import debug
|
20 |
from utils.dataset_utils import k_fold
|
21 |
from utils.augmentation import get_augmentation
|
22 |
from dataset import Subset, get_dataset
|
23 |
|
24 |
from processing.pipeline_numpy import RawProcessingPipeline
|
25 |
-
from processing.pipeline_torch import
|
26 |
|
27 |
from model import log_tensor, resnet_model, LitModel, TrackImagesCallback
|
28 |
|
@@ -31,89 +31,88 @@ import segmentation_models_pytorch as smp
|
|
31 |
from utils.ssim import SSIM
|
32 |
|
33 |
# args to set up task
|
34 |
-
parser = argparse.ArgumentParser(description=
|
35 |
-
parser.add_argument(
|
36 |
-
default=
|
37 |
-
parser.add_argument(
|
38 |
help='URI of the processing model (e.g. s3://mlflow-artifacts-821771080529/1/5fa754c566e3466690b1d309a476340f/artifacts/processing-model)')
|
39 |
-
parser.add_argument(
|
40 |
help='URI of the net (e.g. s3://mlflow-artifacts-821771080529/1/5fa754c566e3466690b1d309a476340f/artifacts/prediction-model)')
|
41 |
-
parser.add_argument(
|
42 |
default=None, help='URI of the indices you want to load (e.g. s3://mlflow-artifacts-601883093460/7/4326da05aca54107be8c554de0674a14/artifacts/training')
|
43 |
|
44 |
-
parser.add_argument(
|
45 |
default='classification learnable pipeline', help='Specify the experiment you are running, e.g. end2end segmentation')
|
46 |
-
parser.add_argument(
|
47 |
default='test run', help='Specify the name of your run')
|
48 |
|
49 |
-
parser.add_argument(
|
50 |
-
parser.add_argument(
|
51 |
help='Model will be saved locally if action is taken') # TODO: bypass mlflow
|
52 |
|
53 |
-
parser.add_argument(
|
54 |
help='Save images after each trasformation of the pipeline for the test set')
|
55 |
-
parser.add_argument(
|
56 |
help='Save images of gradients after each trasformation of the pipeline for the test set')
|
57 |
-
parser.add_argument(
|
58 |
help='Save the torch tensors after each trasformation of the pipeline for the test set')
|
59 |
-
parser.add_argument(
|
60 |
help='Save images after each trasformation of the pipeline for the test set + input gradient')
|
61 |
-
parser.add_argument(
|
62 |
help='Track the n first elements of dataset. Only used for args.track_processing=True')
|
63 |
-
parser.add_argument(
|
64 |
|
65 |
# args to create dataset
|
66 |
-
parser.add_argument(
|
67 |
-
parser.add_argument(
|
68 |
-
choices=[
|
69 |
|
70 |
-
parser.add_argument(
|
71 |
-
parser.add_argument(
|
72 |
|
73 |
# args for training
|
74 |
-
parser.add_argument(
|
75 |
-
parser.add_argument(
|
76 |
-
parser.add_argument(
|
77 |
-
parser.add_argument(
|
78 |
-
choices=[
|
79 |
-
parser.add_argument(
|
80 |
-
help='Track images every epoch or once after training') # TODO: implement, actually should be disabled by default for 'val' and 'test
|
81 |
-
parser.add_argument("--check_val_every_n_epoch", type=int, default=1)
|
82 |
|
83 |
# args to specify the processing
|
84 |
-
parser.add_argument(
|
85 |
-
choices=[
|
86 |
-
help=
|
87 |
|
88 |
# args to specify model
|
89 |
-
parser.add_argument(
|
90 |
-
help='Type of pretrained network')
|
91 |
-
parser.add_argument(
|
92 |
help='Whether to use a pre-trained model or not')
|
93 |
-
parser.add_argument(
|
94 |
|
95 |
-
parser.add_argument(
|
96 |
-
parser.add_argument(
|
97 |
|
98 |
# args to specify static pipeline transformations
|
99 |
-
parser.add_argument(
|
100 |
-
choices=['bilinear', 'malvar2004', 'menon2007'], help=
|
101 |
-
parser.add_argument(
|
102 |
-
choices=['sharpening_filter', 'unsharp_masking'], help=
|
103 |
-
parser.add_argument(
|
104 |
-
choices=['gaussian_denoising', 'median_denoising', 'fft_denoising'], help=
|
105 |
|
106 |
# args to choose training mode
|
107 |
-
parser.add_argument(
|
108 |
-
parser.add_argument(
|
109 |
-
parser.add_argument(
|
110 |
-
help=
|
111 |
-
parser.add_argument(
|
112 |
-
parser.add_argument(
|
113 |
parser.add_argument('--adv_parameters', choices=['all', 'black_level', 'white_balance',
|
114 |
-
'colour_correction', 'gamma_correct', 'sharpening_filter', 'gaussian_blur', 'additive_layer']
|
|
|
115 |
|
116 |
-
parser.add_argument(
|
117 |
|
118 |
parser.add_argument('--test_run', action='store_true')
|
119 |
|
@@ -133,10 +132,8 @@ def run_train(args):
|
|
133 |
# set tracking uri, this is the address of the mlflow server where light experimental data will be stored
|
134 |
mlflow.set_tracking_uri(args.tracking_uri)
|
135 |
mlflow.set_experiment(args.experiment_name)
|
136 |
-
os.environ[
|
137 |
-
os.environ[
|
138 |
-
|
139 |
-
# dataset
|
140 |
|
141 |
dataset = get_dataset(args.dataset)
|
142 |
|
@@ -147,52 +144,57 @@ def run_train(args):
|
|
147 |
pl.seed_everything(args.seed)
|
148 |
idxs_kfold = k_fold(dataset, n_splits=args.n_splits, seed=args.seed, train_size=args.train_size)
|
149 |
|
|
|
150 |
with mlflow.start_run(run_name=args.run_name) as parent_run:
|
151 |
|
152 |
-
|
|
|
153 |
|
154 |
-
print(f
|
155 |
|
156 |
if args.processing_mode == 'static':
|
157 |
-
if
|
|
|
158 |
mean = torch.tensor([0.35, 0.36, 0.35])
|
159 |
std = torch.tensor([0.12, 0.11, 0.12])
|
160 |
-
elif args.dataset ==
|
161 |
mean = torch.tensor([0.91, 0.84, 0.94])
|
162 |
std = torch.tensor([0.08, 0.12, 0.05])
|
163 |
|
|
|
164 |
dataset.transform = T.Compose([RawProcessingPipeline(
|
165 |
camera_parameters=dataset.camera_parameters,
|
166 |
debayer=args.sp_debayer,
|
167 |
sharpening=args.sp_sharpening,
|
168 |
denoising=args.sp_denoising,
|
169 |
-
),
|
170 |
-
|
|
|
171 |
|
172 |
processor = nn.Identity()
|
173 |
|
|
|
174 |
if args.processor_uri is not None and args.processing_mode != 'none':
|
175 |
print('Fetching processor: ', end='')
|
176 |
-
|
177 |
-
|
178 |
-
for param in processor.parameters():
|
179 |
-
param.requires_grad = True
|
180 |
-
model.processor = None
|
181 |
-
del model
|
182 |
else:
|
183 |
print(f'processing_mode: {args.processing_mode}')
|
184 |
-
normalize_mosaic = None # normalize after raw has been
|
185 |
-
|
|
|
|
|
|
|
|
|
186 |
mosaic_mean = [0.5663, 0.1401, 0.0731]
|
187 |
mosaic_std = [0.097, 0.0423, 0.008]
|
188 |
normalize_mosaic = T.Normalize(mosaic_mean, mosaic_std)
|
189 |
|
|
|
190 |
track_stages = args.track_processing or args.track_processing_gradients
|
191 |
if args.processing_mode == 'parametrized':
|
192 |
processor = ParametrizedProcessing(
|
193 |
-
camera_parameters=dataset.camera_parameters, track_stages=track_stages, batch_norm_output=True
|
194 |
-
# noise_layer=args.adv_noise_layer, # this has to be added manually afterwards for when a model is loaded that doesn't have one yet
|
195 |
-
)
|
196 |
|
197 |
elif args.processing_mode == 'neural_network':
|
198 |
processor = NNProcessing(track_stages=track_stages,
|
@@ -201,22 +203,19 @@ def run_train(args):
|
|
201 |
processor = RawToRGB(reduce_size=True, out_channels=3, track_stages=track_stages,
|
202 |
normalize_mosaic=normalize_mosaic)
|
203 |
|
204 |
-
if args.classifier_uri: # fetch classifier
|
205 |
print('Fetching classifier: ', end='')
|
206 |
-
|
207 |
-
|
208 |
-
model.classifier = None
|
209 |
-
del model
|
210 |
else:
|
211 |
if dataset.task == 'classification':
|
212 |
classifier = resnet_model(
|
213 |
-
model=
|
214 |
pretrained=args.classifier_pretrained,
|
215 |
in_channels=3,
|
216 |
fc_out_features=len(dataset.classes)
|
217 |
)
|
218 |
else:
|
219 |
-
# XXX: add other network choices to args.smp_network (FPN) and args.network
|
220 |
classifier = smp.UnetPlusPlus(
|
221 |
encoder_name=args.smp_encoder,
|
222 |
encoder_depth=5,
|
@@ -240,26 +239,23 @@ def run_train(args):
|
|
240 |
|
241 |
loss_aux = None
|
242 |
|
243 |
-
if args.adv_training:
|
244 |
|
245 |
assert args.processing_mode == 'parametrized', f"Processing mode ({args.processing_mode}) should be set to 'parametrized' for adversarial training"
|
246 |
-
assert args.freeze_classifier,
|
247 |
-
assert not args.freeze_processor,
|
248 |
|
|
|
249 |
processor_default = copy.deepcopy(processor)
|
250 |
processor_default.track_stages = args.track_processing
|
251 |
processor_default.eval()
|
252 |
processor_default.to(DEVICE)
|
253 |
-
|
254 |
for p in processor_default.parameters():
|
255 |
p.requires_grad = False
|
256 |
|
257 |
-
if args.adv_noise_layer:
|
258 |
-
|
259 |
-
|
260 |
-
def l2_regularization(x, y):
|
261 |
-
return ((x - y) ** 2).sum()
|
262 |
-
# return (x - y).norm()
|
263 |
|
264 |
if args.adv_aux_loss == 'l2':
|
265 |
regularization = l2_regularization
|
@@ -268,34 +264,12 @@ def run_train(args):
|
|
268 |
else:
|
269 |
NotImplementedError(args.adv_aux_loss)
|
270 |
|
271 |
-
class AuxLoss(nn.Module):
|
272 |
-
def __init__(self, loss_aux, weight=1):
|
273 |
-
super().__init__()
|
274 |
-
self.loss_aux = loss_aux
|
275 |
-
self.weight = weight
|
276 |
-
|
277 |
-
def forward(self, x):
|
278 |
-
with torch.no_grad():
|
279 |
-
x_reference = processor_default(x)
|
280 |
-
x_processed = processor.buffer['processed_rgb']
|
281 |
-
return self.weight * self.loss_aux(x_reference, x_processed)
|
282 |
-
|
283 |
-
class WeightedLoss(nn.Module):
|
284 |
-
def __init__(self, loss, weight=1):
|
285 |
-
super().__init__()
|
286 |
-
self.loss = loss
|
287 |
-
self.weight = weight
|
288 |
-
|
289 |
-
def forward(self, x, y):
|
290 |
-
return self.weight * self.loss(x, y)
|
291 |
-
|
292 |
-
def __repr__(self):
|
293 |
-
return f'{self.weight} * {get_name(self.loss)}'
|
294 |
-
|
295 |
loss = WeightedLoss(loss=loss, weight=-1)
|
296 |
-
|
297 |
loss_aux = AuxLoss(
|
298 |
loss_aux=regularization,
|
|
|
|
|
299 |
weight=args.adv_aux_weight,
|
300 |
)
|
301 |
|
@@ -316,15 +290,13 @@ def run_train(args):
|
|
316 |
freeze_processor=args.freeze_processor,
|
317 |
)
|
318 |
|
|
|
|
|
319 |
# get train_set_dict
|
320 |
if args.state_dict_uri:
|
321 |
state_dict = mlflow.pytorch.load_state_dict(args.state_dict_uri)
|
322 |
train_indices = state_dict['train_indices']
|
323 |
valid_indices = state_dict['valid_indices']
|
324 |
-
else:
|
325 |
-
train_indices = idxs[0]
|
326 |
-
valid_indices = idxs[1]
|
327 |
-
state_dict = vars(args).copy()
|
328 |
|
329 |
track_indices = list(range(args.track_n_images))
|
330 |
|
@@ -350,8 +322,6 @@ def run_train(args):
|
|
350 |
|
351 |
with mlflow.start_run(run_name=f"{args.run_name}_{k_iter}", nested=True) as child_run:
|
352 |
|
353 |
-
# mlflow.pytorch.autolog(silent=True)
|
354 |
-
|
355 |
if k_iter == 0:
|
356 |
display_mlflow_run_info(child_run)
|
357 |
|
@@ -389,14 +359,6 @@ def run_train(args):
|
|
389 |
track_predictions=args.track_predictions,
|
390 |
save_tensors=args.track_save_tensors)]
|
391 |
|
392 |
-
# if True: #args.save_best:
|
393 |
-
# if dataset.task == 'classification':
|
394 |
-
#checkpoint_callback = ModelCheckpoint(pathmonitor="val_accuracy", mode='max')
|
395 |
-
# checkpoint_callback = ModelCheckpoint(dirpath=args.tracking_uri, save_top_k=1, verbose=True, monitor="val_accuracy", mode="max") #dirpath=args.tracking_uri,
|
396 |
-
# else:
|
397 |
-
# checkpoint_callback = ModelCheckpoint(monitor="val_iou_score")
|
398 |
-
#callbacks += [checkpoint_callback]
|
399 |
-
|
400 |
trainer = pl.Trainer(
|
401 |
gpus=1 if DEVICE == 'cuda' else 0,
|
402 |
min_epochs=args.epochs,
|
@@ -404,7 +366,6 @@ def run_train(args):
|
|
404 |
logger=mlf_logger,
|
405 |
callbacks=callbacks,
|
406 |
check_val_every_n_epoch=args.check_val_every_n_epoch,
|
407 |
-
# checkpoint_callback=True,
|
408 |
)
|
409 |
|
410 |
if args.log_model:
|
|
|
15 |
import pytorch_lightning as pl
|
16 |
from pytorch_lightning.callbacks import ModelCheckpoint
|
17 |
|
18 |
+
from utils.base import AuxLoss, WeightedLoss, display_mlflow_run_info, l2_regularization, str2bool, fetch_from_mlflow, get_name, data_loader_mean_and_std
|
19 |
from utils.debug import debug
|
20 |
from utils.dataset_utils import k_fold
|
21 |
from utils.augmentation import get_augmentation
|
22 |
from dataset import Subset, get_dataset
|
23 |
|
24 |
from processing.pipeline_numpy import RawProcessingPipeline
|
25 |
+
from processing.pipeline_torch import append_additive_layer, raw2rgb, RawToRGB, ParametrizedProcessing, NNProcessing
|
26 |
|
27 |
from model import log_tensor, resnet_model, LitModel, TrackImagesCallback
|
28 |
|
|
|
31 |
from utils.ssim import SSIM
|
32 |
|
33 |
# args to set up task
|
34 |
+
parser = argparse.ArgumentParser(description='classification_task')
|
35 |
+
parser.add_argument('--tracking_uri', type=str,
|
36 |
+
default='http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com', help='URI of the mlflow server on AWS')
|
37 |
+
parser.add_argument('--processor_uri', type=str, default=None,
|
38 |
help='URI of the processing model (e.g. s3://mlflow-artifacts-821771080529/1/5fa754c566e3466690b1d309a476340f/artifacts/processing-model)')
|
39 |
+
parser.add_argument('--classifier_uri', type=str, default=None,
|
40 |
help='URI of the net (e.g. s3://mlflow-artifacts-821771080529/1/5fa754c566e3466690b1d309a476340f/artifacts/prediction-model)')
|
41 |
+
parser.add_argument('--state_dict_uri', type=str,
|
42 |
default=None, help='URI of the indices you want to load (e.g. s3://mlflow-artifacts-601883093460/7/4326da05aca54107be8c554de0674a14/artifacts/training')
|
43 |
|
44 |
+
parser.add_argument('--experiment_name', type=str,
|
45 |
default='classification learnable pipeline', help='Specify the experiment you are running, e.g. end2end segmentation')
|
46 |
+
parser.add_argument('--run_name', type=str,
|
47 |
default='test run', help='Specify the name of your run')
|
48 |
|
49 |
+
parser.add_argument('--log_model', type=str2bool, default=True, help='Enables model logging')
|
50 |
+
parser.add_argument('--save_locally', action='store_true',
|
51 |
help='Model will be saved locally if action is taken') # TODO: bypass mlflow
|
52 |
|
53 |
+
parser.add_argument('--track_processing', action='store_true',
|
54 |
help='Save images after each trasformation of the pipeline for the test set')
|
55 |
+
parser.add_argument('--track_processing_gradients', action='store_true',
|
56 |
help='Save images of gradients after each trasformation of the pipeline for the test set')
|
57 |
+
parser.add_argument('--track_save_tensors', action='store_true',
|
58 |
help='Save the torch tensors after each trasformation of the pipeline for the test set')
|
59 |
+
parser.add_argument('--track_predictions', action='store_true',
|
60 |
help='Save images after each trasformation of the pipeline for the test set + input gradient')
|
61 |
+
parser.add_argument('--track_n_images', default=5,
|
62 |
help='Track the n first elements of dataset. Only used for args.track_processing=True')
|
63 |
+
parser.add_argument('--track_every_epoch', action='store_true', help='Track images every epoch or once after training')
|
64 |
|
65 |
# args to create dataset
|
66 |
+
parser.add_argument('--seed', type=int, default=1, help='Global seed')
|
67 |
+
parser.add_argument('--dataset', type=str, default='Microscopy',
|
68 |
+
choices=['Drone', 'DroneSegmentation', 'Microscopy'], help='Select dataset')
|
69 |
|
70 |
+
parser.add_argument('--n_splits', type=int, default=1, help='Number of splits used for training')
|
71 |
+
parser.add_argument('--train_size', type=float, default=0.8, help='Fraction of training points in dataset')
|
72 |
|
73 |
# args for training
|
74 |
+
parser.add_argument('--lr', type=float, default=1e-5, help='learning rate used for training')
|
75 |
+
parser.add_argument('--epochs', type=int, default=3, help='numper of epochs')
|
76 |
+
parser.add_argument('--batch_size', type=int, default=32, help='Training batch size')
|
77 |
+
parser.add_argument('--augmentation', type=str, default='none',
|
78 |
+
choices=['none', 'weak', 'strong'], help='Applies augmentation to training')
|
79 |
+
parser.add_argument('--check_val_every_n_epoch', type=int, default=1)
|
|
|
|
|
80 |
|
81 |
# args to specify the processing
|
82 |
+
parser.add_argument('--processing_mode', type=str, default='parametrized',
|
83 |
+
choices=['parametrized', 'static', 'neural_network', 'none'],
|
84 |
+
help='Which type of raw to rgb processing should be used')
|
85 |
|
86 |
# args to specify model
|
87 |
+
parser.add_argument('--classifier_network', type=str, default='ResNet18', choices=['ResNet18', 'ResNet34', 'Resnet50'],
|
88 |
+
help='Type of pretrained network')
|
89 |
+
parser.add_argument('--classifier_pretrained', action='store_true',
|
90 |
help='Whether to use a pre-trained model or not')
|
91 |
+
parser.add_argument('--smp_encoder', type=str, default='resnet34', help='segmentation models pytorch encoder')
|
92 |
|
93 |
+
parser.add_argument('--freeze_processor', action='store_true', help='Freeze raw to rgb processing model weights')
|
94 |
+
parser.add_argument('--freeze_classifier', action='store_true', help='Freeze classification model weights')
|
95 |
|
96 |
# args to specify static pipeline transformations
|
97 |
+
parser.add_argument('--sp_debayer', type=str, default='bilinear',
|
98 |
+
choices=['bilinear', 'malvar2004', 'menon2007'], help='Specify algorithm used as debayer')
|
99 |
+
parser.add_argument('--sp_sharpening', type=str, default='sharpening_filter',
|
100 |
+
choices=['sharpening_filter', 'unsharp_masking'], help='Specify algorithm used for sharpening')
|
101 |
+
parser.add_argument('--sp_denoising', type=str, default='gaussian_denoising',
|
102 |
+
choices=['gaussian_denoising', 'median_denoising', 'fft_denoising'], help='Specify algorithm used for denoising')
|
103 |
|
104 |
# args to choose training mode
|
105 |
+
parser.add_argument('--adv_training', action='store_true', help='Enable adversarial training')
|
106 |
+
parser.add_argument('--adv_aux_weight', type=float, default=1, help='Weighting of the adversarial auxilliary loss')
|
107 |
+
parser.add_argument('--adv_aux_loss', type=str, default='ssim', choices=['l2', 'ssim'],
|
108 |
+
help='Type of adversarial auxilliary regularization loss')
|
109 |
+
parser.add_argument('--adv_noise_layer', action='store_true', help='Adds an additive layer to Parametrized Processing')
|
110 |
+
parser.add_argument('--adv_track_differences', action='store_true', help='Save difference to default pipeline')
|
111 |
parser.add_argument('--adv_parameters', choices=['all', 'black_level', 'white_balance',
|
112 |
+
'colour_correction', 'gamma_correct', 'sharpening_filter', 'gaussian_blur', 'additive_layer'],
|
113 |
+
help='Target individual parameters for adversarial training.')
|
114 |
|
115 |
+
parser.add_argument('--cache_downloaded_models', type=str2bool, default=True)
|
116 |
|
117 |
parser.add_argument('--test_run', action='store_true')
|
118 |
|
|
|
132 |
# set tracking uri, this is the address of the mlflow server where light experimental data will be stored
|
133 |
mlflow.set_tracking_uri(args.tracking_uri)
|
134 |
mlflow.set_experiment(args.experiment_name)
|
135 |
+
os.environ['AWS_ACCESS_KEY_ID'] = '#TODO: fill in your aws access key id for mlflow server here'
|
136 |
+
os.environ['AWS_SECRET_ACCESS_KEY'] = '#TODO: fill in your aws secret access key for mlflow server here'
|
|
|
|
|
137 |
|
138 |
dataset = get_dataset(args.dataset)
|
139 |
|
|
|
144 |
pl.seed_everything(args.seed)
|
145 |
idxs_kfold = k_fold(dataset, n_splits=args.n_splits, seed=args.seed, train_size=args.train_size)
|
146 |
|
147 |
+
# start mlflow parent run for k-fold validation (optional)
|
148 |
with mlflow.start_run(run_name=args.run_name) as parent_run:
|
149 |
|
150 |
+
# start mlflow child run
|
151 |
+
for k_iter, (train_indices, valid_indices) in enumerate(idxs_kfold):
|
152 |
|
153 |
+
print(f'K_fold subset: {k_iter+1}/{args.n_splits}')
|
154 |
|
155 |
if args.processing_mode == 'static':
|
156 |
+
# only needed if processor outputs should be normalized (might help for classifier training / testing against torch pipeline)
|
157 |
+
if args.dataset == 'Drone' or args.dataset == 'DroneSegmentation':
|
158 |
mean = torch.tensor([0.35, 0.36, 0.35])
|
159 |
std = torch.tensor([0.12, 0.11, 0.12])
|
160 |
+
elif args.dataset == 'Microscopy':
|
161 |
mean = torch.tensor([0.91, 0.84, 0.94])
|
162 |
std = torch.tensor([0.08, 0.12, 0.05])
|
163 |
|
164 |
+
# numpy pipeline doesn't use torch batched transformations. Transformations are applied individually to dataloader
|
165 |
dataset.transform = T.Compose([RawProcessingPipeline(
|
166 |
camera_parameters=dataset.camera_parameters,
|
167 |
debayer=args.sp_debayer,
|
168 |
sharpening=args.sp_sharpening,
|
169 |
denoising=args.sp_denoising,
|
170 |
+
),
|
171 |
+
T.Normalize(mean, std)
|
172 |
+
])
|
173 |
|
174 |
processor = nn.Identity()
|
175 |
|
176 |
+
# fetch processor from mlflow
|
177 |
if args.processor_uri is not None and args.processing_mode != 'none':
|
178 |
print('Fetching processor: ', end='')
|
179 |
+
processor = fetch_from_mlflow(args.processor_uri, type='processor',
|
180 |
+
use_cache=args.cache_downloaded_models)
|
|
|
|
|
|
|
|
|
181 |
else:
|
182 |
print(f'processing_mode: {args.processing_mode}')
|
183 |
+
normalize_mosaic = None # normalize after raw has been transformed to rgb image via raw2rgb
|
184 |
+
# not strictly necessary, but for processing_mode=='none' this will ensure normalized outputs for the classifier
|
185 |
+
# and for processing_mode=='neural_network', the processing segmentation model receives normalized inputs
|
186 |
+
# could be evaded via an additional batchnorm!
|
187 |
+
# XXX
|
188 |
+
if args.dataset == 'Microscopy':
|
189 |
mosaic_mean = [0.5663, 0.1401, 0.0731]
|
190 |
mosaic_std = [0.097, 0.0423, 0.008]
|
191 |
normalize_mosaic = T.Normalize(mosaic_mean, mosaic_std)
|
192 |
|
193 |
+
# track individual processing steps for visualization
|
194 |
track_stages = args.track_processing or args.track_processing_gradients
|
195 |
if args.processing_mode == 'parametrized':
|
196 |
processor = ParametrizedProcessing(
|
197 |
+
camera_parameters=dataset.camera_parameters, track_stages=track_stages, batch_norm_output=True)
|
|
|
|
|
198 |
|
199 |
elif args.processing_mode == 'neural_network':
|
200 |
processor = NNProcessing(track_stages=track_stages,
|
|
|
203 |
processor = RawToRGB(reduce_size=True, out_channels=3, track_stages=track_stages,
|
204 |
normalize_mosaic=normalize_mosaic)
|
205 |
|
206 |
+
if args.classifier_uri: # fetch classifier from mlflow
|
207 |
print('Fetching classifier: ', end='')
|
208 |
+
classifier = fetch_from_mlflow(args.classifier_uri, type='classifier',
|
209 |
+
use_cache=args.cache_downloaded_models)
|
|
|
|
|
210 |
else:
|
211 |
if dataset.task == 'classification':
|
212 |
classifier = resnet_model(
|
213 |
+
model=args.classifier_network,
|
214 |
pretrained=args.classifier_pretrained,
|
215 |
in_channels=3,
|
216 |
fc_out_features=len(dataset.classes)
|
217 |
)
|
218 |
else:
|
|
|
219 |
classifier = smp.UnetPlusPlus(
|
220 |
encoder_name=args.smp_encoder,
|
221 |
encoder_depth=5,
|
|
|
239 |
|
240 |
loss_aux = None
|
241 |
|
242 |
+
if args.adv_training: # setup for failure mode search
|
243 |
|
244 |
assert args.processing_mode == 'parametrized', f"Processing mode ({args.processing_mode}) should be set to 'parametrized' for adversarial training"
|
245 |
+
assert args.freeze_classifier, 'Classifier should be frozen for adversarial training'
|
246 |
+
assert not args.freeze_processor, 'Processor should not be frozen for adversarial training'
|
247 |
|
248 |
+
# copy, so that regularization in rgb space between adversarial and original processor can be computed
|
249 |
processor_default = copy.deepcopy(processor)
|
250 |
processor_default.track_stages = args.track_processing
|
251 |
processor_default.eval()
|
252 |
processor_default.to(DEVICE)
|
253 |
+
|
254 |
for p in processor_default.parameters():
|
255 |
p.requires_grad = False
|
256 |
|
257 |
+
if args.adv_noise_layer: # optional additional "noise" layer in processor
|
258 |
+
append_additive_layer(processor)
|
|
|
|
|
|
|
|
|
259 |
|
260 |
if args.adv_aux_loss == 'l2':
|
261 |
regularization = l2_regularization
|
|
|
264 |
else:
|
265 |
NotImplementedError(args.adv_aux_loss)
|
266 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
loss = WeightedLoss(loss=loss, weight=-1)
|
268 |
+
|
269 |
loss_aux = AuxLoss(
|
270 |
loss_aux=regularization,
|
271 |
+
processor_adv=processor,
|
272 |
+
processor_default=processor_default,
|
273 |
weight=args.adv_aux_weight,
|
274 |
)
|
275 |
|
|
|
290 |
freeze_processor=args.freeze_processor,
|
291 |
)
|
292 |
|
293 |
+
state_dict = vars(args).copy()
|
294 |
+
|
295 |
# get train_set_dict
|
296 |
if args.state_dict_uri:
|
297 |
state_dict = mlflow.pytorch.load_state_dict(args.state_dict_uri)
|
298 |
train_indices = state_dict['train_indices']
|
299 |
valid_indices = state_dict['valid_indices']
|
|
|
|
|
|
|
|
|
300 |
|
301 |
track_indices = list(range(args.track_n_images))
|
302 |
|
|
|
322 |
|
323 |
with mlflow.start_run(run_name=f"{args.run_name}_{k_iter}", nested=True) as child_run:
|
324 |
|
|
|
|
|
325 |
if k_iter == 0:
|
326 |
display_mlflow_run_info(child_run)
|
327 |
|
|
|
359 |
track_predictions=args.track_predictions,
|
360 |
save_tensors=args.track_save_tensors)]
|
361 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
362 |
trainer = pl.Trainer(
|
363 |
gpus=1 if DEVICE == 'cuda' else 0,
|
364 |
min_epochs=args.epochs,
|
|
|
366 |
logger=mlf_logger,
|
367 |
callbacks=callbacks,
|
368 |
check_val_every_n_epoch=args.check_val_every_n_epoch,
|
|
|
369 |
)
|
370 |
|
371 |
if args.log_model:
|
utils/base.py
CHANGED
@@ -18,14 +18,7 @@ from b2sdk.v1 import *
|
|
18 |
|
19 |
import argparse
|
20 |
|
21 |
-
|
22 |
-
class SmartFormatter(argparse.HelpFormatter):
|
23 |
-
|
24 |
-
def _split_lines(self, text, width):
|
25 |
-
if text.startswith('R|'):
|
26 |
-
return text[2:].splitlines()
|
27 |
-
# this is the RawTextHelpFormatter._split_lines
|
28 |
-
return argparse.HelpFormatter._split_lines(self, text, width)
|
29 |
|
30 |
|
31 |
def str2bool(string):
|
@@ -193,6 +186,7 @@ def b2_download_folder(b2_dir, local_dir, force_download=False, mirror_folder=Tr
|
|
193 |
def get_name(obj):
|
194 |
return obj.__name__ if hasattr(obj, '__name__') else type(obj).__name__
|
195 |
|
|
|
196 |
def get_mlflow_model_by_name(experiment_name, run_name,
|
197 |
tracking_uri="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com",
|
198 |
download_model=True):
|
@@ -234,6 +228,7 @@ def get_mlflow_model_by_name(experiment_name, run_name,
|
|
234 |
|
235 |
return state_dict, model
|
236 |
|
|
|
237 |
def data_loader_mean_and_std(data_loader, transform=None):
|
238 |
means = []
|
239 |
stds = []
|
@@ -244,23 +239,35 @@ def data_loader_mean_and_std(data_loader, transform=None):
|
|
244 |
stds.append(x.std(dim=(0, 2, 3)).unsqueeze(0))
|
245 |
return torch.cat(means).mean(dim=0), torch.cat(stds).mean(dim=0)
|
246 |
|
|
|
247 |
def fetch_runs_list_mlflow(experiment):
|
248 |
runs = mlflow.search_runs(experiment.experiment_id)
|
249 |
runs.to_pickle('cache/runs_names.pkl') # where to save it, usually as a .pkl
|
250 |
return runs
|
251 |
|
252 |
-
|
|
|
253 |
cache_loc = os.path.join('cache', uri.split('//')[1]) + '.pt'
|
254 |
if use_cache and os.path.exists(cache_loc):
|
255 |
print(f'loading cached model from {cache_loc} ...')
|
256 |
-
|
257 |
else:
|
258 |
print(f'fetching model from {uri} ...')
|
259 |
model = mlflow.pytorch.load_model(uri)
|
260 |
os.makedirs(os.path.dirname(cache_loc), exist_ok=True)
|
261 |
if download_model:
|
262 |
torch.save(model, cache_loc, pickle_module=mlflow.pytorch.pickle_module)
|
263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
|
265 |
|
266 |
def display_mlflow_run_info(run):
|
@@ -315,16 +322,50 @@ def get_train_test_indices_drone(df, frac, seed=None):
|
|
315 |
return train_indices, test_indices
|
316 |
|
317 |
|
318 |
-
def smp_get_loss(loss):
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
import argparse
|
20 |
|
21 |
+
from torch import nn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
def str2bool(string):
|
|
|
186 |
def get_name(obj):
|
187 |
return obj.__name__ if hasattr(obj, '__name__') else type(obj).__name__
|
188 |
|
189 |
+
|
190 |
def get_mlflow_model_by_name(experiment_name, run_name,
|
191 |
tracking_uri="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com",
|
192 |
download_model=True):
|
|
|
228 |
|
229 |
return state_dict, model
|
230 |
|
231 |
+
|
232 |
def data_loader_mean_and_std(data_loader, transform=None):
|
233 |
means = []
|
234 |
stds = []
|
|
|
239 |
stds.append(x.std(dim=(0, 2, 3)).unsqueeze(0))
|
240 |
return torch.cat(means).mean(dim=0), torch.cat(stds).mean(dim=0)
|
241 |
|
242 |
+
|
243 |
def fetch_runs_list_mlflow(experiment):
|
244 |
runs = mlflow.search_runs(experiment.experiment_id)
|
245 |
runs.to_pickle('cache/runs_names.pkl') # where to save it, usually as a .pkl
|
246 |
return runs
|
247 |
|
248 |
+
|
249 |
+
def fetch_from_mlflow(uri, type='', use_cache=True, download_model=True):
|
250 |
cache_loc = os.path.join('cache', uri.split('//')[1]) + '.pt'
|
251 |
if use_cache and os.path.exists(cache_loc):
|
252 |
print(f'loading cached model from {cache_loc} ...')
|
253 |
+
model = torch.load(cache_loc)
|
254 |
else:
|
255 |
print(f'fetching model from {uri} ...')
|
256 |
model = mlflow.pytorch.load_model(uri)
|
257 |
os.makedirs(os.path.dirname(cache_loc), exist_ok=True)
|
258 |
if download_model:
|
259 |
torch.save(model, cache_loc, pickle_module=mlflow.pytorch.pickle_module)
|
260 |
+
if type == 'processor':
|
261 |
+
processor = model.processor
|
262 |
+
model.processor = None
|
263 |
+
del model # free up memory space
|
264 |
+
return processor
|
265 |
+
if type == 'classifier':
|
266 |
+
classifier = model.classifier
|
267 |
+
model.classifier = None
|
268 |
+
del model # free up memory space
|
269 |
+
return classifier
|
270 |
+
return model
|
271 |
|
272 |
|
273 |
def display_mlflow_run_info(run):
|
|
|
322 |
return train_indices, test_indices
|
323 |
|
324 |
|
325 |
+
# def smp_get_loss(loss):
|
326 |
+
# if loss == "Dice":
|
327 |
+
# return smp.losses.DiceLoss(mode='binary', from_logits=True)
|
328 |
+
# if loss == "BCE":
|
329 |
+
# return nn.BCELoss()
|
330 |
+
# elif loss == "BCEWithLogits":
|
331 |
+
# return smp.losses.BCEWithLogitsLoss()
|
332 |
+
# elif loss == "DicyBCE":
|
333 |
+
# from pytorch_toolbelt import losses as ptbl
|
334 |
+
# return ptbl.JointLoss(ptbl.DiceLoss(mode='binary', from_logits=False),
|
335 |
+
# nn.BCELoss(),
|
336 |
+
# first_weight=args.dice_weight,
|
337 |
+
# second_weight=args.bce_weight)
|
338 |
+
|
339 |
+
|
340 |
+
# Adversarial setup
|
341 |
+
|
342 |
+
def l2_regularization(x, y):
|
343 |
+
return ((x - y) ** 2).sum()
|
344 |
+
|
345 |
+
|
346 |
+
class AuxLoss(nn.Module):
|
347 |
+
def __init__(self, loss_aux, processor_adv, processor_default, weight=1):
|
348 |
+
super().__init__()
|
349 |
+
self.loss_aux = loss_aux
|
350 |
+
self.weight = weight
|
351 |
+
self.processor_adv = processor_adv
|
352 |
+
self.processor_default = processor_default
|
353 |
+
|
354 |
+
def forward(self, x):
|
355 |
+
with torch.no_grad():
|
356 |
+
x_reference = self.processor_default(x)
|
357 |
+
x_processed = self.processor.buffer['processed_rgb']
|
358 |
+
return self.weight * self.loss_aux(x_reference, x_processed)
|
359 |
+
|
360 |
+
|
361 |
+
class WeightedLoss(nn.Module):
|
362 |
+
def __init__(self, loss, weight=1):
|
363 |
+
super().__init__()
|
364 |
+
self.loss = loss
|
365 |
+
self.weight = weight
|
366 |
+
|
367 |
+
def forward(self, x, y):
|
368 |
+
return self.weight * self.loss(x, y)
|
369 |
+
|
370 |
+
def __repr__(self):
|
371 |
+
return f'{self.weight} * {get_name(self.loss)}'
|