File size: 6,079 Bytes
74c6a32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import numpy as np
import tensorflow as tf


class ThinPlateSplines:
    def __init__(self, ctrl_pts: tf.Tensor, target_pts: tf.Tensor, reg=0.0):
        """

        :param ctrl_pts: [N, d] tensor of control d-dimensional points
        :param target_pts: [N, d] tensor of target d-dimensional points
        :param reg: regularization coefficient
        """
        self.__ctrl_pts = ctrl_pts
        self.__target_pts = target_pts
        self.__reg = reg
        self.__num_ctrl_pts = ctrl_pts.shape[0]
        self.__dim = ctrl_pts.shape[1]

        self.__compute_coeffs()
        # self.__aff_params = self.__coeffs[self.__num_ctrl_pts:, ...]   # Affine  parameters of the TPS
        self.__non_aff_paramms = self.__coeffs[:self.__num_ctrl_pts, ...]      # Non-affine parameters of  he TPS

    def __compute_coeffs(self):
        target_pts_aug = tf.concat([self.__target_pts,
                                    tf.zeros([self.__dim + 1, self.__dim], dtype=self.__target_pts.dtype)],
                                   axis=0)

        # T = self.__make_T()
        T_i = tf.cast(tf.linalg.inv(self.__make_T()), target_pts_aug.dtype)
        self.__coeffs = tf.cast(tf.matmul(T_i, target_pts_aug), tf.float32)

    def __make_T(self):
        # cp: [K x 2] control points
        # T: [(num_pts+dim+1) x (num_pts+dim+1)]
        num_pts = self.__ctrl_pts.shape[0]

        P = tf.concat([tf.ones([self.__num_ctrl_pts, 1], dtype=tf.float32), self.__ctrl_pts], axis=1)
        zeros = np.zeros([self.__dim + 1, self.__dim + 1], dtype=np.float)
        self.__K = self.__U_dist(self.__ctrl_pts)
        alfa = tf.reduce_mean(self.__K)

        self.__K = self.__K + tf.ones_like(self.__K) * tf.pow(alfa, 2) * self.__reg

        # top = tf.concat([self.__K, P], axis=1)
        # bottom = tf.concat([tf.transpose(P), zeros], axis=1)

        return tf.concat([tf.concat([self.__K, P], axis=1), tf.concat([tf.transpose(P), zeros], axis=1)], axis=0)

    def __U_dist(self, ctrl_pts, int_pts=None):
        if int_pts is None:
            dist = self.__pairwise_distance_equal(ctrl_pts)  # Already squared!
        else:
            dist = self.__pairwise_distance_different(ctrl_pts, int_pts)  # Already squared!


        # U(x, y) = p_w_dist(x, y)^2 * log(p_w_dist(x, y)) (dist() > =0); 0 otw
        if ctrl_pts.shape[-1] == 2:
            u_dist = dist * tf.math.log(dist + 1e-6)
        else:
            # Src: https://github.com/vaipatel/morphops/blob/master/morphops/tps.py
            # In particular, if k = 2, then  U(r) = r^2 * log(r^2), else U(r) = r
            u_dist = tf.sqrt(dist)
        # tf.matrix_set_diag(u_dist, tf.constant(0, dtype=dist_sq.dtype))
        # reg_term = self.__reg * tf.pow(alfa, 2) * tf.eye(self.__num_ctrl_pts)

        return u_dist # + reg_term

    def __pairwise_distance_sq(self, pts_a, pts_b):
        with tf.variable_scope('pairwise_distance'):
            if np.all(pts_a == pts_b):
                # This implementation works better when doing the pairwise distance os a single set of points
                pts_a_ = tf.reshape(pts_a, [-1, 1, 3])
                pts_b_ = tf.reshape(pts_b, [1, -1, 3])
                dist = tf.reduce_sum(tf.square(pts_a_ - pts_b_), 2)   # squared pairwise distance
            else:
                # PwD^2= A_norm^2 - 2*A*B' + B_norm^2
                pts_a_ = tf.reduce_sum(tf.square(pts_a), 1)
                pts_b_ = tf.reduce_sum(tf.square(pts_b), 1)

                pts_a_ = tf.expand_dims(pts_a_, 1)
                pts_b_ = tf.expand_dims(pts_b_, 0)

                pts_a_pts_b_ = tf.matmul(pts_a, pts_b, adjoint_b=True)

                dist = pts_a_ - 2 * pts_a_pts_b_ + pts_b_

            return tf.cast(dist, tf.float32)

    @staticmethod
    def __pairwise_distance_equal(pts):
        # This implementation works better when doing the pairwise distance os a single set of points
        dist = tf.reduce_sum(tf.square(tf.reshape(pts, [-1, 1, 3]) - tf.reshape(pts, [1, -1, 3])), 2)  # squared pairwise distance
        return tf.cast(dist, tf.float32)

    @staticmethod
    def __pairwise_distance_different(pts_a, pts_b):
        pts_a_ = tf.reduce_sum(tf.square(pts_a), 1)
        pts_b_ = tf.reduce_sum(tf.square(pts_b), 1)

        pts_a_ = tf.expand_dims(pts_a_, 1)
        pts_b_ = tf.expand_dims(pts_b_, 0)

        pts_a_pts_b_ = tf.matmul(pts_a, pts_b, adjoint_b=True)

        dist = pts_a_ - 2 * pts_a_pts_b_ + pts_b_
        return tf.cast(dist, tf.float32)

    def __lift_pts(self, int_pts: tf.Tensor, num_pts):
        # int_pts: [N x 2], input points
        # cp: [K x 2], control points
        # pLift: [N x (3+K)], lifted input points

        # u_dist = self.__U_dist(int_pts, self.__ctrl_pts)

        int_pts_lift = tf.concat([self.__U_dist(int_pts, self.__ctrl_pts),
                                 tf.ones([num_pts, 1], dtype=tf.float32),
                                 int_pts], axis=1)
        return int_pts_lift

    @property
    def bending_energy(self):
        aux = tf.matmul(self.__non_aff_paramms, self.__K, transpose_a=True)
        return tf.matmul(aux, self.__non_aff_paramms)

    def interpolate(self, int_points): #, num_pts):
        """

        :param int_points: [K, d] flattened d-points of a mesh
        :return:
        """
        num_pts = tf.shape(int_points)[0]
        int_points_lift = self.__lift_pts(int_points, num_pts)

        return tf.matmul(int_points_lift, self.__coeffs)

    def __call__(self, int_points, num_pts, **kwargs):
        return self.interpolate(int_points) # , num_pts)


def thin_plate_splines_batch(ctrl_pts: tf.Tensor, target_pts: tf.Tensor, int_pts: tf.Tensor, reg=0.0):
    _batches = ctrl_pts.shape[0]

    if tf.get_default_session() is not None:
        print('DEBUG TIME')

    def tps_sample(in_data):
        cp, tp, ip = in_data
        # _num_pts = ip.shape[0]
        tps = ThinPlateSplines(cp, tp, reg)
        interp = tps.interpolate(ip) # , _num_pts)
        return interp

    return tf.map_fn(tps_sample, elems=(ctrl_pts, target_pts, int_pts), dtype=tf.float32)