Spaces:
Build error
Build error
""" | |
Copyright (C) 2019 NVIDIA Corporation. All rights reserved. | |
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from models.networks.base_network import BaseNetwork | |
from models.networks.normalization import get_nonspade_norm_layer | |
from models.networks.architecture import ResnetBlock as ResnetBlock | |
from models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock | |
import os | |
#import data # only run from basic level! | |
import copy # deepcopy | |
class SPADEGenerator(BaseNetwork): | |
def modify_commandline_options(parser, is_train): | |
parser.set_defaults(norm_G='spectralspadesyncbatch3x3') | |
parser.add_argument('--num_upsampling_layers', | |
choices=('normal', 'more', 'most'), default='normal', | |
help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator") | |
return parser | |
def __init__(self, opt): | |
super().__init__() | |
self.opt = opt | |
nf = opt.ngf | |
self.sw, self.sh, self.scale_ratio = self.compute_latent_vector_size(opt) | |
if opt.use_vae: | |
# In case of VAE, we will sample from random z vector | |
self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh) | |
else: | |
# Otherwise, we make the network deterministic by starting with | |
# downsampled segmentation map instead of random z | |
self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1) | |
self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt) | |
self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt) | |
self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt) | |
# 20200211 test 4x with only 3 stage | |
self.ups = nn.ModuleList([ | |
SPADEResnetBlock(16 * nf, 8 * nf, opt), | |
SPADEResnetBlock(8 * nf, 4 * nf, opt), | |
SPADEResnetBlock(4 * nf, 2 * nf, opt), | |
SPADEResnetBlock(2 * nf, 1 * nf, opt) # here | |
]) | |
self.to_rgbs = nn.ModuleList([ | |
nn.Conv2d(8 * nf, 3, 3, padding=1), | |
nn.Conv2d(4 * nf, 3, 3, padding=1), | |
nn.Conv2d(2 * nf, 3, 3, padding=1), | |
nn.Conv2d(1 * nf, 3, 3, padding=1) # here | |
]) | |
self.up = nn.Upsample(scale_factor=2) | |
# 20200309 interface for flexible encoder design | |
# and mid-level loss control! | |
# For basic network, it's just a 16x downsampling | |
def encode(self, input): | |
h, w = input.size()[-2:] | |
sh, sw = h//2**self.scale_ratio, w//2**self.scale_ratio | |
x = F.interpolate(input, size=(sh, sw)) | |
return self.fc(x) # 20200310: Merge fc into encoder | |
def compute_latent_vector_size(self, opt): | |
if opt.num_upsampling_layers == 'normal': | |
num_up_layers = 5 | |
elif opt.num_upsampling_layers == 'more': | |
num_up_layers = 6 | |
elif opt.num_upsampling_layers == 'most': | |
num_up_layers = 7 | |
else: | |
raise ValueError('opt.num_upsampling_layers [%s] not recognized' % | |
opt.num_upsampling_layers) | |
# 20200211 Yang Lingbo with respect to phase | |
scale_ratio = num_up_layers | |
#scale_ratio = 4 #here | |
sw = opt.crop_size // (2**num_up_layers) | |
sh = round(sw / opt.aspect_ratio) | |
return sw, sh, scale_ratio | |
def forward(self, input, seg=None): | |
''' | |
20200307: Dangerous Change | |
Add separable forward to allow different | |
input and segmentation maps... | |
To return to original, simply add | |
seg = input at begining, and disable the seg parameter. | |
20200308: A more elegant solution: | |
@ Allow forward to take default parameters. | |
@ Allow customizable input encoding | |
20200310: Merge fc into encode, since encoder directly outputs processed feature map. | |
[TODO] @ Allow customizable segmap encoding? | |
''' | |
if seg is None: | |
seg = input # Interesting change... | |
# For basic generator, 16x downsampling. | |
# 20200310: Merge fc into encoder | |
x = self.encode(input) | |
#print(x.shape, input.shape, seg.shape) | |
x = self.head_0(x, seg) | |
x = self.up(x) | |
x = self.G_middle_0(x, seg) | |
x = self.G_middle_1(x, seg) | |
if self.opt.is_test: | |
phase = len(self.to_rgbs) | |
else: | |
phase = self.opt.train_phase+1 | |
for i in range(phase): | |
x = self.up(x) | |
x = self.ups[i](x, seg) | |
x = self.to_rgbs[phase-1](F.leaky_relu(x, 2e-1)) | |
x = torch.tanh(x) | |
return x | |
def mixed_guidance_forward(self, input, seg=None, n=0, mode='progressive'): | |
''' | |
mixed_forward: input and seg are different images | |
For the first n levels (including encoder) | |
we use input, for the rest we use seg. | |
If mode = 'progressive', the output's like: AAABBB | |
If mode = 'one_plug', the output's like: AAABAA | |
If mode = 'one_ablate', the output's like: BBBABB | |
''' | |
if seg is None: | |
return self.forward(input) | |
if self.opt.is_test: | |
phase = len(self.to_rgbs) | |
else: | |
phase = self.opt.train_phase+1 | |
if mode == 'progressive': | |
n = max(min(n, 4 + phase), 0) | |
guide_list = [input] * n + [seg] * (4+phase-n) | |
elif mode == 'one_plug': | |
n = max(min(n, 4 + phase-1), 0) | |
guide_list = [seg] * (4+phase) | |
guide_list[n] = input | |
elif mode == 'one_ablate': | |
if n > 3+phase: | |
return self.forward(input) | |
guide_list = [input] * (4+phase) | |
guide_list[n] = seg | |
x = self.encode(guide_list[0]) | |
x = self.head_0(x, guide_list[1]) | |
x = self.up(x) | |
x = self.G_middle_0(x, guide_list[2]) | |
x = self.G_middle_1(x, guide_list[3]) | |
for i in range(phase): | |
x = self.up(x) | |
x = self.ups[i](x, guide_list[4+i]) | |
x = self.to_rgbs[phase-1](F.leaky_relu(x, 2e-1)) | |
x = torch.tanh(x) | |
return x | |
class HiFaceGANGenerator(SPADEGenerator): | |
def __init__(self, opt): | |
super().__init__(opt) | |
self.opt = opt | |
nf = opt.ngf | |
self.sw, self.sh, self.scale_ratio = self.compute_latent_vector_size(opt) | |
if opt.use_vae: | |
# In case of VAE, we will sample from random z vector | |
self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh) | |
else: | |
# Otherwise, we make the network deterministic by starting with | |
# downsampled segmentation map instead of random z | |
self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1) | |
self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 16 * nf) | |
self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 16 * nf) | |
self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 16 * nf) | |
# 20200211 test 4x with only 3 stage | |
self.ups = nn.ModuleList([ | |
SPADEResnetBlock(16 * nf, 8 * nf, opt, 8 * nf), | |
SPADEResnetBlock(8 * nf, 4 * nf, opt, 4 * nf), | |
SPADEResnetBlock(4 * nf, 2 * nf, opt, 2 * nf), | |
SPADEResnetBlock(2 * nf, 1 * nf, opt, 1 * nf) # here | |
]) | |
self.to_rgbs = nn.ModuleList([ | |
nn.Conv2d(8 * nf, 3, 3, padding=1), | |
nn.Conv2d(4 * nf, 3, 3, padding=1), | |
nn.Conv2d(2 * nf, 3, 3, padding=1), | |
nn.Conv2d(1 * nf, 3, 3, padding=1) # here | |
]) | |
self.up = nn.Upsample(scale_factor=2) | |
self.encoder = ContentAdaptiveSuppresor(opt, self.sw, self.sh, self.scale_ratio) | |
def nested_encode(self, x): | |
return self.encoder(x) | |
def forward(self, input): | |
xs = self.nested_encode(input) | |
x = self.encode(input) | |
''' | |
print([_x.shape for _x in xs]) | |
print(x.shape) | |
print(self.head_0) | |
''' | |
x = self.head_0(x, xs[0]) | |
x = self.up(x) | |
x = self.G_middle_0(x, xs[1]) | |
x = self.G_middle_1(x, xs[1]) | |
if self.opt.is_test: | |
phase = len(self.to_rgbs) | |
else: | |
phase = self.opt.train_phase+1 | |
for i in range(phase): | |
x = self.up(x) | |
x = self.ups[i](x, xs[i+2]) | |
x = self.to_rgbs[phase-1](F.leaky_relu(x, 2e-1)) | |
x = torch.tanh(x) | |
return x | |
class ContentAdaptiveSuppresor(BaseNetwork): | |
def __init__(self, opt, sw, sh, n_2xdown, | |
norm_layer=nn.InstanceNorm2d): | |
super().__init__() | |
self.sw = sw | |
self.sh = sh | |
self.max_ratio = 16 | |
self.n_2xdown = n_2xdown | |
# norm_layer = get_nonspade_norm_layer(opt, opt.norm_E) | |
# 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold | |
ngf = opt.ngf | |
kw = 3 | |
pw = (kw - 1) // 2 | |
self.head = nn.Sequential( | |
nn.Conv2d(opt.semantic_nc, ngf, kw, stride=1, padding=pw, bias=False), | |
norm_layer(ngf), | |
nn.ReLU(), | |
) | |
cur_ratio = 1 | |
for i in range(n_2xdown): | |
next_ratio = min(cur_ratio*2, self.max_ratio) | |
model = [ | |
SimplifiedLIP(ngf*cur_ratio), | |
nn.Conv2d(ngf*cur_ratio, ngf*next_ratio, kw, stride=1, padding=pw), | |
norm_layer(ngf*next_ratio), | |
] | |
cur_ratio = next_ratio | |
if i < n_2xdown - 1: | |
model += [nn.ReLU(inplace=True)] | |
setattr(self, 'encoder_%d' % i, nn.Sequential(*model)) | |
def forward(self, x): | |
# 20200628: Note the features are arranged from small to large | |
x = [self.head(x)] | |
for i in range(self.n_2xdown): | |
net = getattr(self, 'encoder_%d' % i) | |
x = [net(x[0])] + x | |
return x | |
######################################### | |
# Below are deprecated codes | |
# | |
# 20200309: LIP for local importance pooling | |
# 20200311: Self-supervised mask encoder | |
# Author: lingbo.ylb | |
# Quick trial, to be reformated later. | |
# 20200324: Nah forget about it... | |
######################################### | |
def lip2d(x, logit, kernel=3, stride=2, padding=1): | |
weight = logit.exp() | |
return F.avg_pool2d(x*weight, kernel, stride, padding) / F.avg_pool2d(weight, kernel, stride, padding) | |
class SoftGate(nn.Module): | |
COEFF = 12.0 | |
def __init__(self): | |
super(SoftGate, self).__init__() | |
def forward(self, x): | |
return torch.sigmoid(x).mul(self.COEFF) | |
class SimplifiedLIP(nn.Module): | |
def __init__(self, channels): | |
super(SimplifiedLIP, self).__init__() | |
rp = channels | |
self.logit = nn.Sequential( | |
nn.Conv2d(channels, channels, 3, padding=1, bias=False), | |
nn.InstanceNorm2d(channels, affine=True), | |
SoftGate() | |
) | |
''' | |
OrderedDict(( | |
('conv', nn.Conv2d(channels, channels, 3, padding=1, bias=False)), | |
('bn', nn.InstanceNorm2d(channels, affine=True)), | |
('gate', SoftGate()), | |
)) | |
''' | |
def init_layer(self): | |
self.logit[0].weight.data.fill_(0.0) | |
def forward(self, x): | |
frac = lip2d(x, self.logit(x)) | |
return frac | |
class LIPEncoder(BaseNetwork): | |
def __init__(self, opt, sw, sh, n_2xdown, | |
norm_layer=nn.InstanceNorm2d): | |
super().__init__() | |
self.sw = sw | |
self.sh = sh | |
self.max_ratio = 16 | |
# norm_layer = get_nonspade_norm_layer(opt, opt.norm_E) | |
# 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold | |
ngf = opt.ngf | |
kw = 3 | |
pw = (kw - 1) // 2 | |
model = [ | |
nn.Conv2d(opt.semantic_nc, ngf, kw, stride=1, padding=pw, bias=False), | |
norm_layer(ngf), | |
nn.ReLU(), | |
] | |
cur_ratio = 1 | |
for i in range(n_2xdown): | |
next_ratio = min(cur_ratio*2, self.max_ratio) | |
model += [ | |
SimplifiedLIP(ngf*cur_ratio), | |
nn.Conv2d(ngf*cur_ratio, ngf*next_ratio, kw, stride=1, padding=pw), | |
norm_layer(ngf*next_ratio), | |
] | |
cur_ratio = next_ratio | |
if i < n_2xdown - 1: | |
model += [nn.ReLU(inplace=True)] | |
self.model = nn.Sequential(*model) | |
def forward(self, x): | |
return self.model(x) | |
class LIPSPADEGenerator(SPADEGenerator): | |
''' | |
20200309: SPADEGenerator with a learnable feature encoder | |
Encoder design: Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019) | |
''' | |
def __init__(self, opt): | |
super().__init__(opt) | |
self.lip_encoder = LIPEncoder(opt, self.sw, self.sh, self.scale_ratio) | |
def encode(self, x): | |
return self.lip_encoder(x) | |
class NoiseClassPredictor(nn.Module): | |
''' | |
Input: nc*sw*sw tensor, either from clean or corrupted images | |
Output: n-dim tensor indicating the loss type (or intensity?) | |
''' | |
def __init__(self, opt, sw, nc, outdim): | |
super().__init__() | |
nbottleneck = 256 | |
middim = 256 | |
# Compact info | |
conv = [ | |
nn.Conv2d(nc, nbottleneck, 1, stride=1), | |
nn.InstanceNorm2d(nbottleneck), | |
nn.LeakyReLU(0.2, inplace=True), | |
] | |
# sw should be probably 16, downsample to 4 and convert to 1 | |
while sw > 4: | |
sw = sw // 2 | |
conv += [ | |
nn.Conv2d(nbottleneck, nbottleneck, 3, stride=2, padding=1), | |
nn.InstanceNorm2d(nbottleneck), | |
nn.LeakyReLU(0.2, inplace=True), | |
] | |
self.conv = nn.Sequential(*conv) | |
indim = sw * sw * nbottleneck | |
self.fc = nn.Sequential( | |
nn.Linear(indim, middim), | |
nn.BatchNorm1d(middim), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Linear(middim, outdim), | |
# nn.Sigmoid(), | |
) | |
def forward(self, x): | |
x = self.conv(x) | |
x = x.view(x.shape[0],-1) | |
return self.fc(x) | |
class NoiseIntensityPredictor(nn.Module): | |
''' | |
Input: nc*sw*sw tensor, either from clean or corrupted images | |
Output: 1-dim tensor indicating the loss intensity | |
''' | |
def __init__(self, opt, sw, nc, outdim): | |
super().__init__() | |
nbottleneck = 256 | |
middim = 256 | |
# Compact info | |
conv = [ | |
nn.Conv2d(nc, nbottleneck, 1, stride=1), | |
nn.BatchNorm2d(nbottleneck), | |
nn.LeakyReLU(0.2, inplace=True), | |
] | |
# sw should be probably 16, downsample to 4 and convert to 1 | |
while sw > 4: | |
sw = sw // 2 | |
conv += [ | |
nn.Conv2d(nbottleneck, nbottleneck, 3, stride=2, padding=1), | |
nn.BatchNorm2d(nbottleneck), | |
nn.LeakyReLU(0.2, inplace=True), | |
] | |
self.conv = nn.Sequential(*conv) | |
indim = sw * sw * nbottleneck | |
self.fc = nn.Sequential( | |
nn.Linear(indim, middim), | |
nn.BatchNorm1d(middim), | |
#nn.Dropout(0.5), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Linear(middim, outdim), | |
#nn.Dropout(0.5), | |
# nn.Sigmoid(), | |
) | |
def forward(self, x): | |
x = self.conv(x) | |
x = x.view(x.shape[0],-1) | |
x = self.fc(x) | |
return x.squeeze() | |
class SubAddGenerator(SPADEGenerator): | |
''' | |
20200311: | |
This generator contains a complete set | |
of self-supervised training scheme | |
that requires a separate dataloader. | |
The self-supervised pre-training is | |
implemented as a clean interface. | |
SubAddGenerator::train_E(self, dataloader, epochs) | |
For the upperlevel Pix2Pix_Model, | |
two things to be done: | |
A) Run the pretrain script | |
B) Save encoder and adjust learning rate. | |
----------------------------------------- | |
20200312: | |
Pre-test problem: The discriminator is hard to test real vs fake | |
Also, using residual of feature maps ain't work... | |
Cause: The feature map is too close to be properly separated. | |
Try to test on one single reduction: arbitrary ratio of downsampling | |
try to estimate the reduction ratio? | |
''' | |
def __init__(self, opt): | |
super().__init__(opt) | |
self.encoder = LIPEncoder(opt, self.sw, self.sh, self.scale_ratio) | |
self.dis_nc = self.opt.ngf * min(16, 2**self.scale_ratio) | |
# intensity is a scalar | |
self.discriminator = NoiseIntensityPredictor(opt, self.sw, self.dis_nc, 1) | |
if opt.isTrain: | |
self.attach_dataloader(opt) | |
self.noise_dim = opt.noise_dim | |
self.l1_loss = nn.L1Loss() | |
self.gan_loss = nn.MSELoss() | |
#self.discriminator = NoiseClassPredictor(opt, self.sw, self.dis_nc, | |
# self.noise_dim + 1) # add a new label for clean images | |
#self.gan_loss = nn.CrossEntropyLoss() | |
beta1, beta2 = opt.beta1, opt.beta2 | |
if opt.no_TTUR: | |
G_lr, D_lr = opt.lr, opt.lr | |
else: | |
G_lr, D_lr = opt.lr / 2, opt.lr * 2 | |
self.optimizer_E = torch.optim.Adam( | |
self.encoder.parameters(), lr=G_lr, betas=(beta1, beta2) | |
) | |
self.optimizer_D = torch.optim.Adam( | |
self.discriminator.parameters(), lr=D_lr/2, betas=(beta1, beta2) | |
) | |
def _create_auxiliary_opt(self, opt): | |
''' | |
Create auxiliary options | |
change necessary params | |
--- dataroot | |
--- dataset_mode | |
''' | |
aux_opt = copy.deepcopy(opt) # just for safety | |
aux_opt.dataroot = opt.dataroot_assist | |
aux_opt.dataset_mode = 'assist' | |
aux_opt.batchSize = 4 | |
aux_opt.nThreads = 4 | |
return aux_opt | |
def attach_dataloader(self, opt): | |
aux_opt = self._create_auxiliary_opt(opt) | |
self.loader = data.create_dataloader(aux_opt) | |
def encode(self, x): | |
return self.encoder(x) | |
def process_input(self, data): | |
self.clean = data['clean'].cuda() | |
self.noisy = data['noisy'].cuda() | |
# for BCELoss, class label is just int. | |
self.noise_label = data['label'].cuda() | |
# label design... | |
# clean label should be 0 or [1,0,0...0]? | |
# for BCELoss, class label is just 0. | |
#self.clean_label = torch.zeros_like(self.noise_label) | |
self.clean_label = torch.ones_like(self.noise_label) | |
def update_E(self): | |
bundle_in = torch.cat((self.clean, self.noisy), dim=0) | |
bundle_out = self.encode(bundle_in) | |
nb = bundle_in.shape[0] // 2 | |
F_real, F_fake = bundle_out[:nb], bundle_out[nb:] | |
pred_fake = self.discriminator(F_fake) | |
loss_l1 = self.l1_loss(F_fake, F_real) | |
loss_gan = self.gan_loss(pred_fake, self.clean_label) | |
loss_sum = loss_l1 * 10 + loss_gan | |
self.optimizer_E.zero_grad() | |
loss_sum.backward() | |
self.optimizer_E.step() | |
self.loss_l1 = loss_l1.item() | |
self.loss_gan_E = loss_gan.item() | |
self.loss_sum = loss_sum.item() | |
def update_D(self): | |
with torch.no_grad(): | |
bundle_in = torch.cat((self.clean, self.noisy), dim=0) | |
bundle_out = self.encode(bundle_in) | |
nb = bundle_in.shape[0] // 2 | |
#F_real, #F_fake = bundle_out[:nb], bundle_out[nb:] | |
F_fake = bundle_out[nb:] / (bundle_out[:nb] + 1e-6) | |
F_real = torch.ones_like(F_fake, requires_grad=False) | |
pred_real = self.discriminator(F_real) | |
loss_real = self.gan_loss(pred_real, self.clean_label) | |
pred_fake = self.discriminator(F_fake.detach()) | |
loss_fake = self.gan_loss(pred_fake, self.noise_label) | |
loss_sum = (loss_real + loss_fake * self.opt.noise_dim) / 2 | |
self.optimizer_D.zero_grad() | |
loss_sum.backward() | |
self.optimizer_D.step() | |
self.loss_gan_D_real = loss_real.item() | |
self.loss_gan_D_fake = loss_fake.item() | |
def debug_D(self): | |
with torch.no_grad(): | |
bundle_in = torch.cat((self.clean, self.noisy), dim=0) | |
bundle_out = self.encode(bundle_in) | |
nb = bundle_in.shape[0] // 2 | |
F_real, F_fake = bundle_out[:nb], bundle_out[nb:] | |
F_res = F_fake - F_real # try to predict the residual, it's easier | |
#F_real = torch.zeros_like(F_real) # real res == 0 | |
pred_real = self.discriminator(F_real)#.argmax(dim=1) | |
pred_fake = self.discriminator(F_res.detach())#.argmax(dim=1) | |
print(pred_real, pred_fake) | |
real_acc = (pred_real == 0).sum().item() / pred_real.shape[0] | |
fake_acc = (pred_fake == self.noise_label).sum().item() / pred_fake.shape[0] | |
print(real_acc, fake_acc) | |
def log(self, epoch, i): | |
logstring = ' Epoch [%d] iter [%d]: ' % (epoch, i) | |
logstring += 'l1: %.4f ' % self.loss_l1 | |
logstring += 'gen: %.4f ' % self.loss_gan_E | |
logstring += 'E_sum: %.4f ' % self.loss_sum | |
logstring += 'dis_real: %.4f ' % self.loss_gan_D_real | |
logstring += 'dis_fake: %.4f' % self.loss_gan_D_fake | |
print(logstring) | |
def train_E(self, epochs): | |
pretrained_ckpt_dir = os.path.join( | |
self.opt.checkpoints_dir, self.opt.name, | |
'pretrained_net_E_%d.pth' % epochs | |
) | |
print(pretrained_ckpt_dir) | |
print('======= Stage I: Subtraction =======') | |
if os.path.isfile(pretrained_ckpt_dir): | |
state_dict_E = torch.load(pretrained_ckpt_dir) | |
self.encoder.load_state_dict(state_dict_E) | |
print('======= Load cached checkpoints %s' % pretrained_ckpt_dir) | |
else: | |
print('======= total epochs: %d ' % epochs) | |
for epoch in range(1,epochs+1): | |
for i, data in enumerate(self.loader): | |
self.process_input(data) | |
self.update_E() | |
self.update_D() | |
if i % 10 == 0: | |
self.log(epoch, i) # output losses and thing. | |
print('Epoch [%d] finished' % epoch) | |
# just save the latest. | |
torch.save(self.encoder.state_dict(), os.path.join( | |
self.opt.checkpoints_dir, self.opt.name, 'pretrained_net_E_%d.pth' % epoch | |
)) | |
class ContrasiveGenerator(SPADEGenerator): | |
def __init__(self, opt): | |
super().__init__(opt) | |
self.encoder = LIPEncoder(opt, self.sw, self.sh, self.scale_ratio) | |
if opt.isTrain: | |
self.attach_dataloader(opt) | |
self.noise_dim = opt.noise_dim | |
self.l1_loss = nn.L1Loss() | |
beta1, beta2 = opt.beta1, opt.beta2 | |
self.optimizer_E = torch.optim.Adam( | |
self.encoder.parameters(), lr=opt.lr, betas=(beta1, beta2) | |
) | |
def _create_auxiliary_opt(self, opt): | |
''' | |
Create auxiliary options | |
change necessary params | |
--- dataroot | |
--- dataset_mode | |
''' | |
aux_opt = copy.deepcopy(opt) # just for safety | |
aux_opt.dataroot = opt.dataroot_assist | |
aux_opt.dataset_mode = 'assist' | |
aux_opt.batchSize = 8 | |
aux_opt.nThreads = 4 | |
return aux_opt | |
def attach_dataloader(self, opt): | |
aux_opt = self._create_auxiliary_opt(opt) | |
self.loader = data.create_dataloader(aux_opt) | |
def encode(self, x): | |
return self.encoder(x) | |
def process_input(self, data): | |
self.clean = data['clean'].cuda() | |
self.noisy = data['noisy'].cuda() | |
def update_E(self): | |
bundle_in = torch.cat((self.clean, self.noisy), dim=0) | |
bundle_out = self.encode(bundle_in) | |
nb = bundle_in.shape[0] // 2 | |
F_real, F_fake = bundle_out[:nb], bundle_out[nb:] | |
loss_l1 = self.l1_loss(F_fake, F_real) | |
self.optimizer_E.zero_grad() | |
loss_l1.backward() | |
self.optimizer_E.step() | |
self.loss_l1 = loss_l1.item() | |
def log(self, epoch, i): | |
logstring = ' Epoch [%d] iter [%d]: ' % (epoch, i) | |
logstring += 'l1: %.4f ' % self.loss_l1 | |
print(logstring) | |
def train_E(self, epochs): | |
pretrained_ckpt_dir = os.path.join( | |
self.opt.checkpoints_dir, self.opt.name, | |
'pretrained_net_E_%d.pth' % epochs | |
) | |
print(pretrained_ckpt_dir) | |
print('======= Stage I: Subtraction =======') | |
if os.path.isfile(pretrained_ckpt_dir): | |
state_dict_E = torch.load(pretrained_ckpt_dir) | |
self.encoder.load_state_dict(state_dict_E) | |
print('======= Load cached checkpoints %s' % pretrained_ckpt_dir) | |
else: | |
print('======= total epochs: %d ' % epochs) | |
for epoch in range(1,epochs+1): | |
for i, data in enumerate(self.loader): | |
self.process_input(data) | |
self.update_E() | |
if i % 10 == 0: | |
self.log(epoch, i) # output losses and thing. | |
print('Epoch [%d] finished' % epoch) | |
# just save the latest. | |
torch.save(self.encoder.state_dict(), os.path.join( | |
self.opt.checkpoints_dir, self.opt.name, 'pretrained_net_E_%d.pth' % epoch | |
)) | |