Spaces:
Sleeping
Sleeping
Update networks.py
Browse files- networks.py +157 -65
networks.py
CHANGED
@@ -8,46 +8,72 @@ import numpy as np
|
|
8 |
|
9 |
class Options:
|
10 |
def __init__(self):
|
11 |
-
#
|
12 |
self.fine_height = 256
|
13 |
self.fine_width = 192
|
|
|
|
|
14 |
self.grid_size = 5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
self.use_dropout = False
|
16 |
-
self.
|
17 |
-
self.input_nc_B = 1
|
18 |
-
self.tom_input_nc = 26
|
19 |
-
self.tom_output_nc = 4
|
20 |
|
21 |
def weights_init_normal(m):
|
22 |
classname = m.__class__.__name__
|
23 |
if classname.find('Conv') != -1:
|
24 |
init.normal_(m.weight.data, 0.0, 0.02)
|
25 |
elif classname.find('Linear') != -1:
|
26 |
-
init.
|
27 |
-
elif classname.find('
|
28 |
init.normal_(m.weight.data, 1.0, 0.02)
|
29 |
init.constant_(m.bias.data, 0.0)
|
30 |
|
31 |
def init_weights(net, init_type='normal'):
|
32 |
-
print('initialization method [
|
33 |
net.apply(weights_init_normal)
|
34 |
|
35 |
class FeatureExtraction(nn.Module):
|
36 |
-
def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d
|
37 |
super(FeatureExtraction, self).__init__()
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
for i in range(n_layers):
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
init_weights(self.model)
|
50 |
|
|
|
|
|
|
|
51 |
class FeatureL2Norm(nn.Module):
|
52 |
def __init__(self):
|
53 |
super(FeatureL2Norm, self).__init__()
|
@@ -83,7 +109,7 @@ class FeatureRegression(nn.Module):
|
|
83 |
nn.ReLU(inplace=True),
|
84 |
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
85 |
nn.BatchNorm2d(64),
|
86 |
-
nn.ReLU(inplace=True)
|
87 |
)
|
88 |
self.linear = nn.Linear(64 * 4 * 3, output_dim)
|
89 |
self.tanh = nn.Tanh()
|
@@ -97,18 +123,18 @@ class FeatureRegression(nn.Module):
|
|
97 |
class TpsGridGen(nn.Module):
|
98 |
def __init__(self, out_h=256, out_w=192, grid_size=5):
|
99 |
super(TpsGridGen, self).__init__()
|
100 |
-
self.out_h
|
|
|
101 |
self.grid_size = grid_size
|
102 |
|
103 |
# Create grid
|
104 |
axis_coords = np.linspace(-1, 1, grid_size)
|
105 |
self.N = grid_size * grid_size
|
106 |
P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
|
107 |
-
|
108 |
-
|
109 |
-
self.
|
110 |
-
self.
|
111 |
-
self.Li = self.compute_L_inverse(P_X, P_Y).unsqueeze(0)
|
112 |
|
113 |
# Grid for interpolation
|
114 |
grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h))
|
@@ -117,10 +143,12 @@ class TpsGridGen(nn.Module):
|
|
117 |
|
118 |
def compute_L_inverse(self, X, Y):
|
119 |
N = X.size()[0]
|
120 |
-
Xmat
|
121 |
-
|
|
|
122 |
P_dist_squared[P_dist_squared == 0] = 1
|
123 |
K = torch.mul(P_dist_squared, torch.log(P_dist_squared))
|
|
|
124 |
O = torch.FloatTensor(N, 1).fill_(1)
|
125 |
Z = torch.FloatTensor(3, 3).fill_(0)
|
126 |
P = torch.cat((O, X, Y), 1)
|
@@ -128,22 +156,44 @@ class TpsGridGen(nn.Module):
|
|
128 |
return torch.inverse(L)
|
129 |
|
130 |
def forward(self, theta):
|
131 |
-
theta = theta.contiguous()
|
132 |
batch_size = theta.size()[0]
|
|
|
133 |
|
134 |
-
# Split theta into point coordinates
|
135 |
Q_X = theta[:, :self.N].contiguous().view(batch_size, self.N, 1)
|
136 |
Q_Y = theta[:, self.N:].contiguous().view(batch_size, self.N, 1)
|
137 |
Q_X = Q_X + self.P_X_base.expand_as(Q_X)
|
138 |
Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y)
|
139 |
|
140 |
# Compute weights
|
141 |
-
W_X
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
-
# Calculate transformed grid
|
144 |
-
points_X, points_Y = self.transform_points(W_X, W_Y)
|
145 |
return torch.cat((points_X, points_Y), 3)
|
146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
class GMM(nn.Module):
|
148 |
def __init__(self, opt=None):
|
149 |
super(GMM, self).__init__()
|
@@ -167,57 +217,49 @@ class GMM(nn.Module):
|
|
167 |
grid = self.gridGen(theta)
|
168 |
return grid, theta
|
169 |
|
170 |
-
class UnetGenerator(nn.Module):
|
171 |
-
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.InstanceNorm2d):
|
172 |
-
super(UnetGenerator, self).__init__()
|
173 |
-
unet_block = UnetSkipConnectionBlock(
|
174 |
-
ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
|
175 |
-
|
176 |
-
for _ in range(num_downs - 5):
|
177 |
-
unet_block = UnetSkipConnectionBlock(
|
178 |
-
ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
179 |
-
|
180 |
-
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
181 |
-
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
182 |
-
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
183 |
-
|
184 |
-
self.model = UnetSkipConnectionBlock(
|
185 |
-
output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
|
186 |
-
|
187 |
-
def forward(self, input):
|
188 |
-
return self.model(input)
|
189 |
-
|
190 |
class UnetSkipConnectionBlock(nn.Module):
|
191 |
-
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
192 |
-
outermost=False, innermost=False,
|
|
|
193 |
super(UnetSkipConnectionBlock, self).__init__()
|
194 |
self.outermost = outermost
|
195 |
use_bias = norm_layer == nn.InstanceNorm2d
|
196 |
-
|
197 |
if input_nc is None:
|
198 |
input_nc = outer_nc
|
199 |
|
200 |
-
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
|
|
201 |
downrelu = nn.LeakyReLU(0.2, True)
|
202 |
downnorm = norm_layer(inner_nc)
|
203 |
uprelu = nn.ReLU(True)
|
204 |
upnorm = norm_layer(outer_nc)
|
205 |
|
206 |
if outermost:
|
207 |
-
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
|
|
|
|
208 |
down = [downconv]
|
209 |
up = [uprelu, upconv, nn.Tanh()]
|
210 |
model = down + [submodule] + up
|
211 |
elif innermost:
|
212 |
-
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
|
|
|
|
213 |
down = [downrelu, downconv]
|
214 |
up = [uprelu, upconv, upnorm]
|
215 |
model = down + up
|
216 |
else:
|
217 |
-
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
|
|
|
|
218 |
down = [downrelu, downconv, downnorm]
|
219 |
up = [uprelu, upconv, upnorm]
|
220 |
-
|
|
|
|
|
|
|
|
|
221 |
|
222 |
self.model = nn.Sequential(*model)
|
223 |
|
@@ -227,17 +269,47 @@ class UnetSkipConnectionBlock(nn.Module):
|
|
227 |
else:
|
228 |
return torch.cat([x, self.model(x)], 1)
|
229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
class TOM(nn.Module):
|
231 |
-
""" Try-On Module """
|
232 |
def __init__(self, opt=None):
|
233 |
super(TOM, self).__init__()
|
234 |
if opt is None:
|
235 |
opt = Options()
|
236 |
|
237 |
-
# Input: [agnostic(3) + warped_design(3) + warped_mask(1) + features(19)] = 26 channels
|
238 |
self.unet = UnetGenerator(
|
239 |
input_nc=opt.tom_input_nc,
|
240 |
-
output_nc=opt.tom_output_nc,
|
241 |
num_downs=6,
|
242 |
norm_layer=nn.InstanceNorm2d
|
243 |
)
|
@@ -259,4 +331,24 @@ def load_checkpoint(model, checkpoint_path, strict=True):
|
|
259 |
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
260 |
|
261 |
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
class Options:
|
10 |
def __init__(self):
|
11 |
+
# Image dimensions
|
12 |
self.fine_height = 256
|
13 |
self.fine_width = 192
|
14 |
+
|
15 |
+
# GMM parameters
|
16 |
self.grid_size = 5
|
17 |
+
self.input_nc = 22 # For extractionA
|
18 |
+
self.input_nc_B = 1 # For extractionB
|
19 |
+
|
20 |
+
# TOM parameters
|
21 |
+
self.tom_input_nc = 26 # 3(agnostic) + 3(warped) + 1(mask) + 19(features)
|
22 |
+
self.tom_output_nc = 4 # 3(rendered) + 1(composite mask)
|
23 |
+
|
24 |
+
# Training settings
|
25 |
self.use_dropout = False
|
26 |
+
self.norm_layer = nn.BatchNorm2d
|
|
|
|
|
|
|
27 |
|
28 |
def weights_init_normal(m):
|
29 |
classname = m.__class__.__name__
|
30 |
if classname.find('Conv') != -1:
|
31 |
init.normal_(m.weight.data, 0.0, 0.02)
|
32 |
elif classname.find('Linear') != -1:
|
33 |
+
init.normal_(m.weight.data, 0.0, 0.02)
|
34 |
+
elif classname.find('BatchNorm') != -1:
|
35 |
init.normal_(m.weight.data, 1.0, 0.02)
|
36 |
init.constant_(m.bias.data, 0.0)
|
37 |
|
38 |
def init_weights(net, init_type='normal'):
|
39 |
+
print(f'initialization method [{init_type}]')
|
40 |
net.apply(weights_init_normal)
|
41 |
|
42 |
class FeatureExtraction(nn.Module):
|
43 |
+
def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
44 |
super(FeatureExtraction, self).__init__()
|
45 |
+
|
46 |
+
# Build feature extraction layers
|
47 |
+
layers = [
|
48 |
+
nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1),
|
49 |
+
nn.ReLU(True),
|
50 |
+
norm_layer(ngf)
|
51 |
+
]
|
52 |
+
|
53 |
for i in range(n_layers):
|
54 |
+
in_channels = min(2**i * ngf, 512)
|
55 |
+
out_channels = min(2**(i+1) * ngf, 512)
|
56 |
+
layers += [
|
57 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
|
58 |
+
nn.ReLU(True),
|
59 |
+
norm_layer(out_channels)
|
60 |
+
]
|
61 |
+
|
62 |
+
# Final processing blocks
|
63 |
+
layers += [
|
64 |
+
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
|
65 |
+
nn.ReLU(True),
|
66 |
+
norm_layer(512),
|
67 |
+
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
|
68 |
+
nn.ReLU(True)
|
69 |
+
]
|
70 |
+
|
71 |
+
self.model = nn.Sequential(*layers)
|
72 |
init_weights(self.model)
|
73 |
|
74 |
+
def forward(self, x):
|
75 |
+
return self.model(x)
|
76 |
+
|
77 |
class FeatureL2Norm(nn.Module):
|
78 |
def __init__(self):
|
79 |
super(FeatureL2Norm, self).__init__()
|
|
|
109 |
nn.ReLU(inplace=True),
|
110 |
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
111 |
nn.BatchNorm2d(64),
|
112 |
+
nn.ReLU(inplace=True)
|
113 |
)
|
114 |
self.linear = nn.Linear(64 * 4 * 3, output_dim)
|
115 |
self.tanh = nn.Tanh()
|
|
|
123 |
class TpsGridGen(nn.Module):
|
124 |
def __init__(self, out_h=256, out_w=192, grid_size=5):
|
125 |
super(TpsGridGen, self).__init__()
|
126 |
+
self.out_h = out_h
|
127 |
+
self.out_w = out_w
|
128 |
self.grid_size = grid_size
|
129 |
|
130 |
# Create grid
|
131 |
axis_coords = np.linspace(-1, 1, grid_size)
|
132 |
self.N = grid_size * grid_size
|
133 |
P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
|
134 |
+
|
135 |
+
self.P_X_base = torch.FloatTensor(P_X.reshape(-1, 1))
|
136 |
+
self.P_Y_base = torch.FloatTensor(P_Y.reshape(-1, 1))
|
137 |
+
self.Li = self.compute_L_inverse(self.P_X_base, self.P_Y_base).unsqueeze(0)
|
|
|
138 |
|
139 |
# Grid for interpolation
|
140 |
grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h))
|
|
|
143 |
|
144 |
def compute_L_inverse(self, X, Y):
|
145 |
N = X.size()[0]
|
146 |
+
Xmat = X.expand(N, N)
|
147 |
+
Ymat = Y.expand(N, N)
|
148 |
+
P_dist_squared = torch.pow(Xmat - Xmat.transpose(0, 1), 2) + torch.pow(Ymat - Ymat.transpose(0, 1), 2)
|
149 |
P_dist_squared[P_dist_squared == 0] = 1
|
150 |
K = torch.mul(P_dist_squared, torch.log(P_dist_squared))
|
151 |
+
|
152 |
O = torch.FloatTensor(N, 1).fill_(1)
|
153 |
Z = torch.FloatTensor(3, 3).fill_(0)
|
154 |
P = torch.cat((O, X, Y), 1)
|
|
|
156 |
return torch.inverse(L)
|
157 |
|
158 |
def forward(self, theta):
|
|
|
159 |
batch_size = theta.size()[0]
|
160 |
+
theta = theta.contiguous()
|
161 |
|
|
|
162 |
Q_X = theta[:, :self.N].contiguous().view(batch_size, self.N, 1)
|
163 |
Q_Y = theta[:, self.N:].contiguous().view(batch_size, self.N, 1)
|
164 |
Q_X = Q_X + self.P_X_base.expand_as(Q_X)
|
165 |
Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y)
|
166 |
|
167 |
# Compute weights
|
168 |
+
W_X = torch.bmm(self.Li[:, :self.N, :self.N].expand(batch_size, self.N, self.N), Q_X)
|
169 |
+
W_Y = torch.bmm(self.Li[:, :self.N, :self.N].expand(batch_size, self.N, self.N), Q_Y)
|
170 |
+
|
171 |
+
# Transform points
|
172 |
+
points_X = self.apply_transformation(self.grid_X, W_X, Q_X)
|
173 |
+
points_Y = self.apply_transformation(self.grid_Y, W_Y, Q_Y)
|
174 |
|
|
|
|
|
175 |
return torch.cat((points_X, points_Y), 3)
|
176 |
|
177 |
+
def apply_transformation(self, grid, W, Q):
|
178 |
+
batch_size = W.size()[0]
|
179 |
+
P = torch.cat([
|
180 |
+
self.P_X_base.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0, 4),
|
181 |
+
self.P_Y_base.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0, 4)
|
182 |
+
], 1)
|
183 |
+
|
184 |
+
delta = grid.expand(batch_size, 1, self.out_h, self.out_w, 1, self.N) - P.expand(batch_size, 1, self.out_h, self.out_w, 1, self.N)
|
185 |
+
dist_squared = torch.pow(delta[:,0], 2) + torch.pow(delta[:,1], 2)
|
186 |
+
dist_squared[dist_squared == 0] = 1
|
187 |
+
U = torch.mul(dist_squared, torch.log(dist_squared))
|
188 |
+
|
189 |
+
points = torch.sum(torch.mul(W.expand(batch_size, 1, self.out_h, self.out_w, 1, self.N), U.unsqueeze(4)), 5)
|
190 |
+
points += torch.sum(Q.expand(batch_size, 1, self.out_h, self.out_w, 1, 3) *
|
191 |
+
torch.cat([grid.new_ones(batch_size, 1, self.out_h, self.out_w, 1),
|
192 |
+
grid.expand(batch_size, 1, self.out_h, self.out_w, 1),
|
193 |
+
grid.transpose(3,4).expand(batch_size, 1, self.out_h, self.out_w, 1)], 4), 5)
|
194 |
+
|
195 |
+
return points.squeeze(4)
|
196 |
+
|
197 |
class GMM(nn.Module):
|
198 |
def __init__(self, opt=None):
|
199 |
super(GMM, self).__init__()
|
|
|
217 |
grid = self.gridGen(theta)
|
218 |
return grid, theta
|
219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
class UnetSkipConnectionBlock(nn.Module):
|
221 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
222 |
+
submodule=None, outermost=False, innermost=False,
|
223 |
+
norm_layer=nn.InstanceNorm2d, use_dropout=False):
|
224 |
super(UnetSkipConnectionBlock, self).__init__()
|
225 |
self.outermost = outermost
|
226 |
use_bias = norm_layer == nn.InstanceNorm2d
|
227 |
+
|
228 |
if input_nc is None:
|
229 |
input_nc = outer_nc
|
230 |
|
231 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
232 |
+
stride=2, padding=1, bias=use_bias)
|
233 |
downrelu = nn.LeakyReLU(0.2, True)
|
234 |
downnorm = norm_layer(inner_nc)
|
235 |
uprelu = nn.ReLU(True)
|
236 |
upnorm = norm_layer(outer_nc)
|
237 |
|
238 |
if outermost:
|
239 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
240 |
+
kernel_size=4, stride=2,
|
241 |
+
padding=1)
|
242 |
down = [downconv]
|
243 |
up = [uprelu, upconv, nn.Tanh()]
|
244 |
model = down + [submodule] + up
|
245 |
elif innermost:
|
246 |
+
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
247 |
+
kernel_size=4, stride=2,
|
248 |
+
padding=1, bias=use_bias)
|
249 |
down = [downrelu, downconv]
|
250 |
up = [uprelu, upconv, upnorm]
|
251 |
model = down + up
|
252 |
else:
|
253 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
254 |
+
kernel_size=4, stride=2,
|
255 |
+
padding=1, bias=use_bias)
|
256 |
down = [downrelu, downconv, downnorm]
|
257 |
up = [uprelu, upconv, upnorm]
|
258 |
+
|
259 |
+
if use_dropout:
|
260 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
261 |
+
else:
|
262 |
+
model = down + [submodule] + up
|
263 |
|
264 |
self.model = nn.Sequential(*model)
|
265 |
|
|
|
269 |
else:
|
270 |
return torch.cat([x, self.model(x)], 1)
|
271 |
|
272 |
+
class UnetGenerator(nn.Module):
|
273 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
274 |
+
norm_layer=nn.InstanceNorm2d, use_dropout=False):
|
275 |
+
super(UnetGenerator, self).__init__()
|
276 |
+
|
277 |
+
# Build UNet structure
|
278 |
+
unet_block = UnetSkipConnectionBlock(
|
279 |
+
ngf * 8, ngf * 8, input_nc=None, submodule=None,
|
280 |
+
norm_layer=norm_layer, innermost=True)
|
281 |
+
|
282 |
+
for i in range(num_downs - 5):
|
283 |
+
unet_block = UnetSkipConnectionBlock(
|
284 |
+
ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
|
285 |
+
norm_layer=norm_layer, use_dropout=use_dropout)
|
286 |
+
|
287 |
+
unet_block = UnetSkipConnectionBlock(
|
288 |
+
ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
|
289 |
+
norm_layer=norm_layer)
|
290 |
+
unet_block = UnetSkipConnectionBlock(
|
291 |
+
ngf * 2, ngf * 4, input_nc=None, submodule=unet_block,
|
292 |
+
norm_layer=norm_layer)
|
293 |
+
unet_block = UnetSkipConnectionBlock(
|
294 |
+
ngf, ngf * 2, input_nc=None, submodule=unet_block,
|
295 |
+
norm_layer=norm_layer)
|
296 |
+
|
297 |
+
self.model = UnetSkipConnectionBlock(
|
298 |
+
output_nc, ngf, input_nc=input_nc, submodule=unet_block,
|
299 |
+
outermost=True, norm_layer=norm_layer)
|
300 |
+
|
301 |
+
def forward(self, input):
|
302 |
+
return self.model(input)
|
303 |
+
|
304 |
class TOM(nn.Module):
|
|
|
305 |
def __init__(self, opt=None):
|
306 |
super(TOM, self).__init__()
|
307 |
if opt is None:
|
308 |
opt = Options()
|
309 |
|
|
|
310 |
self.unet = UnetGenerator(
|
311 |
input_nc=opt.tom_input_nc,
|
312 |
+
output_nc=opt.tom_output_nc,
|
313 |
num_downs=6,
|
314 |
norm_layer=nn.InstanceNorm2d
|
315 |
)
|
|
|
331 |
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
332 |
|
333 |
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
334 |
+
|
335 |
+
# Filter out unexpected keys
|
336 |
+
model_state_dict = model.state_dict()
|
337 |
+
filtered_state_dict = {k: v for k, v in state_dict.items()
|
338 |
+
if k in model_state_dict and v.size() == model_state_dict[k].size()}
|
339 |
+
|
340 |
+
# Load filtered state dict
|
341 |
+
model.load_state_dict(filtered_state_dict, strict=strict)
|
342 |
+
|
343 |
+
# Print warnings
|
344 |
+
missing = [k for k in model_state_dict if k not in state_dict]
|
345 |
+
unexpected = [k for k in state_dict if k not in model_state_dict]
|
346 |
+
size_mismatch = [k for k in state_dict
|
347 |
+
if k in model_state_dict and state_dict[k].size() != model_state_dict[k].size()]
|
348 |
+
|
349 |
+
if missing:
|
350 |
+
print(f"Missing keys: {missing}")
|
351 |
+
if unexpected:
|
352 |
+
print(f"Unexpected keys: {unexpected}")
|
353 |
+
if size_mismatch:
|
354 |
+
print(f"Size mismatch: {size_mismatch}")
|