gaur3009 commited on
Commit
7335794
·
verified ·
1 Parent(s): a17268d

Update networks.py

Browse files
Files changed (1) hide show
  1. networks.py +15 -10
networks.py CHANGED
@@ -6,6 +6,15 @@ from torchvision import models
6
  import os
7
  import numpy as np
8
 
 
 
 
 
 
 
 
 
 
9
  def weights_init_normal(m):
10
  classname = m.__class__.__name__
11
  if classname.find('Conv') != -1:
@@ -143,6 +152,7 @@ class TpsGridGen(nn.Module):
143
  super(TpsGridGen, self).__init__()
144
  self.out_h, self.out_w = out_h, out_w
145
  self.reg_factor = reg_factor
 
146
 
147
  # create grid in numpy
148
  self.grid = np.zeros([self.out_h, self.out_w, 3], dtype=np.float32)
@@ -173,7 +183,6 @@ class TpsGridGen(nn.Module):
173
  def forward(self, theta):
174
  warped_grid = self.apply_transformation(
175
  theta, torch.cat((self.grid_X, self.grid_Y), 3))
176
-
177
  return warped_grid
178
 
179
  def compute_L_inverse(self, X, Y):
@@ -285,10 +294,6 @@ class TpsGridGen(nn.Module):
285
 
286
  return torch.cat((points_X_prime, points_Y_prime), 3)
287
 
288
- # Defines the Unet generator.
289
- # |num_downs|: number of downsamplings in UNet. For example,
290
- # if |num_downs| == 7, image of size 128x128 will become of size 1x1
291
- # at the bottleneck
292
  class UnetGenerator(nn.Module):
293
  def __init__(self, input_nc, output_nc, num_downs, ngf=64,
294
  norm_layer=nn.BatchNorm2d, use_dropout=False):
@@ -313,9 +318,6 @@ class UnetGenerator(nn.Module):
313
  def forward(self, input):
314
  return self.model(input)
315
 
316
- # Defines the submodule with skip connection.
317
- # X -------------------identity---------------------- X
318
- # |-- downsampling -- |submodule| -- upsampling --|
319
  class UnetSkipConnectionBlock(nn.Module):
320
  def __init__(self, outer_nc, inner_nc, input_nc=None,
321
  submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
@@ -464,9 +466,12 @@ class GicLoss(nn.Module):
464
  class GMM(nn.Module):
465
  """ Geometric Matching Module
466
  """
467
-
468
- def __init__(self, opt):
469
  super(GMM, self).__init__()
 
 
 
 
470
  self.extractionA = FeatureExtraction(
471
  22, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d)
472
  self.extractionB = FeatureExtraction(
 
6
  import os
7
  import numpy as np
8
 
9
+ # Configuration class to hold all parameters
10
+ class Options:
11
+ def __init__(self):
12
+ # Default values
13
+ self.fine_height = 256
14
+ self.fine_width = 192
15
+ self.grid_size = 3
16
+ self.use_dropout = False
17
+
18
  def weights_init_normal(m):
19
  classname = m.__class__.__name__
20
  if classname.find('Conv') != -1:
 
152
  super(TpsGridGen, self).__init__()
153
  self.out_h, self.out_w = out_h, out_w
154
  self.reg_factor = reg_factor
155
+ self.grid_size = grid_size
156
 
157
  # create grid in numpy
158
  self.grid = np.zeros([self.out_h, self.out_w, 3], dtype=np.float32)
 
183
  def forward(self, theta):
184
  warped_grid = self.apply_transformation(
185
  theta, torch.cat((self.grid_X, self.grid_Y), 3))
 
186
  return warped_grid
187
 
188
  def compute_L_inverse(self, X, Y):
 
294
 
295
  return torch.cat((points_X_prime, points_Y_prime), 3)
296
 
 
 
 
 
297
  class UnetGenerator(nn.Module):
298
  def __init__(self, input_nc, output_nc, num_downs, ngf=64,
299
  norm_layer=nn.BatchNorm2d, use_dropout=False):
 
318
  def forward(self, input):
319
  return self.model(input)
320
 
 
 
 
321
  class UnetSkipConnectionBlock(nn.Module):
322
  def __init__(self, outer_nc, inner_nc, input_nc=None,
323
  submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
 
466
  class GMM(nn.Module):
467
  """ Geometric Matching Module
468
  """
469
+ def __init__(self, opt=None):
 
470
  super(GMM, self).__init__()
471
+ # Initialize default options if none provided
472
+ if opt is None:
473
+ opt = Options()
474
+
475
  self.extractionA = FeatureExtraction(
476
  22, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d)
477
  self.extractionB = FeatureExtraction(