Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
·
1a02524
1
Parent(s):
572f947
cond grad penalty: use only cond embedding to compute grad
Browse files- score_sde/models/ncsnpp_generator_adagn.py +3 -0
- train_ddgan.py +75 -26
score_sde/models/ncsnpp_generator_adagn.py
CHANGED
|
@@ -325,9 +325,12 @@ class NCSNpp(nn.Module):
|
|
| 325 |
|
| 326 |
hs = [modules[m_idx](x)]
|
| 327 |
m_idx += 1
|
|
|
|
|
|
|
| 328 |
for i_level in range(self.num_resolutions):
|
| 329 |
# Residual blocks for this resolution
|
| 330 |
for i_block in range(self.num_res_blocks):
|
|
|
|
| 331 |
h = modules[m_idx](hs[-1], temb, zemb)
|
| 332 |
m_idx += 1
|
| 333 |
if h.shape[-1] in self.attn_resolutions:
|
|
|
|
| 325 |
|
| 326 |
hs = [modules[m_idx](x)]
|
| 327 |
m_idx += 1
|
| 328 |
+
#print(self.attn_resolutions)
|
| 329 |
+
#self.attn_resolutions = (32,)
|
| 330 |
for i_level in range(self.num_resolutions):
|
| 331 |
# Residual blocks for this resolution
|
| 332 |
for i_block in range(self.num_res_blocks):
|
| 333 |
+
#print(hs[-1].shape, temb.shape, zemb.shape, type(modules[m_idx]))
|
| 334 |
h = modules[m_idx](hs[-1], temb, zemb)
|
| 335 |
m_idx += 1
|
| 336 |
if h.shape[-1] in self.attn_resolutions:
|
train_ddgan.py
CHANGED
|
@@ -28,7 +28,10 @@ from torch.multiprocessing import Process
|
|
| 28 |
import torch.distributed as dist
|
| 29 |
import shutil
|
| 30 |
import logging
|
| 31 |
-
import
|
|
|
|
|
|
|
|
|
|
| 32 |
def log_and_continue(exn):
|
| 33 |
logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
|
| 34 |
return True
|
|
@@ -192,7 +195,11 @@ def sample_from_model(coefficients, generator, n_time, x_init, T, opt, cond=None
|
|
| 192 |
return x
|
| 193 |
|
| 194 |
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
def train(rank, gpu, args):
|
| 198 |
from score_sde.models.discriminator import Discriminator_small, Discriminator_large, CondAttnDiscriminator, SmallCondAttnDiscriminator
|
|
@@ -278,6 +285,7 @@ def train(rank, gpu, args):
|
|
| 278 |
),
|
| 279 |
])
|
| 280 |
pipeline.extend([
|
|
|
|
| 281 |
wds.decode("pilrgb", handler=log_and_continue),
|
| 282 |
wds.rename(image="jpg;png"),
|
| 283 |
wds.map_dict(image=train_transform),
|
|
@@ -307,7 +315,7 @@ def train(rank, gpu, args):
|
|
| 307 |
pin_memory=True,
|
| 308 |
sampler=train_sampler,
|
| 309 |
)
|
| 310 |
-
text_encoder =
|
| 311 |
args.cond_size = text_encoder.output_size
|
| 312 |
netG = NCSNpp(args).to(device)
|
| 313 |
nb_params = 0
|
|
@@ -387,7 +395,7 @@ def train(rank, gpu, args):
|
|
| 387 |
.format(checkpoint['epoch']))
|
| 388 |
else:
|
| 389 |
global_step, epoch, init_epoch = 0, 0, 0
|
| 390 |
-
|
| 391 |
for epoch in range(init_epoch, args.num_epoch+1):
|
| 392 |
if args.dataset == "wds":
|
| 393 |
os.environ["WDS_EPOCH"] = str(epoch)
|
|
@@ -419,45 +427,71 @@ def train(rank, gpu, args):
|
|
| 419 |
x_t, x_tp1 = q_sample_pairs(coeff, real_data, t)
|
| 420 |
x_t.requires_grad = True
|
| 421 |
|
| 422 |
-
cond_for_discr = (cond_pooled, cond, cond_mask) if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
|
| 424 |
# train with real
|
| 425 |
D_real = netD(x_t, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
|
| 426 |
|
| 427 |
errD_real = F.softplus(-D_real)
|
| 428 |
errD_real = errD_real.mean()
|
|
|
|
| 429 |
|
| 430 |
errD_real.backward(retain_graph=True)
|
| 431 |
|
| 432 |
|
| 433 |
if args.lazy_reg is None:
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
else:
|
| 445 |
-
if global_step % args.lazy_reg == 0:
|
| 446 |
grad_real = torch.autograd.grad(
|
| 447 |
-
|
| 448 |
-
|
| 449 |
grad_penalty = (
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
grad_penalty = args.r1_gamma / 2 * grad_penalty
|
| 455 |
grad_penalty.backward()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
|
| 457 |
# train with fake
|
| 458 |
latent_z = torch.randn(batch_size, nz, device=device)
|
| 459 |
|
| 460 |
-
|
| 461 |
x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
|
| 462 |
x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
|
| 463 |
|
|
@@ -466,6 +500,18 @@ def train(rank, gpu, args):
|
|
| 466 |
|
| 467 |
errD_fake = F.softplus(output)
|
| 468 |
errD_fake = errD_fake.mean()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
errD_fake.backward()
|
| 470 |
|
| 471 |
|
|
@@ -592,6 +638,7 @@ if __name__ == '__main__':
|
|
| 592 |
|
| 593 |
parser.add_argument('--resume', action='store_true',default=False)
|
| 594 |
parser.add_argument('--masked_mean', action='store_true',default=False)
|
|
|
|
| 595 |
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
| 596 |
parser.add_argument('--cross_attention', action='store_true',default=False)
|
| 597 |
|
|
@@ -616,7 +663,7 @@ if __name__ == '__main__':
|
|
| 616 |
help='channel multiplier')
|
| 617 |
parser.add_argument('--num_res_blocks', type=int, default=2,
|
| 618 |
help='number of resnet blocks per scale')
|
| 619 |
-
parser.add_argument('--attn_resolutions', default=(16,),
|
| 620 |
help='resolution of applying attention')
|
| 621 |
parser.add_argument('--dropout', type=float, default=0.,
|
| 622 |
help='drop-out rate')
|
|
@@ -665,12 +712,14 @@ if __name__ == '__main__':
|
|
| 665 |
parser.add_argument('--beta2', type=float, default=0.9,
|
| 666 |
help='beta2 for adam')
|
| 667 |
parser.add_argument('--no_lr_decay',action='store_true', default=False)
|
| 668 |
-
|
|
|
|
| 669 |
parser.add_argument('--use_ema', action='store_true', default=False,
|
| 670 |
help='use EMA or not')
|
| 671 |
parser.add_argument('--ema_decay', type=float, default=0.9999, help='decay rate for EMA')
|
| 672 |
|
| 673 |
parser.add_argument('--r1_gamma', type=float, default=0.05, help='coef for r1 reg')
|
|
|
|
| 674 |
parser.add_argument('--lazy_reg', type=int, default=None,
|
| 675 |
help='lazy regulariation.')
|
| 676 |
|
|
|
|
| 28 |
import torch.distributed as dist
|
| 29 |
import shutil
|
| 30 |
import logging
|
| 31 |
+
from encoder import build_encoder
|
| 32 |
+
from utils import ResampledShards2
|
| 33 |
+
|
| 34 |
+
|
| 35 |
def log_and_continue(exn):
|
| 36 |
logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
|
| 37 |
return True
|
|
|
|
| 195 |
return x
|
| 196 |
|
| 197 |
|
| 198 |
+
|
| 199 |
+
def filter_no_caption(sample):
|
| 200 |
+
return 'txt' in sample
|
| 201 |
+
|
| 202 |
+
|
| 203 |
|
| 204 |
def train(rank, gpu, args):
|
| 205 |
from score_sde.models.discriminator import Discriminator_small, Discriminator_large, CondAttnDiscriminator, SmallCondAttnDiscriminator
|
|
|
|
| 285 |
),
|
| 286 |
])
|
| 287 |
pipeline.extend([
|
| 288 |
+
wds.select(filter_no_caption),
|
| 289 |
wds.decode("pilrgb", handler=log_and_continue),
|
| 290 |
wds.rename(image="jpg;png"),
|
| 291 |
wds.map_dict(image=train_transform),
|
|
|
|
| 315 |
pin_memory=True,
|
| 316 |
sampler=train_sampler,
|
| 317 |
)
|
| 318 |
+
text_encoder = build_encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
|
| 319 |
args.cond_size = text_encoder.output_size
|
| 320 |
netG = NCSNpp(args).to(device)
|
| 321 |
nb_params = 0
|
|
|
|
| 395 |
.format(checkpoint['epoch']))
|
| 396 |
else:
|
| 397 |
global_step, epoch, init_epoch = 0, 0, 0
|
| 398 |
+
use_cond_attn_discr = args.discr_type in ("large_cond_attn", "small_cond_attn")
|
| 399 |
for epoch in range(init_epoch, args.num_epoch+1):
|
| 400 |
if args.dataset == "wds":
|
| 401 |
os.environ["WDS_EPOCH"] = str(epoch)
|
|
|
|
| 427 |
x_t, x_tp1 = q_sample_pairs(coeff, real_data, t)
|
| 428 |
x_t.requires_grad = True
|
| 429 |
|
| 430 |
+
cond_for_discr = (cond_pooled, cond, cond_mask) if use_cond_attn_discr else cond_pooled
|
| 431 |
+
if args.grad_penalty_cond:
|
| 432 |
+
if use_cond_attn_discr:
|
| 433 |
+
#cond_pooled.requires_grad = True
|
| 434 |
+
cond.requires_grad = True
|
| 435 |
+
#cond_mask.requires_grad = True
|
| 436 |
+
else:
|
| 437 |
+
cond_for_discr.requires_grad = True
|
| 438 |
|
| 439 |
# train with real
|
| 440 |
D_real = netD(x_t, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
|
| 441 |
|
| 442 |
errD_real = F.softplus(-D_real)
|
| 443 |
errD_real = errD_real.mean()
|
| 444 |
+
|
| 445 |
|
| 446 |
errD_real.backward(retain_graph=True)
|
| 447 |
|
| 448 |
|
| 449 |
if args.lazy_reg is None:
|
| 450 |
+
if args.grad_penalty_cond:
|
| 451 |
+
inputs = (x_t,) + (cond,) if use_cond_attn_discr else (cond_for_discr,)
|
| 452 |
+
grad_real = torch.autograd.grad(
|
| 453 |
+
outputs=D_real.sum(), inputs=inputs, create_graph=True
|
| 454 |
+
)[0]
|
| 455 |
+
grad_real = torch.cat([g.view(g.size(0), -1) for g in grad_real])
|
| 456 |
+
grad_penalty = (grad_real.norm(2, dim=1) ** 2).mean()
|
| 457 |
+
grad_penalty = args.r1_gamma / 2 * grad_penalty
|
| 458 |
+
grad_penalty.backward()
|
| 459 |
+
else:
|
|
|
|
|
|
|
| 460 |
grad_real = torch.autograd.grad(
|
| 461 |
+
outputs=D_real.sum(), inputs=x_t, create_graph=True
|
| 462 |
+
)[0]
|
| 463 |
grad_penalty = (
|
| 464 |
+
grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2
|
| 465 |
+
).mean()
|
| 466 |
+
|
| 467 |
+
|
| 468 |
grad_penalty = args.r1_gamma / 2 * grad_penalty
|
| 469 |
grad_penalty.backward()
|
| 470 |
+
else:
|
| 471 |
+
if global_step % args.lazy_reg == 0:
|
| 472 |
+
if args.grad_penalty_cond:
|
| 473 |
+
inputs = (x_t,) + (cond,) if use_cond_attn_discr else (cond_for_discr,)
|
| 474 |
+
grad_real = torch.autograd.grad(
|
| 475 |
+
outputs=D_real.sum(), inputs=inputs, create_graph=True
|
| 476 |
+
)[0]
|
| 477 |
+
grad_real = torch.cat([g.view(g.size(0), -1) for g in grad_real])
|
| 478 |
+
grad_penalty = (grad_real.norm(2, dim=1) ** 2).mean()
|
| 479 |
+
grad_penalty = args.r1_gamma / 2 * grad_penalty
|
| 480 |
+
grad_penalty.backward()
|
| 481 |
+
else:
|
| 482 |
+
grad_real = torch.autograd.grad(
|
| 483 |
+
outputs=D_real.sum(), inputs=x_t, create_graph=True
|
| 484 |
+
)[0]
|
| 485 |
+
grad_penalty = (
|
| 486 |
+
grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2
|
| 487 |
+
).mean()
|
| 488 |
+
|
| 489 |
+
grad_penalty = args.r1_gamma / 2 * grad_penalty
|
| 490 |
+
grad_penalty.backward()
|
| 491 |
|
| 492 |
# train with fake
|
| 493 |
latent_z = torch.randn(batch_size, nz, device=device)
|
| 494 |
|
|
|
|
| 495 |
x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
|
| 496 |
x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
|
| 497 |
|
|
|
|
| 500 |
|
| 501 |
errD_fake = F.softplus(output)
|
| 502 |
errD_fake = errD_fake.mean()
|
| 503 |
+
|
| 504 |
+
if args.mismatch_loss:
|
| 505 |
+
# following https://github.com/tobran/DF-GAN/blob/bc38a4f795c294b09b4ef5579cd4ff78807e5b96/code/lib/modules.py,
|
| 506 |
+
# we add a discr loss for (real image, non matching text)
|
| 507 |
+
#inds = torch.flip(torch.arange(len(x_t)), dims=(0,))
|
| 508 |
+
inds = torch.cat([torch.arange(1,len(x_t)),torch.arange(1)])
|
| 509 |
+
cond_for_discr_mis = (cond_pooled[inds], cond[inds], cond_mask[inds]) if use_cond_attn_discr else cond_pooled[inds]
|
| 510 |
+
D_real_mis = netD(x_t, t, x_tp1.detach(), cond=cond_for_discr_mis).view(-1)
|
| 511 |
+
errD_real_mis = F.softplus(D_real_mis)
|
| 512 |
+
errD_real_mis = errD_real_mis.mean()
|
| 513 |
+
errD_fake = errD_fake * 0.5 + errD_real_mis * 0.5
|
| 514 |
+
|
| 515 |
errD_fake.backward()
|
| 516 |
|
| 517 |
|
|
|
|
| 638 |
|
| 639 |
parser.add_argument('--resume', action='store_true',default=False)
|
| 640 |
parser.add_argument('--masked_mean', action='store_true',default=False)
|
| 641 |
+
parser.add_argument('--mismatch_loss', action='store_true',default=False)
|
| 642 |
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
| 643 |
parser.add_argument('--cross_attention', action='store_true',default=False)
|
| 644 |
|
|
|
|
| 663 |
help='channel multiplier')
|
| 664 |
parser.add_argument('--num_res_blocks', type=int, default=2,
|
| 665 |
help='number of resnet blocks per scale')
|
| 666 |
+
parser.add_argument('--attn_resolutions', default=(16,), nargs='+', type=int,
|
| 667 |
help='resolution of applying attention')
|
| 668 |
parser.add_argument('--dropout', type=float, default=0.,
|
| 669 |
help='drop-out rate')
|
|
|
|
| 712 |
parser.add_argument('--beta2', type=float, default=0.9,
|
| 713 |
help='beta2 for adam')
|
| 714 |
parser.add_argument('--no_lr_decay',action='store_true', default=False)
|
| 715 |
+
parser.add_argument('--grad_penalty_cond', action='store_true',default=False)
|
| 716 |
+
|
| 717 |
parser.add_argument('--use_ema', action='store_true', default=False,
|
| 718 |
help='use EMA or not')
|
| 719 |
parser.add_argument('--ema_decay', type=float, default=0.9999, help='decay rate for EMA')
|
| 720 |
|
| 721 |
parser.add_argument('--r1_gamma', type=float, default=0.05, help='coef for r1 reg')
|
| 722 |
+
|
| 723 |
parser.add_argument('--lazy_reg', type=int, default=None,
|
| 724 |
help='lazy regulariation.')
|
| 725 |
|