Spaces:
Sleeping
Sleeping
Update networks.py
Browse files- networks.py +47 -35
networks.py
CHANGED
@@ -126,29 +126,33 @@ class TpsGridGen(nn.Module):
|
|
126 |
self.out_h = out_h
|
127 |
self.out_w = out_w
|
128 |
self.grid_size = grid_size
|
129 |
-
|
130 |
-
# Create grid
|
131 |
-
axis_coords = np.linspace(-1, 1, grid_size)
|
132 |
self.N = grid_size * grid_size
|
|
|
|
|
|
|
133 |
P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
|
|
|
|
|
|
|
|
|
134 |
|
135 |
-
|
136 |
-
self.
|
137 |
-
self.Li = self.compute_L_inverse(self.P_X_base, self.P_Y_base).unsqueeze(0)
|
138 |
|
139 |
-
#
|
140 |
grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h))
|
141 |
-
self.grid_X
|
142 |
-
self.grid_Y
|
143 |
|
144 |
def compute_L_inverse(self, X, Y):
|
145 |
-
N = X.size()
|
146 |
Xmat = X.expand(N, N)
|
147 |
Ymat = Y.expand(N, N)
|
148 |
P_dist_squared = torch.pow(Xmat - Xmat.transpose(0, 1), 2) + torch.pow(Ymat - Ymat.transpose(0, 1), 2)
|
149 |
-
P_dist_squared[P_dist_squared == 0] = 1
|
150 |
K = torch.mul(P_dist_squared, torch.log(P_dist_squared))
|
151 |
|
|
|
152 |
O = torch.FloatTensor(N, 1).fill_(1)
|
153 |
Z = torch.FloatTensor(3, 3).fill_(0)
|
154 |
P = torch.cat((O, X, Y), 1)
|
@@ -156,43 +160,51 @@ class TpsGridGen(nn.Module):
|
|
156 |
return torch.inverse(L)
|
157 |
|
158 |
def forward(self, theta):
|
159 |
-
batch_size = theta.size()
|
160 |
-
|
161 |
|
|
|
162 |
Q_X = theta[:, :self.N].contiguous().view(batch_size, self.N, 1)
|
163 |
Q_Y = theta[:, self.N:].contiguous().view(batch_size, self.N, 1)
|
164 |
-
Q_X = Q_X + self.
|
165 |
-
Q_Y = Q_Y + self.
|
166 |
|
167 |
# Compute weights
|
168 |
-
W_X = torch.bmm(self.Li[:, :self.N, :self.N].expand(batch_size,
|
169 |
-
W_Y = torch.bmm(self.Li[:, :self.N, :self.N].expand(batch_size,
|
|
|
|
|
|
|
|
|
170 |
|
171 |
-
#
|
172 |
-
points_X = self.
|
173 |
-
points_Y = self.
|
174 |
|
175 |
return torch.cat((points_X, points_Y), 3)
|
176 |
|
177 |
-
def
|
178 |
-
batch_size =
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
], 1)
|
|
|
183 |
|
184 |
-
|
185 |
-
dist_squared = torch.pow(delta
|
186 |
-
dist_squared[dist_squared == 0] = 1
|
187 |
U = torch.mul(dist_squared, torch.log(dist_squared))
|
188 |
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
194 |
|
195 |
-
|
|
|
196 |
|
197 |
class GMM(nn.Module):
|
198 |
def __init__(self, opt=None):
|
|
|
126 |
self.out_h = out_h
|
127 |
self.out_w = out_w
|
128 |
self.grid_size = grid_size
|
|
|
|
|
|
|
129 |
self.N = grid_size * grid_size
|
130 |
+
|
131 |
+
# Create regular grid of control points
|
132 |
+
axis_coords = np.linspace(-1, 1, grid_size)
|
133 |
P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
|
134 |
+
P_X = torch.FloatTensor(P_X.reshape(-1, 1)) # (N,1)
|
135 |
+
P_Y = torch.FloatTensor(P_Y.reshape(-1, 1)) # (N,1)
|
136 |
+
self.register_buffer('P_X', P_X)
|
137 |
+
self.register_buffer('P_Y', P_Y)
|
138 |
|
139 |
+
# Compute inverse matrix L^-1
|
140 |
+
self.register_buffer('Li', self.compute_L_inverse(P_X, P_Y))
|
|
|
141 |
|
142 |
+
# Create sampling grid
|
143 |
grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h))
|
144 |
+
self.register_buffer('grid_X', torch.FloatTensor(grid_X).unsqueeze(0).unsqueeze(3)) # (1,H,W,1)
|
145 |
+
self.register_buffer('grid_Y', torch.FloatTensor(grid_Y).unsqueeze(0).unsqueeze(3)) # (1,H,W,1)
|
146 |
|
147 |
def compute_L_inverse(self, X, Y):
|
148 |
+
N = X.size(0)
|
149 |
Xmat = X.expand(N, N)
|
150 |
Ymat = Y.expand(N, N)
|
151 |
P_dist_squared = torch.pow(Xmat - Xmat.transpose(0, 1), 2) + torch.pow(Ymat - Ymat.transpose(0, 1), 2)
|
152 |
+
P_dist_squared[P_dist_squared == 0] = 1 # Avoid log(0)
|
153 |
K = torch.mul(P_dist_squared, torch.log(P_dist_squared))
|
154 |
|
155 |
+
# Construct L matrix
|
156 |
O = torch.FloatTensor(N, 1).fill_(1)
|
157 |
Z = torch.FloatTensor(3, 3).fill_(0)
|
158 |
P = torch.cat((O, X, Y), 1)
|
|
|
160 |
return torch.inverse(L)
|
161 |
|
162 |
def forward(self, theta):
|
163 |
+
batch_size = theta.size(0)
|
164 |
+
device = theta.device
|
165 |
|
166 |
+
# Split theta into x and y components
|
167 |
Q_X = theta[:, :self.N].contiguous().view(batch_size, self.N, 1)
|
168 |
Q_Y = theta[:, self.N:].contiguous().view(batch_size, self.N, 1)
|
169 |
+
Q_X = Q_X + self.P_X.expand_as(Q_X)
|
170 |
+
Q_Y = Q_Y + self.P_Y.expand_as(Q_Y)
|
171 |
|
172 |
# Compute weights
|
173 |
+
W_X = torch.bmm(self.Li[:, :self.N, :self.N].expand(batch_size, -1, -1), Q_X)
|
174 |
+
W_Y = torch.bmm(self.Li[:, :self.N, :self.N].expand(batch_size, -1, -1), Q_Y)
|
175 |
+
|
176 |
+
# Repeat grid for batch processing
|
177 |
+
grid_X = self.grid_X.expand(batch_size, -1, -1, -1).to(device)
|
178 |
+
grid_Y = self.grid_Y.expand(batch_size, -1, -1, -1).to(device)
|
179 |
|
180 |
+
# Compute transformed coordinates
|
181 |
+
points_X = self.transform_points(grid_X, W_X, Q_X)
|
182 |
+
points_Y = self.transform_points(grid_Y, W_Y, Q_Y)
|
183 |
|
184 |
return torch.cat((points_X, points_Y), 3)
|
185 |
|
186 |
+
def transform_points(self, grid, W, Q):
|
187 |
+
batch_size, h, w, _ = grid.size()
|
188 |
+
|
189 |
+
# Compute distance between grid points and control points
|
190 |
+
grid_flat = grid.view(batch_size, -1, 1)
|
191 |
+
P = torch.cat([self.P_X, self.P_Y], 1).unsqueeze(0).expand(batch_size, -1, -1).to(grid.device)
|
192 |
+
delta = grid_flat - P
|
193 |
|
194 |
+
# Compute U (radial basis function)
|
195 |
+
dist_squared = torch.sum(torch.pow(delta, 2), 2, keepdim=True)
|
196 |
+
dist_squared[dist_squared == 0] = 1 # Avoid log(0)
|
197 |
U = torch.mul(dist_squared, torch.log(dist_squared))
|
198 |
|
199 |
+
# Compute affine + non-affine transformation
|
200 |
+
A = torch.cat([
|
201 |
+
torch.ones(batch_size, h*w, 1, device=grid.device),
|
202 |
+
grid_flat[:, :, 0:1],
|
203 |
+
grid_flat[:, :, 1:2]
|
204 |
+
], 2)
|
205 |
|
206 |
+
points = torch.bmm(A, Q.view(batch_size, 3, -1)) + torch.bmm(U, W.view(batch_size, self.N, -1))
|
207 |
+
return points.view(batch_size, h, w, 1)
|
208 |
|
209 |
class GMM(nn.Module):
|
210 |
def __init__(self, opt=None):
|