gaur3009 commited on
Commit
3ecccc5
·
verified ·
1 Parent(s): 0a3903a

Update networks.py

Browse files
Files changed (1) hide show
  1. networks.py +109 -376
networks.py CHANGED
@@ -6,7 +6,6 @@ from torchvision import models
6
  import os
7
  import numpy as np
8
 
9
- # Configuration class to hold all parameters
10
  class Options:
11
  def __init__(self):
12
  # Default values
@@ -14,6 +13,10 @@ class Options:
14
  self.fine_width = 192
15
  self.grid_size = 5
16
  self.use_dropout = False
 
 
 
 
17
 
18
  def weights_init_normal(m):
19
  classname = m.__class__.__name__
@@ -25,37 +28,9 @@ def weights_init_normal(m):
25
  init.normal_(m.weight.data, 1.0, 0.02)
26
  init.constant_(m.bias.data, 0.0)
27
 
28
- def weights_init_xavier(m):
29
- classname = m.__class__.__name__
30
- if classname.find('Conv') != -1:
31
- init.xavier_normal_(m.weight.data, gain=0.02)
32
- elif classname.find('Linear') != -1:
33
- init.xavier_normal_(m.weight.data, gain=0.02)
34
- elif classname.find('BatchNorm2d') != -1:
35
- init.normal_(m.weight.data, 1.0, 0.02)
36
- init.constant_(m.bias.data, 0.0)
37
-
38
- def weights_init_kaiming(m):
39
- classname = m.__class__.__name__
40
- if classname.find('Conv') != -1:
41
- init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
42
- elif classname.find('Linear') != -1:
43
- init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
44
- elif classname.find('BatchNorm2d') != -1:
45
- init.normal_(m.weight.data, 1.0, 0.02)
46
- init.constant_(m.bias.data, 0.0)
47
-
48
  def init_weights(net, init_type='normal'):
49
  print('initialization method [%s]' % init_type)
50
- if init_type == 'normal':
51
- net.apply(weights_init_normal)
52
- elif init_type == 'xavier':
53
- net.apply(weights_init_xavier)
54
- elif init_type == 'kaiming':
55
- net.apply(weights_init_kaiming)
56
- else:
57
- raise NotImplementedError(
58
- 'initialization method [%s] is not implemented' % init_type)
59
 
60
  class FeatureExtraction(nn.Module):
61
  def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_dropout=False):
@@ -65,30 +40,21 @@ class FeatureExtraction(nn.Module):
65
  for i in range(n_layers):
66
  in_ngf = 2**i * ngf if 2**i * ngf < 512 else 512
67
  out_ngf = 2**(i+1) * ngf if 2**i * ngf < 512 else 512
68
- downconv = nn.Conv2d(
69
- in_ngf, out_ngf, kernel_size=4, stride=2, padding=1)
70
- model += [downconv, nn.ReLU(True)]
71
- model += [norm_layer(out_ngf)]
72
- model += [nn.Conv2d(512, 512, kernel_size=3,
73
- stride=1, padding=1), nn.ReLU(True)]
74
  model += [norm_layer(512)]
75
- model += [nn.Conv2d(512, 512, kernel_size=3,
76
- stride=1, padding=1), nn.ReLU(True)]
77
-
78
  self.model = nn.Sequential(*model)
79
- init_weights(self.model, init_type='normal')
80
-
81
- def forward(self, x):
82
- return self.model(x)
83
 
84
- class FeatureL2Norm(torch.nn.Module):
85
  def __init__(self):
86
  super(FeatureL2Norm, self).__init__()
87
 
88
  def forward(self, feature):
89
  epsilon = 1e-6
90
- norm = torch.pow(torch.sum(torch.pow(feature, 2), 1) +
91
- epsilon, 0.5).unsqueeze(1).expand_as(feature)
92
  return torch.div(feature, norm)
93
 
94
  class FeatureCorrelation(nn.Module):
@@ -97,14 +63,10 @@ class FeatureCorrelation(nn.Module):
97
 
98
  def forward(self, feature_A, feature_B):
99
  b, c, h, w = feature_A.size()
100
- # reshape features for matrix multiplication
101
  feature_A = feature_A.transpose(2, 3).contiguous().view(b, c, h*w)
102
  feature_B = feature_B.view(b, c, h*w).transpose(1, 2)
103
- # perform matrix mult.
104
  feature_mul = torch.bmm(feature_B, feature_A)
105
- correlation_tensor = feature_mul.view(
106
- b, h, w, h*w).transpose(2, 3).transpose(1, 2)
107
- return correlation_tensor
108
 
109
  class FeatureRegression(nn.Module):
110
  def __init__(self, input_nc=512, output_dim=6):
@@ -128,238 +90,134 @@ class FeatureRegression(nn.Module):
128
 
129
  def forward(self, x):
130
  x = self.conv(x)
131
- # Change view() to reshape() and make contiguous
132
  x = x.contiguous().view(x.size(0), -1)
133
  x = self.linear(x)
134
- x = self.tanh(x)
135
- return x
136
-
137
- class AffineGridGen(nn.Module):
138
- def __init__(self, out_h=256, out_w=192, out_ch=3):
139
- super(AffineGridGen, self).__init__()
140
- self.out_h = out_h
141
- self.out_w = out_w
142
- self.out_ch = out_ch
143
-
144
- def forward(self, theta):
145
- theta = theta.contiguous()
146
- batch_size = theta.size()[0]
147
- out_size = torch.Size(
148
- (batch_size, self.out_ch, self.out_h, self.out_w))
149
- return F.affine_grid(theta, out_size)
150
 
151
  class TpsGridGen(nn.Module):
152
- def __init__(self, out_h=256, out_w=192, use_regular_grid=True, grid_size=3, reg_factor=0):
153
  super(TpsGridGen, self).__init__()
154
  self.out_h, self.out_w = out_h, out_w
155
- self.reg_factor = reg_factor
156
  self.grid_size = grid_size
157
-
158
- # create grid in numpy
159
- self.grid = np.zeros([self.out_h, self.out_w, 3], dtype=np.float32)
160
- # sampling grid with dim-0 coords (Y)
161
- self.grid_X, self.grid_Y = np.meshgrid(
162
- np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h))
163
- # grid_X,grid_Y: size [1,H,W,1,1]
164
- self.grid_X = torch.FloatTensor(self.grid_X).unsqueeze(0).unsqueeze(3)
165
- self.grid_Y = torch.FloatTensor(self.grid_Y).unsqueeze(0).unsqueeze(3)
166
-
167
- # initialize regular grid for control points P_i
168
- if use_regular_grid:
169
- axis_coords = np.linspace(-1, 1, grid_size)
170
- self.N = grid_size*grid_size
171
- P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
172
- P_X = np.reshape(P_X, (-1, 1)) # size (N,1)
173
- P_Y = np.reshape(P_Y, (-1, 1)) # size (N,1)
174
- P_X = torch.FloatTensor(P_X)
175
- P_Y = torch.FloatTensor(P_Y)
176
- self.P_X_base = P_X.clone()
177
- self.P_Y_base = P_Y.clone()
178
- self.Li = self.compute_L_inverse(P_X, P_Y).unsqueeze(0)
179
- self.P_X = P_X.unsqueeze(2).unsqueeze(
180
- 3).unsqueeze(4).transpose(0, 4)
181
- self.P_Y = P_Y.unsqueeze(2).unsqueeze(
182
- 3).unsqueeze(4).transpose(0, 4)
183
-
184
- def forward(self, theta):
185
- warped_grid = self.apply_transformation(
186
- theta, torch.cat((self.grid_X, self.grid_Y), 3))
187
- return warped_grid
188
 
189
  def compute_L_inverse(self, X, Y):
190
- N = X.size()[0] # num of points (along dim 0)
191
- # construct matrix K
192
- Xmat = X.expand(N, N)
193
- Ymat = Y.expand(N, N)
194
- P_dist_squared = torch.pow(
195
- Xmat-Xmat.transpose(0, 1), 2)+torch.pow(Ymat-Ymat.transpose(0, 1), 2)
196
- # make diagonal 1 to avoid NaN in log computation
197
  P_dist_squared[P_dist_squared == 0] = 1
198
  K = torch.mul(P_dist_squared, torch.log(P_dist_squared))
199
- # construct matrix L
200
  O = torch.FloatTensor(N, 1).fill_(1)
201
  Z = torch.FloatTensor(3, 3).fill_(0)
202
  P = torch.cat((O, X, Y), 1)
203
- L = torch.cat((torch.cat((K, P), 1), torch.cat(
204
- (P.transpose(0, 1), Z), 1)), 0)
205
- Li = torch.inverse(L)
206
- return Li
207
-
208
- def apply_transformation(self, theta, points):
209
- if theta.dim() == 2:
210
- theta = theta.unsqueeze(2).unsqueeze(3)
211
- # points should be in the [B,H,W,2] format,
212
- # where points[:,:,:,0] are the X coords
213
- # and points[:,:,:,1] are the Y coords
214
 
215
- # input are the corresponding control points P_i
 
216
  batch_size = theta.size()[0]
217
- # split theta into point coordinates
218
- Q_X = theta[:, :self.N, :, :].squeeze(3)
219
- Q_Y = theta[:, self.N:, :, :].squeeze(3)
 
220
  Q_X = Q_X + self.P_X_base.expand_as(Q_X)
221
  Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y)
 
 
 
 
 
 
 
222
 
223
- # get spatial dimensions of points
224
- points_b = points.size()[0]
225
- points_h = points.size()[1]
226
- points_w = points.size()[2]
227
-
228
- # repeat pre-defined control points along spatial dimensions of points to be transformed
229
- P_X = self.P_X.expand((1, points_h, points_w, 1, self.N))
230
- P_Y = self.P_Y.expand((1, points_h, points_w, 1, self.N))
231
-
232
- # compute weigths for non-linear part
233
- W_X = torch.bmm(self.Li[:, :self.N, :self.N].expand(
234
- (batch_size, self.N, self.N)), Q_X)
235
- W_Y = torch.bmm(self.Li[:, :self.N, :self.N].expand(
236
- (batch_size, self.N, self.N)), Q_Y)
237
- # reshape
238
- # W_X,W,Y: size [B,H,W,1,N]
239
- W_X = W_X.unsqueeze(3).unsqueeze(4).transpose(
240
- 1, 4).repeat(1, points_h, points_w, 1, 1)
241
- W_Y = W_Y.unsqueeze(3).unsqueeze(4).transpose(
242
- 1, 4).repeat(1, points_h, points_w, 1, 1)
243
- # compute weights for affine part
244
- A_X = torch.bmm(self.Li[:, self.N:, :self.N].expand(
245
- (batch_size, 3, self.N)), Q_X)
246
- A_Y = torch.bmm(self.Li[:, self.N:, :self.N].expand(
247
- (batch_size, 3, self.N)), Q_Y)
248
- # reshape
249
- # A_X,A,Y: size [B,H,W,1,3]
250
- A_X = A_X.unsqueeze(3).unsqueeze(4).transpose(
251
- 1, 4).repeat(1, points_h, points_w, 1, 1)
252
- A_Y = A_Y.unsqueeze(3).unsqueeze(4).transpose(
253
- 1, 4).repeat(1, points_h, points_w, 1, 1)
254
-
255
- # compute distance P_i - (grid_X,grid_Y)
256
- # grid is expanded in point dim 4, but not in batch dim 0, as points P_X,P_Y are fixed for all batch
257
- points_X_for_summation = points[:, :, :, 0].unsqueeze(
258
- 3).unsqueeze(4).expand(points[:, :, :, 0].size()+(1, self.N))
259
- points_Y_for_summation = points[:, :, :, 1].unsqueeze(
260
- 3).unsqueeze(4).expand(points[:, :, :, 1].size()+(1, self.N))
261
-
262
- if points_b == 1:
263
- delta_X = points_X_for_summation-P_X
264
- delta_Y = points_Y_for_summation-P_Y
265
- else:
266
- # use expanded P_X,P_Y in batch dimension
267
- delta_X = points_X_for_summation - \
268
- P_X.expand_as(points_X_for_summation)
269
- delta_Y = points_Y_for_summation - \
270
- P_Y.expand_as(points_Y_for_summation)
271
-
272
- dist_squared = torch.pow(delta_X, 2)+torch.pow(delta_Y, 2)
273
- # U: size [1,H,W,1,N]
274
- dist_squared[dist_squared == 0] = 1 # avoid NaN in log computation
275
- U = torch.mul(dist_squared, torch.log(dist_squared))
276
-
277
- # expand grid in batch dimension if necessary
278
- points_X_batch = points[:, :, :, 0].unsqueeze(3)
279
- points_Y_batch = points[:, :, :, 1].unsqueeze(3)
280
- if points_b == 1:
281
- points_X_batch = points_X_batch.expand(
282
- (batch_size,)+points_X_batch.size()[1:])
283
- points_Y_batch = points_Y_batch.expand(
284
- (batch_size,)+points_Y_batch.size()[1:])
285
-
286
- points_X_prime = A_X[:, :, :, :, 0] + \
287
- torch.mul(A_X[:, :, :, :, 1], points_X_batch) + \
288
- torch.mul(A_X[:, :, :, :, 2], points_Y_batch) + \
289
- torch.sum(torch.mul(W_X, U.expand_as(W_X)), 4)
290
-
291
- points_Y_prime = A_Y[:, :, :, :, 0] + \
292
- torch.mul(A_Y[:, :, :, :, 1], points_X_batch) + \
293
- torch.mul(A_Y[:, :, :, :, 2], points_Y_batch) + \
294
- torch.sum(torch.mul(W_Y, U.expand_as(W_Y)), 4)
295
 
296
- return torch.cat((points_X_prime, points_Y_prime), 3)
 
 
 
 
 
 
 
 
297
 
298
  class UnetGenerator(nn.Module):
299
- def __init__(self, input_nc, output_nc, num_downs, ngf=64,
300
- norm_layer=nn.BatchNorm2d, use_dropout=False):
301
  super(UnetGenerator, self).__init__()
302
- # construct unet structure
303
  unet_block = UnetSkipConnectionBlock(
304
  ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
305
- for i in range(num_downs - 5):
 
306
  unet_block = UnetSkipConnectionBlock(
307
- ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
308
- unet_block = UnetSkipConnectionBlock(
309
- ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
310
- unet_block = UnetSkipConnectionBlock(
311
- ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
312
- unet_block = UnetSkipConnectionBlock(
313
- ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
314
- unet_block = UnetSkipConnectionBlock(
315
  output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
316
 
317
- self.model = unet_block
318
-
319
  def forward(self, input):
320
  return self.model(input)
321
 
322
  class UnetSkipConnectionBlock(nn.Module):
323
- def __init__(self, outer_nc, inner_nc, input_nc=None,
324
- submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
325
  super(UnetSkipConnectionBlock, self).__init__()
326
  self.outermost = outermost
327
  use_bias = norm_layer == nn.InstanceNorm2d
328
-
329
  if input_nc is None:
330
  input_nc = outer_nc
331
- downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
332
- stride=2, padding=1, bias=use_bias)
333
  downrelu = nn.LeakyReLU(0.2, True)
334
  downnorm = norm_layer(inner_nc)
335
  uprelu = nn.ReLU(True)
336
  upnorm = norm_layer(outer_nc)
337
 
338
  if outermost:
339
- upsample = nn.Upsample(scale_factor=2, mode='bilinear')
340
- upconv = nn.Conv2d(inner_nc * 2, outer_nc,
341
- kernel_size=3, stride=1, padding=1, bias=use_bias)
342
  down = [downconv]
343
- up = [uprelu, upsample, upconv, upnorm]
344
  model = down + [submodule] + up
345
  elif innermost:
346
- upsample = nn.Upsample(scale_factor=2, mode='bilinear')
347
- upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3,
348
- stride=1, padding=1, bias=use_bias)
349
  down = [downrelu, downconv]
350
- up = [uprelu, upsample, upconv, upnorm]
351
  model = down + up
352
  else:
353
- upsample = nn.Upsample(scale_factor=2, mode='bilinear')
354
- upconv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=3,
355
- stride=1, padding=1, bias=use_bias)
356
  down = [downrelu, downconv, downnorm]
357
- up = [uprelu, upsample, upconv, upnorm]
358
-
359
- if use_dropout:
360
- model = down + [submodule] + up + [nn.Dropout(0.5)]
361
- else:
362
- model = down + [submodule] + up
363
 
364
  self.model = nn.Sequential(*model)
365
 
@@ -369,131 +227,27 @@ class UnetSkipConnectionBlock(nn.Module):
369
  else:
370
  return torch.cat([x, self.model(x)], 1)
371
 
372
- class Vgg19(nn.Module):
373
- def __init__(self, requires_grad=False):
374
- super(Vgg19, self).__init__()
375
- vgg_pretrained_features = models.vgg19(pretrained=True).features
376
- self.slice1 = torch.nn.Sequential()
377
- self.slice2 = torch.nn.Sequential()
378
- self.slice3 = torch.nn.Sequential()
379
- self.slice4 = torch.nn.Sequential()
380
- self.slice5 = torch.nn.Sequential()
381
- for x in range(2):
382
- self.slice1.add_module(str(x), vgg_pretrained_features[x])
383
- for x in range(2, 7):
384
- self.slice2.add_module(str(x), vgg_pretrained_features[x])
385
- for x in range(7, 12):
386
- self.slice3.add_module(str(x), vgg_pretrained_features[x])
387
- for x in range(12, 21):
388
- self.slice4.add_module(str(x), vgg_pretrained_features[x])
389
- for x in range(21, 30):
390
- self.slice5.add_module(str(x), vgg_pretrained_features[x])
391
- if not requires_grad:
392
- for param in self.parameters():
393
- param.requires_grad = False
394
-
395
- def forward(self, X):
396
- h_relu1 = self.slice1(X)
397
- h_relu2 = self.slice2(h_relu1)
398
- h_relu3 = self.slice3(h_relu2)
399
- h_relu4 = self.slice4(h_relu3)
400
- h_relu5 = self.slice5(h_relu4)
401
- out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
402
- return out
403
-
404
- class VGGLoss(nn.Module):
405
- def __init__(self, layids=None):
406
- super(VGGLoss, self).__init__()
407
- self.vgg = Vgg19()
408
- self.criterion = nn.L1Loss()
409
- self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
410
- self.layids = layids
411
-
412
- def forward(self, x, y):
413
- x_vgg, y_vgg = self.vgg(x), self.vgg(y)
414
- loss = 0
415
- if self.layids is None:
416
- self.layids = list(range(len(x_vgg)))
417
- for i in self.layids:
418
- loss += self.weights[i] * \
419
- self.criterion(x_vgg[i], y_vgg[i].detach())
420
- return loss
421
-
422
- class DT(nn.Module):
423
- def __init__(self):
424
- super(DT, self).__init__()
425
-
426
- def forward(self, x1, x2):
427
- dt = torch.abs(x1 - x2)
428
- return dt
429
-
430
- class DT2(nn.Module):
431
- def __init__(self):
432
- super(DT2, self).__init__()
433
-
434
- def forward(self, x1, y1, x2, y2):
435
- dt = torch.sqrt(torch.mul(x1 - x2, x1 - x2) +
436
- torch.mul(y1 - y2, y1 - y2))
437
- return dt
438
-
439
- class GicLoss(nn.Module):
440
- def __init__(self, opt):
441
- super(GicLoss, self).__init__()
442
- self.dT = DT()
443
- self.opt = opt
444
-
445
- def forward(self, grid):
446
- Gx = grid[:, :, :, 0]
447
- Gy = grid[:, :, :, 1]
448
- Gxcenter = Gx[:, 1:self.opt.fine_height - 1, 1:self.opt.fine_width - 1]
449
- Gxup = Gx[:, 0:self.opt.fine_height - 2, 1:self.opt.fine_width - 1]
450
- Gxdown = Gx[:, 2:self.opt.fine_height, 1:self.opt.fine_width - 1]
451
- Gxleft = Gx[:, 1:self.opt.fine_height - 1, 0:self.opt.fine_width - 2]
452
- Gxright = Gx[:, 1:self.opt.fine_height - 1, 2:self.opt.fine_width]
453
-
454
- Gycenter = Gy[:, 1:self.opt.fine_height - 1, 1:self.opt.fine_width - 1]
455
- Gyup = Gy[:, 0:self.opt.fine_height - 2, 1:self.opt.fine_width - 1]
456
- Gydown = Gy[:, 2:self.opt.fine_height, 1:self.opt.fine_width - 1]
457
- Gyleft = Gy[:, 1:self.opt.fine_height - 1, 0:self.opt.fine_width - 2]
458
- Gyright = Gy[:, 1:self.opt.fine_height - 1, 2:self.opt.fine_width]
459
-
460
- dtleft = self.dT(Gxleft, Gxcenter)
461
- dtright = self.dT(Gxright, Gxcenter)
462
- dtup = self.dT(Gyup, Gycenter)
463
- dtdown = self.dT(Gydown, Gycenter)
464
-
465
- return torch.sum(torch.abs(dtleft - dtright) + torch.abs(dtup - dtdown))
466
-
467
- class GMM(nn.Module):
468
- """ Geometric Matching Module
469
- """
470
  def __init__(self, opt=None):
471
- super(GMM, self).__init__()
472
- # Initialize default options if none provided
473
  if opt is None:
474
  opt = Options()
475
-
476
- self.extractionA = FeatureExtraction(
477
- 22, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d)
478
- self.extractionB = FeatureExtraction(
479
- 1, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d)
480
- self.l2norm = FeatureL2Norm()
481
- self.correlation = FeatureCorrelation()
482
- self.regression = FeatureRegression(
483
- input_nc=192, output_dim=2*opt.grid_size**2)
484
- self.gridGen = TpsGridGen(
485
- opt.fine_height, opt.fine_width, grid_size=opt.grid_size)
486
-
487
- def forward(self, inputA, inputB):
488
- featureA = self.extractionA(inputA)
489
- featureB = self.extractionB(inputB)
490
- featureA = self.l2norm(featureA)
491
- featureB = self.l2norm(featureB)
492
- correlation = self.correlation(featureA, featureB)
493
 
494
- theta = self.regression(correlation)
495
- grid = self.gridGen(theta)
496
- return grid, theta
 
 
 
497
 
498
  def save_checkpoint(model, save_path):
499
  if not os.path.exists(os.path.dirname(save_path)):
@@ -504,26 +258,5 @@ def load_checkpoint(model, checkpoint_path, strict=True):
504
  if not os.path.exists(checkpoint_path):
505
  raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
506
 
507
- # Load checkpoint with strict=False to ignore size mismatches
508
  state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
509
-
510
- # Filter out size-mismatched keys
511
- model_state_dict = model.state_dict()
512
- filtered_state_dict = {k: v for k, v in state_dict.items()
513
- if k in model_state_dict and v.size() == model_state_dict[k].size()}
514
-
515
- # Load the filtered state dict
516
- model.load_state_dict(filtered_state_dict, strict=strict)
517
-
518
- # Print warnings for mismatched keys
519
- missing_keys = [k for k in model_state_dict.keys() if k not in state_dict]
520
- unexpected_keys = [k for k in state_dict.keys() if k not in model_state_dict]
521
- size_mismatch_keys = [k for k in state_dict.keys()
522
- if k in model_state_dict and state_dict[k].size() != model_state_dict[k].size()]
523
-
524
- if missing_keys:
525
- print(f"Missing keys in checkpoint: {missing_keys}")
526
- if unexpected_keys:
527
- print(f"Unexpected keys in checkpoint: {unexpected_keys}")
528
- if size_mismatch_keys:
529
- print(f"Size mismatch for keys: {size_mismatch_keys}")
 
6
  import os
7
  import numpy as np
8
 
 
9
  class Options:
10
  def __init__(self):
11
  # Default values
 
13
  self.fine_width = 192
14
  self.grid_size = 5
15
  self.use_dropout = False
16
+ self.input_nc = 22
17
+ self.input_nc_B = 1
18
+ self.tom_input_nc = 26
19
+ self.tom_output_nc = 4
20
 
21
  def weights_init_normal(m):
22
  classname = m.__class__.__name__
 
28
  init.normal_(m.weight.data, 1.0, 0.02)
29
  init.constant_(m.bias.data, 0.0)
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def init_weights(net, init_type='normal'):
32
  print('initialization method [%s]' % init_type)
33
+ net.apply(weights_init_normal)
 
 
 
 
 
 
 
 
34
 
35
  class FeatureExtraction(nn.Module):
36
  def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_dropout=False):
 
40
  for i in range(n_layers):
41
  in_ngf = 2**i * ngf if 2**i * ngf < 512 else 512
42
  out_ngf = 2**(i+1) * ngf if 2**i * ngf < 512 else 512
43
+ downconv = nn.Conv2d(in_ngf, out_ngf, kernel_size=4, stride=2, padding=1)
44
+ model += [downconv, nn.ReLU(True), norm_layer(out_ngf)]
45
+ model += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(True)]
 
 
 
46
  model += [norm_layer(512)]
47
+ model += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(True)]
 
 
48
  self.model = nn.Sequential(*model)
49
+ init_weights(self.model)
 
 
 
50
 
51
+ class FeatureL2Norm(nn.Module):
52
  def __init__(self):
53
  super(FeatureL2Norm, self).__init__()
54
 
55
  def forward(self, feature):
56
  epsilon = 1e-6
57
+ norm = torch.pow(torch.sum(torch.pow(feature, 2), 1) + epsilon, 0.5).unsqueeze(1).expand_as(feature)
 
58
  return torch.div(feature, norm)
59
 
60
  class FeatureCorrelation(nn.Module):
 
63
 
64
  def forward(self, feature_A, feature_B):
65
  b, c, h, w = feature_A.size()
 
66
  feature_A = feature_A.transpose(2, 3).contiguous().view(b, c, h*w)
67
  feature_B = feature_B.view(b, c, h*w).transpose(1, 2)
 
68
  feature_mul = torch.bmm(feature_B, feature_A)
69
+ return feature_mul.view(b, h, w, h*w).transpose(2, 3).transpose(1, 2)
 
 
70
 
71
  class FeatureRegression(nn.Module):
72
  def __init__(self, input_nc=512, output_dim=6):
 
90
 
91
  def forward(self, x):
92
  x = self.conv(x)
 
93
  x = x.contiguous().view(x.size(0), -1)
94
  x = self.linear(x)
95
+ return self.tanh(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  class TpsGridGen(nn.Module):
98
+ def __init__(self, out_h=256, out_w=192, grid_size=5):
99
  super(TpsGridGen, self).__init__()
100
  self.out_h, self.out_w = out_h, out_w
 
101
  self.grid_size = grid_size
102
+
103
+ # Create grid
104
+ axis_coords = np.linspace(-1, 1, grid_size)
105
+ self.N = grid_size * grid_size
106
+ P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
107
+ P_X = torch.FloatTensor(P_X.reshape(-1, 1))
108
+ P_Y = torch.FloatTensor(P_Y.reshape(-1, 1))
109
+ self.P_X_base = P_X.clone()
110
+ self.P_Y_base = P_Y.clone()
111
+ self.Li = self.compute_L_inverse(P_X, P_Y).unsqueeze(0)
112
+
113
+ # Grid for interpolation
114
+ grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h))
115
+ self.grid_X = torch.FloatTensor(grid_X).unsqueeze(0).unsqueeze(3)
116
+ self.grid_Y = torch.FloatTensor(grid_Y).unsqueeze(0).unsqueeze(3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  def compute_L_inverse(self, X, Y):
119
+ N = X.size()[0]
120
+ Xmat, Ymat = X.expand(N, N), Y.expand(N, N)
121
+ P_dist_squared = torch.pow(Xmat-Xmat.transpose(0, 1), 2) + torch.pow(Ymat-Ymat.transpose(0, 1), 2)
 
 
 
 
122
  P_dist_squared[P_dist_squared == 0] = 1
123
  K = torch.mul(P_dist_squared, torch.log(P_dist_squared))
 
124
  O = torch.FloatTensor(N, 1).fill_(1)
125
  Z = torch.FloatTensor(3, 3).fill_(0)
126
  P = torch.cat((O, X, Y), 1)
127
+ L = torch.cat((torch.cat((K, P), 1), torch.cat((P.transpose(0, 1), Z), 1)), 0)
128
+ return torch.inverse(L)
 
 
 
 
 
 
 
 
 
129
 
130
+ def forward(self, theta):
131
+ theta = theta.contiguous()
132
  batch_size = theta.size()[0]
133
+
134
+ # Split theta into point coordinates
135
+ Q_X = theta[:, :self.N].contiguous().view(batch_size, self.N, 1)
136
+ Q_Y = theta[:, self.N:].contiguous().view(batch_size, self.N, 1)
137
  Q_X = Q_X + self.P_X_base.expand_as(Q_X)
138
  Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y)
139
+
140
+ # Compute weights
141
+ W_X, W_Y = self.apply_theta(Q_X, Q_Y)
142
+
143
+ # Calculate transformed grid
144
+ points_X, points_Y = self.transform_points(W_X, W_Y)
145
+ return torch.cat((points_X, points_Y), 3)
146
 
147
+ class GMM(nn.Module):
148
+ def __init__(self, opt=None):
149
+ super(GMM, self).__init__()
150
+ if opt is None:
151
+ opt = Options()
152
+
153
+ self.extractionA = FeatureExtraction(opt.input_nc)
154
+ self.extractionB = FeatureExtraction(opt.input_nc_B)
155
+ self.l2norm = FeatureL2Norm()
156
+ self.correlation = FeatureCorrelation()
157
+ self.regression = FeatureRegression(input_nc=192, output_dim=2*opt.grid_size**2)
158
+ self.gridGen = TpsGridGen(opt.fine_height, opt.fine_width, opt.grid_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ def forward(self, inputA, inputB):
161
+ featureA = self.extractionA(inputA)
162
+ featureB = self.extractionB(inputB)
163
+ featureA = self.l2norm(featureA)
164
+ featureB = self.l2norm(featureB)
165
+ correlation = self.correlation(featureA, featureB)
166
+ theta = self.regression(correlation)
167
+ grid = self.gridGen(theta)
168
+ return grid, theta
169
 
170
  class UnetGenerator(nn.Module):
171
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.InstanceNorm2d):
 
172
  super(UnetGenerator, self).__init__()
 
173
  unet_block = UnetSkipConnectionBlock(
174
  ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
175
+
176
+ for _ in range(num_downs - 5):
177
  unet_block = UnetSkipConnectionBlock(
178
+ ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
179
+
180
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
181
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
182
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
183
+
184
+ self.model = UnetSkipConnectionBlock(
 
185
  output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
186
 
 
 
187
  def forward(self, input):
188
  return self.model(input)
189
 
190
  class UnetSkipConnectionBlock(nn.Module):
191
+ def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
192
+ outermost=False, innermost=False, norm_layer=nn.InstanceNorm2d):
193
  super(UnetSkipConnectionBlock, self).__init__()
194
  self.outermost = outermost
195
  use_bias = norm_layer == nn.InstanceNorm2d
196
+
197
  if input_nc is None:
198
  input_nc = outer_nc
199
+
200
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
201
  downrelu = nn.LeakyReLU(0.2, True)
202
  downnorm = norm_layer(inner_nc)
203
  uprelu = nn.ReLU(True)
204
  upnorm = norm_layer(outer_nc)
205
 
206
  if outermost:
207
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
 
 
208
  down = [downconv]
209
+ up = [uprelu, upconv, nn.Tanh()]
210
  model = down + [submodule] + up
211
  elif innermost:
212
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
 
 
213
  down = [downrelu, downconv]
214
+ up = [uprelu, upconv, upnorm]
215
  model = down + up
216
  else:
217
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
 
 
218
  down = [downrelu, downconv, downnorm]
219
+ up = [uprelu, upconv, upnorm]
220
+ model = down + [submodule] + up
 
 
 
 
221
 
222
  self.model = nn.Sequential(*model)
223
 
 
227
  else:
228
  return torch.cat([x, self.model(x)], 1)
229
 
230
+ class TOM(nn.Module):
231
+ """ Try-On Module """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  def __init__(self, opt=None):
233
+ super(TOM, self).__init__()
 
234
  if opt is None:
235
  opt = Options()
236
+
237
+ # Input: [agnostic(3) + warped_design(3) + warped_mask(1) + features(19)] = 26 channels
238
+ self.unet = UnetGenerator(
239
+ input_nc=opt.tom_input_nc,
240
+ output_nc=opt.tom_output_nc, # [rendered(3) + mask(1)]
241
+ num_downs=6,
242
+ norm_layer=nn.InstanceNorm2d
243
+ )
 
 
 
 
 
 
 
 
 
 
244
 
245
+ def forward(self, x):
246
+ output = self.unet(x)
247
+ p_rendered, m_composite = torch.split(output, [3, 1], dim=1)
248
+ p_rendered = torch.tanh(p_rendered)
249
+ m_composite = torch.sigmoid(m_composite)
250
+ return p_rendered, m_composite
251
 
252
  def save_checkpoint(model, save_path):
253
  if not os.path.exists(os.path.dirname(save_path)):
 
258
  if not os.path.exists(checkpoint_path):
259
  raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
260
 
 
261
  state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
262
+ model.load_state_dict(state_dict, strict=strict)