File size: 6,079 Bytes
74c6a32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import numpy as np
import tensorflow as tf
class ThinPlateSplines:
def __init__(self, ctrl_pts: tf.Tensor, target_pts: tf.Tensor, reg=0.0):
"""
:param ctrl_pts: [N, d] tensor of control d-dimensional points
:param target_pts: [N, d] tensor of target d-dimensional points
:param reg: regularization coefficient
"""
self.__ctrl_pts = ctrl_pts
self.__target_pts = target_pts
self.__reg = reg
self.__num_ctrl_pts = ctrl_pts.shape[0]
self.__dim = ctrl_pts.shape[1]
self.__compute_coeffs()
# self.__aff_params = self.__coeffs[self.__num_ctrl_pts:, ...] # Affine parameters of the TPS
self.__non_aff_paramms = self.__coeffs[:self.__num_ctrl_pts, ...] # Non-affine parameters of he TPS
def __compute_coeffs(self):
target_pts_aug = tf.concat([self.__target_pts,
tf.zeros([self.__dim + 1, self.__dim], dtype=self.__target_pts.dtype)],
axis=0)
# T = self.__make_T()
T_i = tf.cast(tf.linalg.inv(self.__make_T()), target_pts_aug.dtype)
self.__coeffs = tf.cast(tf.matmul(T_i, target_pts_aug), tf.float32)
def __make_T(self):
# cp: [K x 2] control points
# T: [(num_pts+dim+1) x (num_pts+dim+1)]
num_pts = self.__ctrl_pts.shape[0]
P = tf.concat([tf.ones([self.__num_ctrl_pts, 1], dtype=tf.float32), self.__ctrl_pts], axis=1)
zeros = np.zeros([self.__dim + 1, self.__dim + 1], dtype=np.float)
self.__K = self.__U_dist(self.__ctrl_pts)
alfa = tf.reduce_mean(self.__K)
self.__K = self.__K + tf.ones_like(self.__K) * tf.pow(alfa, 2) * self.__reg
# top = tf.concat([self.__K, P], axis=1)
# bottom = tf.concat([tf.transpose(P), zeros], axis=1)
return tf.concat([tf.concat([self.__K, P], axis=1), tf.concat([tf.transpose(P), zeros], axis=1)], axis=0)
def __U_dist(self, ctrl_pts, int_pts=None):
if int_pts is None:
dist = self.__pairwise_distance_equal(ctrl_pts) # Already squared!
else:
dist = self.__pairwise_distance_different(ctrl_pts, int_pts) # Already squared!
# U(x, y) = p_w_dist(x, y)^2 * log(p_w_dist(x, y)) (dist() > =0); 0 otw
if ctrl_pts.shape[-1] == 2:
u_dist = dist * tf.math.log(dist + 1e-6)
else:
# Src: https://github.com/vaipatel/morphops/blob/master/morphops/tps.py
# In particular, if k = 2, then U(r) = r^2 * log(r^2), else U(r) = r
u_dist = tf.sqrt(dist)
# tf.matrix_set_diag(u_dist, tf.constant(0, dtype=dist_sq.dtype))
# reg_term = self.__reg * tf.pow(alfa, 2) * tf.eye(self.__num_ctrl_pts)
return u_dist # + reg_term
def __pairwise_distance_sq(self, pts_a, pts_b):
with tf.variable_scope('pairwise_distance'):
if np.all(pts_a == pts_b):
# This implementation works better when doing the pairwise distance os a single set of points
pts_a_ = tf.reshape(pts_a, [-1, 1, 3])
pts_b_ = tf.reshape(pts_b, [1, -1, 3])
dist = tf.reduce_sum(tf.square(pts_a_ - pts_b_), 2) # squared pairwise distance
else:
# PwD^2= A_norm^2 - 2*A*B' + B_norm^2
pts_a_ = tf.reduce_sum(tf.square(pts_a), 1)
pts_b_ = tf.reduce_sum(tf.square(pts_b), 1)
pts_a_ = tf.expand_dims(pts_a_, 1)
pts_b_ = tf.expand_dims(pts_b_, 0)
pts_a_pts_b_ = tf.matmul(pts_a, pts_b, adjoint_b=True)
dist = pts_a_ - 2 * pts_a_pts_b_ + pts_b_
return tf.cast(dist, tf.float32)
@staticmethod
def __pairwise_distance_equal(pts):
# This implementation works better when doing the pairwise distance os a single set of points
dist = tf.reduce_sum(tf.square(tf.reshape(pts, [-1, 1, 3]) - tf.reshape(pts, [1, -1, 3])), 2) # squared pairwise distance
return tf.cast(dist, tf.float32)
@staticmethod
def __pairwise_distance_different(pts_a, pts_b):
pts_a_ = tf.reduce_sum(tf.square(pts_a), 1)
pts_b_ = tf.reduce_sum(tf.square(pts_b), 1)
pts_a_ = tf.expand_dims(pts_a_, 1)
pts_b_ = tf.expand_dims(pts_b_, 0)
pts_a_pts_b_ = tf.matmul(pts_a, pts_b, adjoint_b=True)
dist = pts_a_ - 2 * pts_a_pts_b_ + pts_b_
return tf.cast(dist, tf.float32)
def __lift_pts(self, int_pts: tf.Tensor, num_pts):
# int_pts: [N x 2], input points
# cp: [K x 2], control points
# pLift: [N x (3+K)], lifted input points
# u_dist = self.__U_dist(int_pts, self.__ctrl_pts)
int_pts_lift = tf.concat([self.__U_dist(int_pts, self.__ctrl_pts),
tf.ones([num_pts, 1], dtype=tf.float32),
int_pts], axis=1)
return int_pts_lift
@property
def bending_energy(self):
aux = tf.matmul(self.__non_aff_paramms, self.__K, transpose_a=True)
return tf.matmul(aux, self.__non_aff_paramms)
def interpolate(self, int_points): #, num_pts):
"""
:param int_points: [K, d] flattened d-points of a mesh
:return:
"""
num_pts = tf.shape(int_points)[0]
int_points_lift = self.__lift_pts(int_points, num_pts)
return tf.matmul(int_points_lift, self.__coeffs)
def __call__(self, int_points, num_pts, **kwargs):
return self.interpolate(int_points) # , num_pts)
def thin_plate_splines_batch(ctrl_pts: tf.Tensor, target_pts: tf.Tensor, int_pts: tf.Tensor, reg=0.0):
_batches = ctrl_pts.shape[0]
if tf.get_default_session() is not None:
print('DEBUG TIME')
def tps_sample(in_data):
cp, tp, ip = in_data
# _num_pts = ip.shape[0]
tps = ThinPlateSplines(cp, tp, reg)
interp = tps.interpolate(ip) # , _num_pts)
return interp
return tf.map_fn(tps_sample, elems=(ctrl_pts, target_pts, int_pts), dtype=tf.float32)
|