DDMR / ddmr /layers.py
andreped's picture
Renamed module to ddmr
a27d55f
import os, sys
currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
import tensorflow.keras.layers as kl
import tensorflow.keras.backend as K
import tensorflow as tf
import numpy as np
from ddmr.utils.operators import soft_threshold
class UncertaintyWeighting(kl.Layer):
def __init__(self, num_loss_fns=1, num_reg_fns=0, loss_fns: list = [tf.keras.losses.mean_squared_error],
reg_fns: list = list(), prior_loss_w=[1.], manual_loss_w=[1.], prior_reg_w=[1.], manual_reg_w=[1.],
**kwargs):
assert isinstance(loss_fns, list) and (num_loss_fns == len(loss_fns) or len(loss_fns) == 1)
assert isinstance(reg_fns, list) and (num_reg_fns == len(reg_fns))
self.num_loss = num_loss_fns
if len(loss_fns) == 1 and self.num_loss > 1:
self.loss_fns = loss_fns * self.num_loss
else:
self.loss_fns = loss_fns
if len(prior_loss_w) == 1:
self.prior_loss_w = prior_loss_w * num_loss_fns
else:
self.prior_loss_w = prior_loss_w
self.prior_loss_w = np.log(self.prior_loss_w)
if len(manual_loss_w) == 1:
self.manual_loss_w = manual_loss_w * num_loss_fns
else:
self.manual_loss_w = manual_loss_w
self.num_reg = num_reg_fns
if self.num_reg != 0:
if len(reg_fns) == 1 and self.num_reg > 1:
self.reg_fns = reg_fns * self.num_reg
else:
self.reg_fns = reg_fns
self.is_placeholder = True
if self.num_reg != 0:
if len(prior_reg_w) == 1:
self.prior_reg_w = prior_reg_w * num_reg_fns
else:
self.prior_reg_w = prior_reg_w
self.prior_reg_w = np.log(self.prior_reg_w)
if len(manual_reg_w) == 1:
self.manual_reg_w = manual_reg_w * num_reg_fns
else:
self.manual_reg_w = manual_reg_w
else:
self.prior_reg_w = list()
self.manual_reg_w = list()
super(UncertaintyWeighting, self).__init__(**kwargs)
def build(self, input_shape=None):
self.log_loss_vars = self.add_weight(name='loss_log_vars', shape=(self.num_loss,),
initializer=tf.keras.initializers.Constant(self.prior_loss_w),
trainable=True)
self.loss_weights = tf.math.softmax(self.log_loss_vars, name='SM_loss_weights')
if self.num_reg != 0:
self.log_reg_vars = self.add_weight(name='loss_reg_vars', shape=(self.num_reg,),
initializer=tf.keras.initializers.Constant(self.prior_reg_w),
trainable=True)
if self.num_reg == 1:
self.reg_weights = tf.math.exp(self.log_reg_vars, name='EXP_reg_weights')
else:
self.reg_weights = tf.math.softmax(self.log_reg_vars, name='SM_reg_weights')
super(UncertaintyWeighting, self).build(input_shape)
def multi_loss(self, ys_true, ys_pred, regs_true, regs_pred):
loss_values = list()
loss_names_loss = list()
loss_names_reg = list()
for y_true, y_pred, loss_fn, man_w in zip(ys_true, ys_pred, self.loss_fns, self.manual_loss_w):
loss_values.append(tf.keras.backend.mean(man_w * loss_fn(y_true, y_pred)))
loss_names_loss.append(loss_fn.__name__)
loss_values = tf.convert_to_tensor(loss_values, dtype=tf.float32, name="step_loss_values")
loss = tf.math.multiply(self.loss_weights, loss_values, name='step_weighted_loss')
if self.num_reg != 0:
loss_reg = list()
for reg_true, reg_pred, reg_fn, man_w in zip(regs_true, regs_pred, self.reg_fns, self.manual_reg_w):
loss_reg.append(K.mean(man_w * reg_fn(reg_true, reg_pred)))
loss_names_reg.append(reg_fn.__name__)
reg_values = tf.convert_to_tensor(loss_reg, dtype=tf.float32, name="step_reg_values")
loss = loss + tf.math.multiply(self.reg_weights, reg_values, name='step_weighted_reg')
for i, loss_name in enumerate(loss_names_loss):
self.add_metric(tf.slice(self.loss_weights, [i], [1]), name='LOSS_WEIGHT_{}_{}'.format(i, loss_name),
aggregation='mean')
self.add_metric(tf.slice(loss_values, [i], [1]), name='LOSS_VALUE_{}_{}'.format(i, loss_name),
aggregation='mean')
if self.num_reg != 0:
for i, loss_name in enumerate(loss_names_reg):
self.add_metric(tf.slice(self.reg_weights, [i], [1]), name='REG_WEIGHT_{}_{}'.format(i, loss_name),
aggregation='mean')
self.add_metric(tf.slice(reg_values, [i], [1]), name='REG_VALUE_{}_{}'.format(i, loss_name),
aggregation='mean')
return K.sum(loss)
def call(self, inputs):
ys_true = inputs[:self.num_loss]
ys_pred = inputs[self.num_loss:self.num_loss*2]
reg_true = inputs[-self.num_reg*2:-self.num_reg]
reg_pred = inputs[-self.num_reg:] # The last terms are the regularization ones which have no GT
loss = self.multi_loss(ys_true, ys_pred, reg_true, reg_pred)
self.add_loss(loss, inputs=inputs)
# We won't actually use the output, but we need something for the TF graph
return K.concatenate(inputs, -1)
def get_config(self):
base_config = super(UncertaintyWeighting, self).get_config()
base_config['num_loss_fns'] = self.num_loss
base_config['num_reg_fns'] = self.num_reg
return base_config
def distance_map(coord1, coord2, dist, img_shape_w_channel=(64, 64, 1)):
max_dist = np.max(img_shape_w_channel)
dm_p = np.ones(img_shape_w_channel, np.float32)*max_dist
dm_n = np.ones(img_shape_w_channel, np.float32)*max_dist
for c1, c2, d in zip(coord1, coord2, dist):
dm_p[c1, c2, 0] = d if dm_p[c1, c2, 0] > d else dm_p[c1, c2]
d_n = 64. - max_dist
dm_n[c1, c2, 0] = d_n if dm_n[c1, c2, 0] > d_n else dm_n[c1, c2]
return dm_p/max_dist, dm_n/max_dist
def volume_to_ov_and_dm(in_volume: tf.Tensor):
# This one is run as a preprocessing step
def get_ov_projections_and_dm(volume):
# tf.sign returns -1, 0, 1 depending on the sign of the elements of the input (negative, zero, positive)
i, j, k, c = tf.where(volume > 0.0)
top = tf.sign(tf.reduce_sum(volume, axis=0), name='ov_top')
right = tf.sign(tf.reduce_sum(volume, axis=1), name='ov_right')
front = tf.sign(tf.reduce_sum(volume, axis=2), name='ov_front')
top_p, top_n = tf.py_func(distance_map, [j, k, i], tf.float32)
right_p, right_n = tf.py_func(distance_map, [i, k, j], tf.float32)
front_p, front_n = tf.py_func(distance_map, [i, j, k], tf.float32)
return [front, right, top], [front_p, front_n, top_p, top_n, right_p, right_n]
if len(in_volume.shape.as_list()) > 4:
return tf.map_fn(get_ov_projections_and_dm, in_volume, [tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32])
else:
return get_ov_projections_and_dm(in_volume)
def ov_and_dm_to_volume(ov_projections):
front, right, top = ov_projections
def get_volume(front: tf.Tensor, right: tf.Tensor, top: tf.Tensor):
front_shape = front.shape.as_list() # Assume (H, W, C)
top_shape = top.shape.as_list()
front_vol = tf.tile(tf.expand_dims(front, 2), [1, 1, top_shape[0], 1])
right_vol = tf.tile(tf.expand_dims(right, 1), [1, front_shape[1], 1, 1])
top_vol = tf.tile(tf.expand_dims(top, 0), [front_shape[0], 1, 1, 1])
sum = tf.add(tf.add(front_vol, right_vol), top_vol)
return soft_threshold(sum, 2., 'get_volume')
if len(front.shape.as_list()) > 3:
return tf.map_fn(lambda x: get_volume(x[0], x[1], x[2]), ov_projections, tf.float32)
else:
return get_volume(front, right, top)
# TODO: Recovering the coordinates from the distance maps to prevent artifacts
# will the gradients be backpropagated??!?!!?!?!