jpdefrutos commited on
Commit
f42fb70
·
1 Parent(s): 0902b38

Rolling average dynamic weight updating

Browse files
DeepDeformationMapRegistration/callbacks.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import tensorflow.keras.backend as K
3
+
4
+
5
+ class RollingAverageWeighting(tf.keras.callbacks.Callback):
6
+ def __init__(self, weights: list, loss_names: list, ref_loss: str, epoch_update):
7
+ super(RollingAverageWeighting, self).__init__()
8
+ assert len(weights) == len(loss_names)
9
+ self.weights = weights
10
+ self.loss_weights = dict()
11
+ for name, w in zip(loss_names, weights):
12
+ self.loss_weights[name] = w
13
+ self.epoch_update = epoch_update - 1 # Epoch is zero based
14
+ self.rolling_avg = dict()
15
+ self.ref_loss = ref_loss
16
+ loss_names.append(ref_loss)
17
+ for name in loss_names:
18
+ self.rolling_avg[name] = 0
19
+
20
+ def on_epoch_end(self, epoch, logs=None):
21
+ # Get the average loss for each loss function
22
+ if epoch > self.epoch_update:
23
+ # Updated loss weights
24
+ for i, name in enumerate(self.rolling_avg.keys()):
25
+ # avg[n] = avg[n-1] + 1/n * (new_val - avg[n-1]), where n is the size of the rolling avg
26
+ self.rolling_avg[name] += (1 / self.epoch_update) * (logs.get(name) - self.rolling_avg[name])
27
+ else:
28
+ for i, name in enumerate(self.rolling_avg.keys()):
29
+ self.rolling_avg[name] += logs.get(name)
30
+ if epoch == self.epoch_update: # Time to start updating the weights!
31
+ self.rolling_avg[name] /= self.epoch_update
32
+
33
+ if not epoch % self.epoch_update:
34
+ self.update_weights()
35
+
36
+ def update_weights(self):
37
+ new_weights = list()
38
+ for name in self.loss_weights.keys():
39
+ K.set_value(self.loss_weights[name], self.rolling_avg[self.ref_loss] / self.rolling_avg[name])
40
+ new_weights.append(self.rolling_avg[self.ref_loss] / self.rolling_avg[name])
41
+
42
+ out_str = ''
43
+ for name, val in zip(self.loss_weights.keys(), new_weights):
44
+ out_str += '{}: {:7.2f}\t'.format(name, val)
45
+ print('WEIGHTS UPDATE: ' + out_str)