gaur3009 commited on
Commit
5bff76a
·
verified ·
1 Parent(s): a4a6754

Update networks.py

Browse files
Files changed (1) hide show
  1. networks.py +157 -65
networks.py CHANGED
@@ -8,46 +8,72 @@ import numpy as np
8
 
9
  class Options:
10
  def __init__(self):
11
- # Default values
12
  self.fine_height = 256
13
  self.fine_width = 192
 
 
14
  self.grid_size = 5
 
 
 
 
 
 
 
 
15
  self.use_dropout = False
16
- self.input_nc = 22
17
- self.input_nc_B = 1
18
- self.tom_input_nc = 26
19
- self.tom_output_nc = 4
20
 
21
  def weights_init_normal(m):
22
  classname = m.__class__.__name__
23
  if classname.find('Conv') != -1:
24
  init.normal_(m.weight.data, 0.0, 0.02)
25
  elif classname.find('Linear') != -1:
26
- init.normal(m.weight.data, 0.0, 0.02)
27
- elif classname.find('BatchNorm2d') != -1:
28
  init.normal_(m.weight.data, 1.0, 0.02)
29
  init.constant_(m.bias.data, 0.0)
30
 
31
  def init_weights(net, init_type='normal'):
32
- print('initialization method [%s]' % init_type)
33
  net.apply(weights_init_normal)
34
 
35
  class FeatureExtraction(nn.Module):
36
- def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_dropout=False):
37
  super(FeatureExtraction, self).__init__()
38
- downconv = nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1)
39
- model = [downconv, nn.ReLU(True), norm_layer(ngf)]
 
 
 
 
 
 
40
  for i in range(n_layers):
41
- in_ngf = 2**i * ngf if 2**i * ngf < 512 else 512
42
- out_ngf = 2**(i+1) * ngf if 2**i * ngf < 512 else 512
43
- downconv = nn.Conv2d(in_ngf, out_ngf, kernel_size=4, stride=2, padding=1)
44
- model += [downconv, nn.ReLU(True), norm_layer(out_ngf)]
45
- model += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(True)]
46
- model += [norm_layer(512)]
47
- model += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(True)]
48
- self.model = nn.Sequential(*model)
 
 
 
 
 
 
 
 
 
 
49
  init_weights(self.model)
50
 
 
 
 
51
  class FeatureL2Norm(nn.Module):
52
  def __init__(self):
53
  super(FeatureL2Norm, self).__init__()
@@ -83,7 +109,7 @@ class FeatureRegression(nn.Module):
83
  nn.ReLU(inplace=True),
84
  nn.Conv2d(128, 64, kernel_size=3, padding=1),
85
  nn.BatchNorm2d(64),
86
- nn.ReLU(inplace=True),
87
  )
88
  self.linear = nn.Linear(64 * 4 * 3, output_dim)
89
  self.tanh = nn.Tanh()
@@ -97,18 +123,18 @@ class FeatureRegression(nn.Module):
97
  class TpsGridGen(nn.Module):
98
  def __init__(self, out_h=256, out_w=192, grid_size=5):
99
  super(TpsGridGen, self).__init__()
100
- self.out_h, self.out_w = out_h, out_w
 
101
  self.grid_size = grid_size
102
 
103
  # Create grid
104
  axis_coords = np.linspace(-1, 1, grid_size)
105
  self.N = grid_size * grid_size
106
  P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
107
- P_X = torch.FloatTensor(P_X.reshape(-1, 1))
108
- P_Y = torch.FloatTensor(P_Y.reshape(-1, 1))
109
- self.P_X_base = P_X.clone()
110
- self.P_Y_base = P_Y.clone()
111
- self.Li = self.compute_L_inverse(P_X, P_Y).unsqueeze(0)
112
 
113
  # Grid for interpolation
114
  grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h))
@@ -117,10 +143,12 @@ class TpsGridGen(nn.Module):
117
 
118
  def compute_L_inverse(self, X, Y):
119
  N = X.size()[0]
120
- Xmat, Ymat = X.expand(N, N), Y.expand(N, N)
121
- P_dist_squared = torch.pow(Xmat-Xmat.transpose(0, 1), 2) + torch.pow(Ymat-Ymat.transpose(0, 1), 2)
 
122
  P_dist_squared[P_dist_squared == 0] = 1
123
  K = torch.mul(P_dist_squared, torch.log(P_dist_squared))
 
124
  O = torch.FloatTensor(N, 1).fill_(1)
125
  Z = torch.FloatTensor(3, 3).fill_(0)
126
  P = torch.cat((O, X, Y), 1)
@@ -128,22 +156,44 @@ class TpsGridGen(nn.Module):
128
  return torch.inverse(L)
129
 
130
  def forward(self, theta):
131
- theta = theta.contiguous()
132
  batch_size = theta.size()[0]
 
133
 
134
- # Split theta into point coordinates
135
  Q_X = theta[:, :self.N].contiguous().view(batch_size, self.N, 1)
136
  Q_Y = theta[:, self.N:].contiguous().view(batch_size, self.N, 1)
137
  Q_X = Q_X + self.P_X_base.expand_as(Q_X)
138
  Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y)
139
 
140
  # Compute weights
141
- W_X, W_Y = self.apply_theta(Q_X, Q_Y)
 
 
 
 
 
142
 
143
- # Calculate transformed grid
144
- points_X, points_Y = self.transform_points(W_X, W_Y)
145
  return torch.cat((points_X, points_Y), 3)
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  class GMM(nn.Module):
148
  def __init__(self, opt=None):
149
  super(GMM, self).__init__()
@@ -167,57 +217,49 @@ class GMM(nn.Module):
167
  grid = self.gridGen(theta)
168
  return grid, theta
169
 
170
- class UnetGenerator(nn.Module):
171
- def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.InstanceNorm2d):
172
- super(UnetGenerator, self).__init__()
173
- unet_block = UnetSkipConnectionBlock(
174
- ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
175
-
176
- for _ in range(num_downs - 5):
177
- unet_block = UnetSkipConnectionBlock(
178
- ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
179
-
180
- unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
181
- unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
182
- unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
183
-
184
- self.model = UnetSkipConnectionBlock(
185
- output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
186
-
187
- def forward(self, input):
188
- return self.model(input)
189
-
190
  class UnetSkipConnectionBlock(nn.Module):
191
- def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None,
192
- outermost=False, innermost=False, norm_layer=nn.InstanceNorm2d):
 
193
  super(UnetSkipConnectionBlock, self).__init__()
194
  self.outermost = outermost
195
  use_bias = norm_layer == nn.InstanceNorm2d
196
-
197
  if input_nc is None:
198
  input_nc = outer_nc
199
 
200
- downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
 
201
  downrelu = nn.LeakyReLU(0.2, True)
202
  downnorm = norm_layer(inner_nc)
203
  uprelu = nn.ReLU(True)
204
  upnorm = norm_layer(outer_nc)
205
 
206
  if outermost:
207
- upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
 
 
208
  down = [downconv]
209
  up = [uprelu, upconv, nn.Tanh()]
210
  model = down + [submodule] + up
211
  elif innermost:
212
- upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
 
 
213
  down = [downrelu, downconv]
214
  up = [uprelu, upconv, upnorm]
215
  model = down + up
216
  else:
217
- upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
 
 
218
  down = [downrelu, downconv, downnorm]
219
  up = [uprelu, upconv, upnorm]
220
- model = down + [submodule] + up
 
 
 
 
221
 
222
  self.model = nn.Sequential(*model)
223
 
@@ -227,17 +269,47 @@ class UnetSkipConnectionBlock(nn.Module):
227
  else:
228
  return torch.cat([x, self.model(x)], 1)
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  class TOM(nn.Module):
231
- """ Try-On Module """
232
  def __init__(self, opt=None):
233
  super(TOM, self).__init__()
234
  if opt is None:
235
  opt = Options()
236
 
237
- # Input: [agnostic(3) + warped_design(3) + warped_mask(1) + features(19)] = 26 channels
238
  self.unet = UnetGenerator(
239
  input_nc=opt.tom_input_nc,
240
- output_nc=opt.tom_output_nc, # [rendered(3) + mask(1)]
241
  num_downs=6,
242
  norm_layer=nn.InstanceNorm2d
243
  )
@@ -259,4 +331,24 @@ def load_checkpoint(model, checkpoint_path, strict=True):
259
  raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
260
 
261
  state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
262
- model.load_state_dict(state_dict, strict=strict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class Options:
10
  def __init__(self):
11
+ # Image dimensions
12
  self.fine_height = 256
13
  self.fine_width = 192
14
+
15
+ # GMM parameters
16
  self.grid_size = 5
17
+ self.input_nc = 22 # For extractionA
18
+ self.input_nc_B = 1 # For extractionB
19
+
20
+ # TOM parameters
21
+ self.tom_input_nc = 26 # 3(agnostic) + 3(warped) + 1(mask) + 19(features)
22
+ self.tom_output_nc = 4 # 3(rendered) + 1(composite mask)
23
+
24
+ # Training settings
25
  self.use_dropout = False
26
+ self.norm_layer = nn.BatchNorm2d
 
 
 
27
 
28
  def weights_init_normal(m):
29
  classname = m.__class__.__name__
30
  if classname.find('Conv') != -1:
31
  init.normal_(m.weight.data, 0.0, 0.02)
32
  elif classname.find('Linear') != -1:
33
+ init.normal_(m.weight.data, 0.0, 0.02)
34
+ elif classname.find('BatchNorm') != -1:
35
  init.normal_(m.weight.data, 1.0, 0.02)
36
  init.constant_(m.bias.data, 0.0)
37
 
38
  def init_weights(net, init_type='normal'):
39
+ print(f'initialization method [{init_type}]')
40
  net.apply(weights_init_normal)
41
 
42
  class FeatureExtraction(nn.Module):
43
+ def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
44
  super(FeatureExtraction, self).__init__()
45
+
46
+ # Build feature extraction layers
47
+ layers = [
48
+ nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1),
49
+ nn.ReLU(True),
50
+ norm_layer(ngf)
51
+ ]
52
+
53
  for i in range(n_layers):
54
+ in_channels = min(2**i * ngf, 512)
55
+ out_channels = min(2**(i+1) * ngf, 512)
56
+ layers += [
57
+ nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
58
+ nn.ReLU(True),
59
+ norm_layer(out_channels)
60
+ ]
61
+
62
+ # Final processing blocks
63
+ layers += [
64
+ nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
65
+ nn.ReLU(True),
66
+ norm_layer(512),
67
+ nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
68
+ nn.ReLU(True)
69
+ ]
70
+
71
+ self.model = nn.Sequential(*layers)
72
  init_weights(self.model)
73
 
74
+ def forward(self, x):
75
+ return self.model(x)
76
+
77
  class FeatureL2Norm(nn.Module):
78
  def __init__(self):
79
  super(FeatureL2Norm, self).__init__()
 
109
  nn.ReLU(inplace=True),
110
  nn.Conv2d(128, 64, kernel_size=3, padding=1),
111
  nn.BatchNorm2d(64),
112
+ nn.ReLU(inplace=True)
113
  )
114
  self.linear = nn.Linear(64 * 4 * 3, output_dim)
115
  self.tanh = nn.Tanh()
 
123
  class TpsGridGen(nn.Module):
124
  def __init__(self, out_h=256, out_w=192, grid_size=5):
125
  super(TpsGridGen, self).__init__()
126
+ self.out_h = out_h
127
+ self.out_w = out_w
128
  self.grid_size = grid_size
129
 
130
  # Create grid
131
  axis_coords = np.linspace(-1, 1, grid_size)
132
  self.N = grid_size * grid_size
133
  P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
134
+
135
+ self.P_X_base = torch.FloatTensor(P_X.reshape(-1, 1))
136
+ self.P_Y_base = torch.FloatTensor(P_Y.reshape(-1, 1))
137
+ self.Li = self.compute_L_inverse(self.P_X_base, self.P_Y_base).unsqueeze(0)
 
138
 
139
  # Grid for interpolation
140
  grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h))
 
143
 
144
  def compute_L_inverse(self, X, Y):
145
  N = X.size()[0]
146
+ Xmat = X.expand(N, N)
147
+ Ymat = Y.expand(N, N)
148
+ P_dist_squared = torch.pow(Xmat - Xmat.transpose(0, 1), 2) + torch.pow(Ymat - Ymat.transpose(0, 1), 2)
149
  P_dist_squared[P_dist_squared == 0] = 1
150
  K = torch.mul(P_dist_squared, torch.log(P_dist_squared))
151
+
152
  O = torch.FloatTensor(N, 1).fill_(1)
153
  Z = torch.FloatTensor(3, 3).fill_(0)
154
  P = torch.cat((O, X, Y), 1)
 
156
  return torch.inverse(L)
157
 
158
  def forward(self, theta):
 
159
  batch_size = theta.size()[0]
160
+ theta = theta.contiguous()
161
 
 
162
  Q_X = theta[:, :self.N].contiguous().view(batch_size, self.N, 1)
163
  Q_Y = theta[:, self.N:].contiguous().view(batch_size, self.N, 1)
164
  Q_X = Q_X + self.P_X_base.expand_as(Q_X)
165
  Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y)
166
 
167
  # Compute weights
168
+ W_X = torch.bmm(self.Li[:, :self.N, :self.N].expand(batch_size, self.N, self.N), Q_X)
169
+ W_Y = torch.bmm(self.Li[:, :self.N, :self.N].expand(batch_size, self.N, self.N), Q_Y)
170
+
171
+ # Transform points
172
+ points_X = self.apply_transformation(self.grid_X, W_X, Q_X)
173
+ points_Y = self.apply_transformation(self.grid_Y, W_Y, Q_Y)
174
 
 
 
175
  return torch.cat((points_X, points_Y), 3)
176
 
177
+ def apply_transformation(self, grid, W, Q):
178
+ batch_size = W.size()[0]
179
+ P = torch.cat([
180
+ self.P_X_base.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0, 4),
181
+ self.P_Y_base.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0, 4)
182
+ ], 1)
183
+
184
+ delta = grid.expand(batch_size, 1, self.out_h, self.out_w, 1, self.N) - P.expand(batch_size, 1, self.out_h, self.out_w, 1, self.N)
185
+ dist_squared = torch.pow(delta[:,0], 2) + torch.pow(delta[:,1], 2)
186
+ dist_squared[dist_squared == 0] = 1
187
+ U = torch.mul(dist_squared, torch.log(dist_squared))
188
+
189
+ points = torch.sum(torch.mul(W.expand(batch_size, 1, self.out_h, self.out_w, 1, self.N), U.unsqueeze(4)), 5)
190
+ points += torch.sum(Q.expand(batch_size, 1, self.out_h, self.out_w, 1, 3) *
191
+ torch.cat([grid.new_ones(batch_size, 1, self.out_h, self.out_w, 1),
192
+ grid.expand(batch_size, 1, self.out_h, self.out_w, 1),
193
+ grid.transpose(3,4).expand(batch_size, 1, self.out_h, self.out_w, 1)], 4), 5)
194
+
195
+ return points.squeeze(4)
196
+
197
  class GMM(nn.Module):
198
  def __init__(self, opt=None):
199
  super(GMM, self).__init__()
 
217
  grid = self.gridGen(theta)
218
  return grid, theta
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  class UnetSkipConnectionBlock(nn.Module):
221
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
222
+ submodule=None, outermost=False, innermost=False,
223
+ norm_layer=nn.InstanceNorm2d, use_dropout=False):
224
  super(UnetSkipConnectionBlock, self).__init__()
225
  self.outermost = outermost
226
  use_bias = norm_layer == nn.InstanceNorm2d
227
+
228
  if input_nc is None:
229
  input_nc = outer_nc
230
 
231
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
232
+ stride=2, padding=1, bias=use_bias)
233
  downrelu = nn.LeakyReLU(0.2, True)
234
  downnorm = norm_layer(inner_nc)
235
  uprelu = nn.ReLU(True)
236
  upnorm = norm_layer(outer_nc)
237
 
238
  if outermost:
239
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
240
+ kernel_size=4, stride=2,
241
+ padding=1)
242
  down = [downconv]
243
  up = [uprelu, upconv, nn.Tanh()]
244
  model = down + [submodule] + up
245
  elif innermost:
246
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
247
+ kernel_size=4, stride=2,
248
+ padding=1, bias=use_bias)
249
  down = [downrelu, downconv]
250
  up = [uprelu, upconv, upnorm]
251
  model = down + up
252
  else:
253
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
254
+ kernel_size=4, stride=2,
255
+ padding=1, bias=use_bias)
256
  down = [downrelu, downconv, downnorm]
257
  up = [uprelu, upconv, upnorm]
258
+
259
+ if use_dropout:
260
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
261
+ else:
262
+ model = down + [submodule] + up
263
 
264
  self.model = nn.Sequential(*model)
265
 
 
269
  else:
270
  return torch.cat([x, self.model(x)], 1)
271
 
272
+ class UnetGenerator(nn.Module):
273
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64,
274
+ norm_layer=nn.InstanceNorm2d, use_dropout=False):
275
+ super(UnetGenerator, self).__init__()
276
+
277
+ # Build UNet structure
278
+ unet_block = UnetSkipConnectionBlock(
279
+ ngf * 8, ngf * 8, input_nc=None, submodule=None,
280
+ norm_layer=norm_layer, innermost=True)
281
+
282
+ for i in range(num_downs - 5):
283
+ unet_block = UnetSkipConnectionBlock(
284
+ ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
285
+ norm_layer=norm_layer, use_dropout=use_dropout)
286
+
287
+ unet_block = UnetSkipConnectionBlock(
288
+ ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
289
+ norm_layer=norm_layer)
290
+ unet_block = UnetSkipConnectionBlock(
291
+ ngf * 2, ngf * 4, input_nc=None, submodule=unet_block,
292
+ norm_layer=norm_layer)
293
+ unet_block = UnetSkipConnectionBlock(
294
+ ngf, ngf * 2, input_nc=None, submodule=unet_block,
295
+ norm_layer=norm_layer)
296
+
297
+ self.model = UnetSkipConnectionBlock(
298
+ output_nc, ngf, input_nc=input_nc, submodule=unet_block,
299
+ outermost=True, norm_layer=norm_layer)
300
+
301
+ def forward(self, input):
302
+ return self.model(input)
303
+
304
  class TOM(nn.Module):
 
305
  def __init__(self, opt=None):
306
  super(TOM, self).__init__()
307
  if opt is None:
308
  opt = Options()
309
 
 
310
  self.unet = UnetGenerator(
311
  input_nc=opt.tom_input_nc,
312
+ output_nc=opt.tom_output_nc,
313
  num_downs=6,
314
  norm_layer=nn.InstanceNorm2d
315
  )
 
331
  raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
332
 
333
  state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
334
+
335
+ # Filter out unexpected keys
336
+ model_state_dict = model.state_dict()
337
+ filtered_state_dict = {k: v for k, v in state_dict.items()
338
+ if k in model_state_dict and v.size() == model_state_dict[k].size()}
339
+
340
+ # Load filtered state dict
341
+ model.load_state_dict(filtered_state_dict, strict=strict)
342
+
343
+ # Print warnings
344
+ missing = [k for k in model_state_dict if k not in state_dict]
345
+ unexpected = [k for k in state_dict if k not in model_state_dict]
346
+ size_mismatch = [k for k in state_dict
347
+ if k in model_state_dict and state_dict[k].size() != model_state_dict[k].size()]
348
+
349
+ if missing:
350
+ print(f"Missing keys: {missing}")
351
+ if unexpected:
352
+ print(f"Unexpected keys: {unexpected}")
353
+ if size_mismatch:
354
+ print(f"Size mismatch: {size_mismatch}")