Spaces:
Sleeping
Sleeping
Update networks.py
Browse files- networks.py +15 -10
networks.py
CHANGED
@@ -6,6 +6,15 @@ from torchvision import models
|
|
6 |
import os
|
7 |
import numpy as np
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
def weights_init_normal(m):
|
10 |
classname = m.__class__.__name__
|
11 |
if classname.find('Conv') != -1:
|
@@ -143,6 +152,7 @@ class TpsGridGen(nn.Module):
|
|
143 |
super(TpsGridGen, self).__init__()
|
144 |
self.out_h, self.out_w = out_h, out_w
|
145 |
self.reg_factor = reg_factor
|
|
|
146 |
|
147 |
# create grid in numpy
|
148 |
self.grid = np.zeros([self.out_h, self.out_w, 3], dtype=np.float32)
|
@@ -173,7 +183,6 @@ class TpsGridGen(nn.Module):
|
|
173 |
def forward(self, theta):
|
174 |
warped_grid = self.apply_transformation(
|
175 |
theta, torch.cat((self.grid_X, self.grid_Y), 3))
|
176 |
-
|
177 |
return warped_grid
|
178 |
|
179 |
def compute_L_inverse(self, X, Y):
|
@@ -285,10 +294,6 @@ class TpsGridGen(nn.Module):
|
|
285 |
|
286 |
return torch.cat((points_X_prime, points_Y_prime), 3)
|
287 |
|
288 |
-
# Defines the Unet generator.
|
289 |
-
# |num_downs|: number of downsamplings in UNet. For example,
|
290 |
-
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
|
291 |
-
# at the bottleneck
|
292 |
class UnetGenerator(nn.Module):
|
293 |
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
294 |
norm_layer=nn.BatchNorm2d, use_dropout=False):
|
@@ -313,9 +318,6 @@ class UnetGenerator(nn.Module):
|
|
313 |
def forward(self, input):
|
314 |
return self.model(input)
|
315 |
|
316 |
-
# Defines the submodule with skip connection.
|
317 |
-
# X -------------------identity---------------------- X
|
318 |
-
# |-- downsampling -- |submodule| -- upsampling --|
|
319 |
class UnetSkipConnectionBlock(nn.Module):
|
320 |
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
321 |
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
@@ -464,9 +466,12 @@ class GicLoss(nn.Module):
|
|
464 |
class GMM(nn.Module):
|
465 |
""" Geometric Matching Module
|
466 |
"""
|
467 |
-
|
468 |
-
def __init__(self, opt):
|
469 |
super(GMM, self).__init__()
|
|
|
|
|
|
|
|
|
470 |
self.extractionA = FeatureExtraction(
|
471 |
22, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d)
|
472 |
self.extractionB = FeatureExtraction(
|
|
|
6 |
import os
|
7 |
import numpy as np
|
8 |
|
9 |
+
# Configuration class to hold all parameters
|
10 |
+
class Options:
|
11 |
+
def __init__(self):
|
12 |
+
# Default values
|
13 |
+
self.fine_height = 256
|
14 |
+
self.fine_width = 192
|
15 |
+
self.grid_size = 3
|
16 |
+
self.use_dropout = False
|
17 |
+
|
18 |
def weights_init_normal(m):
|
19 |
classname = m.__class__.__name__
|
20 |
if classname.find('Conv') != -1:
|
|
|
152 |
super(TpsGridGen, self).__init__()
|
153 |
self.out_h, self.out_w = out_h, out_w
|
154 |
self.reg_factor = reg_factor
|
155 |
+
self.grid_size = grid_size
|
156 |
|
157 |
# create grid in numpy
|
158 |
self.grid = np.zeros([self.out_h, self.out_w, 3], dtype=np.float32)
|
|
|
183 |
def forward(self, theta):
|
184 |
warped_grid = self.apply_transformation(
|
185 |
theta, torch.cat((self.grid_X, self.grid_Y), 3))
|
|
|
186 |
return warped_grid
|
187 |
|
188 |
def compute_L_inverse(self, X, Y):
|
|
|
294 |
|
295 |
return torch.cat((points_X_prime, points_Y_prime), 3)
|
296 |
|
|
|
|
|
|
|
|
|
297 |
class UnetGenerator(nn.Module):
|
298 |
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
299 |
norm_layer=nn.BatchNorm2d, use_dropout=False):
|
|
|
318 |
def forward(self, input):
|
319 |
return self.model(input)
|
320 |
|
|
|
|
|
|
|
321 |
class UnetSkipConnectionBlock(nn.Module):
|
322 |
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
323 |
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
|
|
466 |
class GMM(nn.Module):
|
467 |
""" Geometric Matching Module
|
468 |
"""
|
469 |
+
def __init__(self, opt=None):
|
|
|
470 |
super(GMM, self).__init__()
|
471 |
+
# Initialize default options if none provided
|
472 |
+
if opt is None:
|
473 |
+
opt = Options()
|
474 |
+
|
475 |
self.extractionA = FeatureExtraction(
|
476 |
22, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d)
|
477 |
self.extractionB = FeatureExtraction(
|