gregorkrzmanc's picture
.
e75a247
raw
history blame
9.08 kB
import tensorflow as tf
class Basic_OC_per_sample(object):
def __init__(self,
q_min,
s_b,
use_mean_x,
spect_supp=None, #None means same as noise
global_weight=False
):
self.q_min = q_min
self.s_b = s_b
self.use_mean_x = use_mean_x
self.global_weight = global_weight
if spect_supp is None:
spect_supp = s_b
self.spect_supp = spect_supp
self.valid=False #constants not created
#helper
def create_Ms(self, truth_idx):
self.Msel, self.Mnot, _ = CreateMidx(truth_idx, calc_m_not=True)
def set_input(self,
beta,
x,
d,
pll,
truth_idx,
object_weight,
is_spectator_weight,
calc_Ms=True,
):
self.valid=True
#used for pll and q
self.tanhsqbeta = tf.math.atanh(beta/(1.01))**2
self.beta_v = tf.debugging.check_numerics(beta,"OC: beta input")
self.d_v = tf.debugging.check_numerics(d,"OC: d input")
self.x_v = tf.debugging.check_numerics(x,"OC: x input")
self.pll_v = tf.debugging.check_numerics(pll,"OC: pll input")
self.sw_v = tf.debugging.check_numerics(is_spectator_weight,"OC: is_spectator_weight input")
object_weight = tf.debugging.check_numerics(object_weight,"OC: object_weight input")
self.isn_v = tf.where(truth_idx<0, tf.zeros_like(truth_idx,dtype='float32')+1., 0.)
#spectators do not participate in the potential losses
self.q_v = (self.tanhsqbeta + self.q_min)*tf.clip_by_value(1.-is_spectator_weight, 0., 1.)
if calc_Ms:
self.create_Ms(truth_idx)
if self.Msel is None:
self.valid=False
return
#if self.Msel.shape[0] < 2:#less than two objects - can be dangerous
# self.valid=False
# return
self.mask_k_m = SelectWithDefault(self.Msel, tf.zeros_like(beta)+1., 0.) #K x V-obj x 1
self.beta_k_m = SelectWithDefault(self.Msel, self.beta_v, 0.) #K x V-obj x 1
self.x_k_m = SelectWithDefault(self.Msel, self.x_v, 0.) #K x V-obj x C
self.q_k_m = SelectWithDefault(self.Msel, self.q_v, 0.)#K x V-obj x 1
self.d_k_m = SelectWithDefault(self.Msel, self.d_v, 0.)
self.alpha_k = tf.argmax(self.q_k_m, axis=1)# high beta and not spectator -> large q
self.beta_k = tf.gather_nd(self.beta_k_m, self.alpha_k, batch_dims=1) # K x 1
self.x_k = self._create_x_alpha_k() #K x C
self.q_k = tf.gather_nd(self.q_k_m, self.alpha_k, batch_dims=1) # K x 1
self.d_k = tf.gather_nd(self.d_k_m, self.alpha_k, batch_dims=1) # K x 1
#just a temp
ow_k_m = SelectWithDefault(self.Msel, object_weight, 0.)
self.ow_k = tf.gather_nd(ow_k_m, self.alpha_k, batch_dims=1) # K x 1
### the following functions should not modify any of the constants and must only depend on them
#for override through inheriting
def att_func(self,dsq_k_m):
return tf.math.log(tf.math.exp(1.)*dsq_k_m/2. + 1.)
def V_att_k(self):
'''
'''
K = tf.reduce_sum(tf.ones_like(self.q_k))
N_k = tf.reduce_sum(self.mask_k_m, axis=1)
dsq_k_m = self.calc_dsq_att() #K x V-obj x 1
sigma = self.weighted_d_k_m(dsq_k_m) #create gradients for all
dsq_k_m = tf.math.divide_no_nan(dsq_k_m, sigma + 1e-4)
V_att = self.att_func(dsq_k_m) * self.q_k_m * self.mask_k_m #K x V-obj x 1
V_att = self.q_k * tf.reduce_sum( V_att ,axis=1) #K x 1
# if self.global_weight:
# N_full = tf.reduce_sum(tf.ones_like(self.beta_v))
# V_att = K * tf.math.divide_no_nan(V_att, N_full+1e-3) #K x 1
# else:
V_att = tf.math.divide_no_nan(V_att, N_k+1e-3) #K x 1
#print(tf.reduce_mean(self.d_v),tf.reduce_max(self.d_v))
return V_att
def rep_func(self,dsq_k_v):
return tf.math.exp(-dsq_k_v/2.)
def weighted_d_k_m(self, dsq): # dsq K x V x 1
return tf.expand_dims(self.d_k, axis=1) # K x 1 x 1
def calc_dsq_att(self):
x_k_e = tf.expand_dims(self.x_k,axis=1)
dsq_k_m = tf.reduce_sum((self.x_k_m - x_k_e)**2, axis=-1, keepdims=True) #K x V-obj x 1
return dsq_k_m
def calc_dsq_rep(self):
dsq = tf.expand_dims(self.x_k, axis=1) - tf.expand_dims(self.x_v, axis=0) #K x V x C
dsq = tf.reduce_sum(dsq**2, axis=-1, keepdims=True) #K x V x 1
return dsq
def V_rep_k(self):
K = tf.reduce_sum(tf.ones_like(self.q_k))
N_notk = tf.reduce_sum(self.Mnot, axis=1)
#future remark: if this gets too large, one could use a kNN here
dsq = self.calc_dsq_rep()
# nogradbeta = tf.stop_gradient(self.beta_k_m)
#weight. tf.reduce_sum( tf.exp(-dsq) * d_v_e, , axis=1) / tf.reduce_sum( tf.exp(-dsq) )
sigma = self.weighted_d_k_m(dsq) #create gradients for all, but prefer k vertex
dsq = tf.math.divide_no_nan(dsq, sigma + 1e-4) #K x V x 1
V_rep = self.rep_func(dsq) * self.Mnot * tf.expand_dims(self.q_v,axis=0) #K x V x 1
V_rep = self.q_k * tf.reduce_sum(V_rep, axis=1) #K x 1
if self.global_weight:
N_full = tf.reduce_sum(tf.ones_like(self.beta_v))
V_rep = K * tf.math.divide_no_nan(V_rep, N_full+1e-3) #K x 1
else:
V_rep = tf.math.divide_no_nan(V_rep, N_notk+1e-3) #K x 1
return V_rep
def Pll_k(self):
tanhsqbeta = self.beta_v**2 #softer here
tanhsqbeta = tf.debugging.check_numerics(tanhsqbeta, "OC: pw b**2")
pw = tanhsqbeta * tf.clip_by_value((1.-tf.clip_by_value(self.isn_v+self.sw_v,0.,1.)),0.,1.) + 1e-6
pw = tf.debugging.check_numerics(pw, "OC: pw")
pll_k_m = SelectWithDefault(self.Msel, self.pll_v, 0.) #K x V_perobj x P
pw_k_m = SelectWithDefault(self.Msel, pw, 0.) #K x V-obj x P
pw_k_sum = tf.reduce_sum(pw_k_m, axis=1)
pw_k_sum = tf.where(pw_k_sum <= 0., 1e-2, pw_k_sum)
pll_k = tf.math.divide_no_nan(tf.reduce_sum(pll_k_m * pw_k_m, axis=1),
pw_k_sum )#K x P
return pll_k
def Beta_pen_k(self):
#use continuous max approximation through LSE
eps = 1e-3
beta_pen = 1. - eps * tf.reduce_logsumexp(self.beta_k_m/eps, axis=1)#sum over m
#for faster convergence
beta_pen += 1. - tf.clip_by_value(tf.reduce_sum(self.beta_k_m, axis=1), 0., 1)
beta_pen = tf.debugging.check_numerics(beta_pen, "OC: beta pen")
return beta_pen
def Noise_pen(self):
nsupp_v = self.beta_v * self.isn_v
nsupp = tf.math.divide_no_nan(tf.reduce_sum(nsupp_v),
tf.reduce_sum(self.isn_v)+1e-3) # nodim
specsupp_v = self.beta_v * self.sw_v
specsupp = tf.math.divide_no_nan(tf.reduce_sum(specsupp_v),
tf.reduce_sum(self.sw_v)+1e-3) # nodim
return self.s_b * nsupp + self.spect_supp * specsupp
# doesn't do anything in this implementation
def high_B_pen_k(self):
return 0.* self.beta_k
# override with more complex through inheritance
def pll_weight_k(self, ow_k, vatt_k, vrep_k):
return ow_k
def add_to_terms(self,
V_att,
V_rep,
Noise_pen,
B_pen,
pll,
high_B_pen
):
zero_tensor = tf.zeros_like(tf.reduce_mean(self.q_v,axis=0))
if not self.valid: # no objects
zero_payload = tf.zeros_like(tf.reduce_mean(self.pll_v,axis=0))
print('WARNING: no objects in sample, continue to next')
return zero_tensor, zero_tensor, zero_tensor, zero_tensor, zero_payload, zero_tensor
K = tf.reduce_sum(tf.ones_like(self.q_k)) # > 0
V_att_k = self.V_att_k()
V_rep_k = self.V_rep_k()
V_att += tf.reduce_sum(self.ow_k * V_att_k)/K
V_rep += tf.reduce_sum(self.ow_k * V_rep_k)/K
Noise_pen += self.Noise_pen()
B_pen += tf.reduce_sum(self.ow_k * self.Beta_pen_k())/K
pl_ow_k = self.pll_weight_k(self.ow_k, V_att_k, V_rep_k)
pll += tf.reduce_sum(pl_ow_k * self.Pll_k(),axis=0)/K
high_B_pen += tf.reduce_sum(self.ow_k *self.high_B_pen_k())/K
return V_att, V_rep, Noise_pen, B_pen, pll, high_B_pen