gaur3009 commited on
Commit
bfbe0c2
·
verified ·
1 Parent(s): 5bff76a

Update networks.py

Browse files
Files changed (1) hide show
  1. networks.py +47 -35
networks.py CHANGED
@@ -126,29 +126,33 @@ class TpsGridGen(nn.Module):
126
  self.out_h = out_h
127
  self.out_w = out_w
128
  self.grid_size = grid_size
129
-
130
- # Create grid
131
- axis_coords = np.linspace(-1, 1, grid_size)
132
  self.N = grid_size * grid_size
 
 
 
133
  P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
 
 
 
 
134
 
135
- self.P_X_base = torch.FloatTensor(P_X.reshape(-1, 1))
136
- self.P_Y_base = torch.FloatTensor(P_Y.reshape(-1, 1))
137
- self.Li = self.compute_L_inverse(self.P_X_base, self.P_Y_base).unsqueeze(0)
138
 
139
- # Grid for interpolation
140
  grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h))
141
- self.grid_X = torch.FloatTensor(grid_X).unsqueeze(0).unsqueeze(3)
142
- self.grid_Y = torch.FloatTensor(grid_Y).unsqueeze(0).unsqueeze(3)
143
 
144
  def compute_L_inverse(self, X, Y):
145
- N = X.size()[0]
146
  Xmat = X.expand(N, N)
147
  Ymat = Y.expand(N, N)
148
  P_dist_squared = torch.pow(Xmat - Xmat.transpose(0, 1), 2) + torch.pow(Ymat - Ymat.transpose(0, 1), 2)
149
- P_dist_squared[P_dist_squared == 0] = 1
150
  K = torch.mul(P_dist_squared, torch.log(P_dist_squared))
151
 
 
152
  O = torch.FloatTensor(N, 1).fill_(1)
153
  Z = torch.FloatTensor(3, 3).fill_(0)
154
  P = torch.cat((O, X, Y), 1)
@@ -156,43 +160,51 @@ class TpsGridGen(nn.Module):
156
  return torch.inverse(L)
157
 
158
  def forward(self, theta):
159
- batch_size = theta.size()[0]
160
- theta = theta.contiguous()
161
 
 
162
  Q_X = theta[:, :self.N].contiguous().view(batch_size, self.N, 1)
163
  Q_Y = theta[:, self.N:].contiguous().view(batch_size, self.N, 1)
164
- Q_X = Q_X + self.P_X_base.expand_as(Q_X)
165
- Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y)
166
 
167
  # Compute weights
168
- W_X = torch.bmm(self.Li[:, :self.N, :self.N].expand(batch_size, self.N, self.N), Q_X)
169
- W_Y = torch.bmm(self.Li[:, :self.N, :self.N].expand(batch_size, self.N, self.N), Q_Y)
 
 
 
 
170
 
171
- # Transform points
172
- points_X = self.apply_transformation(self.grid_X, W_X, Q_X)
173
- points_Y = self.apply_transformation(self.grid_Y, W_Y, Q_Y)
174
 
175
  return torch.cat((points_X, points_Y), 3)
176
 
177
- def apply_transformation(self, grid, W, Q):
178
- batch_size = W.size()[0]
179
- P = torch.cat([
180
- self.P_X_base.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0, 4),
181
- self.P_Y_base.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0, 4)
182
- ], 1)
 
183
 
184
- delta = grid.expand(batch_size, 1, self.out_h, self.out_w, 1, self.N) - P.expand(batch_size, 1, self.out_h, self.out_w, 1, self.N)
185
- dist_squared = torch.pow(delta[:,0], 2) + torch.pow(delta[:,1], 2)
186
- dist_squared[dist_squared == 0] = 1
187
  U = torch.mul(dist_squared, torch.log(dist_squared))
188
 
189
- points = torch.sum(torch.mul(W.expand(batch_size, 1, self.out_h, self.out_w, 1, self.N), U.unsqueeze(4)), 5)
190
- points += torch.sum(Q.expand(batch_size, 1, self.out_h, self.out_w, 1, 3) *
191
- torch.cat([grid.new_ones(batch_size, 1, self.out_h, self.out_w, 1),
192
- grid.expand(batch_size, 1, self.out_h, self.out_w, 1),
193
- grid.transpose(3,4).expand(batch_size, 1, self.out_h, self.out_w, 1)], 4), 5)
 
194
 
195
- return points.squeeze(4)
 
196
 
197
  class GMM(nn.Module):
198
  def __init__(self, opt=None):
 
126
  self.out_h = out_h
127
  self.out_w = out_w
128
  self.grid_size = grid_size
 
 
 
129
  self.N = grid_size * grid_size
130
+
131
+ # Create regular grid of control points
132
+ axis_coords = np.linspace(-1, 1, grid_size)
133
  P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
134
+ P_X = torch.FloatTensor(P_X.reshape(-1, 1)) # (N,1)
135
+ P_Y = torch.FloatTensor(P_Y.reshape(-1, 1)) # (N,1)
136
+ self.register_buffer('P_X', P_X)
137
+ self.register_buffer('P_Y', P_Y)
138
 
139
+ # Compute inverse matrix L^-1
140
+ self.register_buffer('Li', self.compute_L_inverse(P_X, P_Y))
 
141
 
142
+ # Create sampling grid
143
  grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h))
144
+ self.register_buffer('grid_X', torch.FloatTensor(grid_X).unsqueeze(0).unsqueeze(3)) # (1,H,W,1)
145
+ self.register_buffer('grid_Y', torch.FloatTensor(grid_Y).unsqueeze(0).unsqueeze(3)) # (1,H,W,1)
146
 
147
  def compute_L_inverse(self, X, Y):
148
+ N = X.size(0)
149
  Xmat = X.expand(N, N)
150
  Ymat = Y.expand(N, N)
151
  P_dist_squared = torch.pow(Xmat - Xmat.transpose(0, 1), 2) + torch.pow(Ymat - Ymat.transpose(0, 1), 2)
152
+ P_dist_squared[P_dist_squared == 0] = 1 # Avoid log(0)
153
  K = torch.mul(P_dist_squared, torch.log(P_dist_squared))
154
 
155
+ # Construct L matrix
156
  O = torch.FloatTensor(N, 1).fill_(1)
157
  Z = torch.FloatTensor(3, 3).fill_(0)
158
  P = torch.cat((O, X, Y), 1)
 
160
  return torch.inverse(L)
161
 
162
  def forward(self, theta):
163
+ batch_size = theta.size(0)
164
+ device = theta.device
165
 
166
+ # Split theta into x and y components
167
  Q_X = theta[:, :self.N].contiguous().view(batch_size, self.N, 1)
168
  Q_Y = theta[:, self.N:].contiguous().view(batch_size, self.N, 1)
169
+ Q_X = Q_X + self.P_X.expand_as(Q_X)
170
+ Q_Y = Q_Y + self.P_Y.expand_as(Q_Y)
171
 
172
  # Compute weights
173
+ W_X = torch.bmm(self.Li[:, :self.N, :self.N].expand(batch_size, -1, -1), Q_X)
174
+ W_Y = torch.bmm(self.Li[:, :self.N, :self.N].expand(batch_size, -1, -1), Q_Y)
175
+
176
+ # Repeat grid for batch processing
177
+ grid_X = self.grid_X.expand(batch_size, -1, -1, -1).to(device)
178
+ grid_Y = self.grid_Y.expand(batch_size, -1, -1, -1).to(device)
179
 
180
+ # Compute transformed coordinates
181
+ points_X = self.transform_points(grid_X, W_X, Q_X)
182
+ points_Y = self.transform_points(grid_Y, W_Y, Q_Y)
183
 
184
  return torch.cat((points_X, points_Y), 3)
185
 
186
+ def transform_points(self, grid, W, Q):
187
+ batch_size, h, w, _ = grid.size()
188
+
189
+ # Compute distance between grid points and control points
190
+ grid_flat = grid.view(batch_size, -1, 1)
191
+ P = torch.cat([self.P_X, self.P_Y], 1).unsqueeze(0).expand(batch_size, -1, -1).to(grid.device)
192
+ delta = grid_flat - P
193
 
194
+ # Compute U (radial basis function)
195
+ dist_squared = torch.sum(torch.pow(delta, 2), 2, keepdim=True)
196
+ dist_squared[dist_squared == 0] = 1 # Avoid log(0)
197
  U = torch.mul(dist_squared, torch.log(dist_squared))
198
 
199
+ # Compute affine + non-affine transformation
200
+ A = torch.cat([
201
+ torch.ones(batch_size, h*w, 1, device=grid.device),
202
+ grid_flat[:, :, 0:1],
203
+ grid_flat[:, :, 1:2]
204
+ ], 2)
205
 
206
+ points = torch.bmm(A, Q.view(batch_size, 3, -1)) + torch.bmm(U, W.view(batch_size, self.N, -1))
207
+ return points.view(batch_size, h, w, 1)
208
 
209
  class GMM(nn.Module):
210
  def __init__(self, opt=None):