|
from __future__ import print_function |
|
import argparse |
|
import os |
|
import random |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.parallel |
|
import torch.backends.cudnn as cudnn |
|
import torch.optim as optim |
|
import torch.utils.data |
|
import torchvision.datasets as dset |
|
import torchvision.transforms as transforms |
|
import torchvision.utils as vutils |
|
|
|
try: |
|
from apex import amp |
|
except ImportError: |
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--dataset', default='cifar10', help='cifar10 | lsun | mnist |imagenet | folder | lfw | fake') |
|
parser.add_argument('--dataroot', default='./', help='path to dataset') |
|
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) |
|
parser.add_argument('--batchSize', type=int, default=64, help='input batch size') |
|
parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network') |
|
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector') |
|
parser.add_argument('--ngf', type=int, default=64) |
|
parser.add_argument('--ndf', type=int, default=64) |
|
parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for') |
|
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002') |
|
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') |
|
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') |
|
parser.add_argument('--netG', default='', help="path to netG (to continue training)") |
|
parser.add_argument('--netD', default='', help="path to netD (to continue training)") |
|
parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints') |
|
parser.add_argument('--manualSeed', type=int, help='manual seed') |
|
parser.add_argument('--classes', default='bedroom', help='comma separated list of classes for the lsun data set') |
|
parser.add_argument('--opt_level', default='O1', help='amp opt_level, default="O1"') |
|
|
|
opt = parser.parse_args() |
|
print(opt) |
|
|
|
|
|
try: |
|
os.makedirs(opt.outf) |
|
except OSError: |
|
pass |
|
|
|
if opt.manualSeed is None: |
|
opt.manualSeed = 2809 |
|
print("Random Seed: ", opt.manualSeed) |
|
random.seed(opt.manualSeed) |
|
torch.manual_seed(opt.manualSeed) |
|
|
|
cudnn.benchmark = True |
|
|
|
|
|
if opt.dataset in ['imagenet', 'folder', 'lfw']: |
|
|
|
dataset = dset.ImageFolder(root=opt.dataroot, |
|
transform=transforms.Compose([ |
|
transforms.Resize(opt.imageSize), |
|
transforms.CenterCrop(opt.imageSize), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
|
])) |
|
nc=3 |
|
elif opt.dataset == 'lsun': |
|
classes = [ c + '_train' for c in opt.classes.split(',')] |
|
dataset = dset.LSUN(root=opt.dataroot, classes=classes, |
|
transform=transforms.Compose([ |
|
transforms.Resize(opt.imageSize), |
|
transforms.CenterCrop(opt.imageSize), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
|
])) |
|
nc=3 |
|
elif opt.dataset == 'cifar10': |
|
dataset = dset.CIFAR10(root=opt.dataroot, download=True, |
|
transform=transforms.Compose([ |
|
transforms.Resize(opt.imageSize), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
|
])) |
|
nc=3 |
|
|
|
elif opt.dataset == 'mnist': |
|
dataset = dset.MNIST(root=opt.dataroot, download=True, |
|
transform=transforms.Compose([ |
|
transforms.Resize(opt.imageSize), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5,), (0.5,)), |
|
])) |
|
nc=1 |
|
|
|
elif opt.dataset == 'fake': |
|
dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize), |
|
transform=transforms.ToTensor()) |
|
nc=3 |
|
|
|
assert dataset |
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, |
|
shuffle=True, num_workers=int(opt.workers)) |
|
|
|
device = torch.device("cuda:0") |
|
ngpu = int(opt.ngpu) |
|
nz = int(opt.nz) |
|
ngf = int(opt.ngf) |
|
ndf = int(opt.ndf) |
|
|
|
|
|
|
|
def weights_init(m): |
|
classname = m.__class__.__name__ |
|
if classname.find('Conv') != -1: |
|
m.weight.data.normal_(0.0, 0.02) |
|
elif classname.find('BatchNorm') != -1: |
|
m.weight.data.normal_(1.0, 0.02) |
|
m.bias.data.fill_(0) |
|
|
|
|
|
class Generator(nn.Module): |
|
def __init__(self, ngpu): |
|
super(Generator, self).__init__() |
|
self.ngpu = ngpu |
|
self.main = nn.Sequential( |
|
|
|
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False), |
|
nn.BatchNorm2d(ngf * 8), |
|
nn.ReLU(True), |
|
|
|
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), |
|
nn.BatchNorm2d(ngf * 4), |
|
nn.ReLU(True), |
|
|
|
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), |
|
nn.BatchNorm2d(ngf * 2), |
|
nn.ReLU(True), |
|
|
|
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), |
|
nn.BatchNorm2d(ngf), |
|
nn.ReLU(True), |
|
|
|
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), |
|
nn.Tanh() |
|
|
|
) |
|
|
|
def forward(self, input): |
|
if input.is_cuda and self.ngpu > 1: |
|
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) |
|
else: |
|
output = self.main(input) |
|
return output |
|
|
|
|
|
netG = Generator(ngpu).to(device) |
|
netG.apply(weights_init) |
|
if opt.netG != '': |
|
netG.load_state_dict(torch.load(opt.netG)) |
|
print(netG) |
|
|
|
|
|
class Discriminator(nn.Module): |
|
def __init__(self, ngpu): |
|
super(Discriminator, self).__init__() |
|
self.ngpu = ngpu |
|
self.main = nn.Sequential( |
|
|
|
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
|
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), |
|
nn.BatchNorm2d(ndf * 2), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
|
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), |
|
nn.BatchNorm2d(ndf * 4), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
|
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), |
|
nn.BatchNorm2d(ndf * 8), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
|
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), |
|
) |
|
|
|
def forward(self, input): |
|
if input.is_cuda and self.ngpu > 1: |
|
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) |
|
else: |
|
output = self.main(input) |
|
|
|
return output.view(-1, 1).squeeze(1) |
|
|
|
|
|
netD = Discriminator(ngpu).to(device) |
|
netD.apply(weights_init) |
|
if opt.netD != '': |
|
netD.load_state_dict(torch.load(opt.netD)) |
|
print(netD) |
|
|
|
criterion = nn.BCEWithLogitsLoss() |
|
|
|
fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device) |
|
real_label = 1 |
|
fake_label = 0 |
|
|
|
|
|
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) |
|
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) |
|
|
|
[netD, netG], [optimizerD, optimizerG] = amp.initialize( |
|
[netD, netG], [optimizerD, optimizerG], opt_level=opt.opt_level, num_losses=3) |
|
|
|
for epoch in range(opt.niter): |
|
for i, data in enumerate(dataloader, 0): |
|
|
|
|
|
|
|
|
|
netD.zero_grad() |
|
real_cpu = data[0].to(device) |
|
batch_size = real_cpu.size(0) |
|
label = torch.full((batch_size,), real_label, device=device) |
|
|
|
output = netD(real_cpu) |
|
errD_real = criterion(output, label) |
|
with amp.scale_loss(errD_real, optimizerD, loss_id=0) as errD_real_scaled: |
|
errD_real_scaled.backward() |
|
D_x = output.mean().item() |
|
|
|
|
|
noise = torch.randn(batch_size, nz, 1, 1, device=device) |
|
fake = netG(noise) |
|
label.fill_(fake_label) |
|
output = netD(fake.detach()) |
|
errD_fake = criterion(output, label) |
|
with amp.scale_loss(errD_fake, optimizerD, loss_id=1) as errD_fake_scaled: |
|
errD_fake_scaled.backward() |
|
D_G_z1 = output.mean().item() |
|
errD = errD_real + errD_fake |
|
optimizerD.step() |
|
|
|
|
|
|
|
|
|
netG.zero_grad() |
|
label.fill_(real_label) |
|
output = netD(fake) |
|
errG = criterion(output, label) |
|
with amp.scale_loss(errG, optimizerG, loss_id=2) as errG_scaled: |
|
errG_scaled.backward() |
|
D_G_z2 = output.mean().item() |
|
optimizerG.step() |
|
|
|
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' |
|
% (epoch, opt.niter, i, len(dataloader), |
|
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) |
|
if i % 100 == 0: |
|
vutils.save_image(real_cpu, |
|
'%s/real_samples.png' % opt.outf, |
|
normalize=True) |
|
fake = netG(fixed_noise) |
|
vutils.save_image(fake.detach(), |
|
'%s/amp_fake_samples_epoch_%03d.png' % (opt.outf, epoch), |
|
normalize=True) |
|
|
|
|
|
torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch)) |
|
torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch)) |
|
|
|
|
|
|