Spaces:
Build error
Build error
Commit
·
7f6643a
1
Parent(s):
01ad5b5
Upload latent_optimization.py
Browse files- latent_optimization.py +107 -0
latent_optimization.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import models.stylegan2.lpips as lpips
|
| 2 |
+
from torch import autograd, optim
|
| 3 |
+
from torchvision import transforms, utils
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import torch
|
| 6 |
+
from scripts.align_all_parallel import align_face
|
| 7 |
+
from utils.inference_utils import noise_regularize, noise_normalize_, get_lr, latent_noise, visualize
|
| 8 |
+
|
| 9 |
+
def latent_optimization(frame, pspex, landmarkpredictor, step=500, device='cuda'):
|
| 10 |
+
percept = lpips.PerceptualLoss(
|
| 11 |
+
model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
transform = transforms.Compose([
|
| 15 |
+
transforms.ToTensor(),
|
| 16 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5,0.5,0.5]),
|
| 17 |
+
])
|
| 18 |
+
|
| 19 |
+
with torch.no_grad():
|
| 20 |
+
|
| 21 |
+
noise_sample = torch.randn(1000, 512, device=device)
|
| 22 |
+
latent_out = pspex.decoder.style(noise_sample)
|
| 23 |
+
latent_mean = latent_out.mean(0)
|
| 24 |
+
latent_std = ((latent_out - latent_mean).pow(2).sum() / 1000) ** 0.5
|
| 25 |
+
|
| 26 |
+
y = transform(frame).unsqueeze(dim=0).to(device)
|
| 27 |
+
I_ = align_face(frame, landmarkpredictor)
|
| 28 |
+
I_ = transform(I_).unsqueeze(dim=0).to(device)
|
| 29 |
+
wplus = pspex.encoder(I_) + pspex.latent_avg.unsqueeze(0)
|
| 30 |
+
_, f = pspex.encoder(y, return_feat=True)
|
| 31 |
+
latent_in = wplus.detach().clone()
|
| 32 |
+
feat = [f[0].detach().clone(), f[1].detach().clone()]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# wplus and f to optimize
|
| 37 |
+
latent_in.requires_grad = True
|
| 38 |
+
feat[0].requires_grad = True
|
| 39 |
+
feat[1].requires_grad = True
|
| 40 |
+
|
| 41 |
+
noises_single = pspex.decoder.make_noise()
|
| 42 |
+
basic_height, basic_width = int(y.shape[2]*32/256), int(y.shape[3]*32/256)
|
| 43 |
+
noises = []
|
| 44 |
+
for noise in noises_single:
|
| 45 |
+
noises.append(noise.new_empty(y.shape[0], 1, max(basic_height, int(y.shape[2]*noise.shape[2]/256)),
|
| 46 |
+
max(basic_width, int(y.shape[3]*noise.shape[2]/256))).normal_())
|
| 47 |
+
for noise in noises:
|
| 48 |
+
noise.requires_grad = True
|
| 49 |
+
|
| 50 |
+
init_lr=0.05
|
| 51 |
+
optimizer = optim.Adam(feat + noises, lr=init_lr)
|
| 52 |
+
optimizer2 = optim.Adam([latent_in], lr=init_lr)
|
| 53 |
+
noise_weight = 0.05 * 0.2
|
| 54 |
+
|
| 55 |
+
pbar = tqdm(range(step))
|
| 56 |
+
latent_path = []
|
| 57 |
+
|
| 58 |
+
for i in pbar:
|
| 59 |
+
t = i / step
|
| 60 |
+
lr = get_lr(t, init_lr)
|
| 61 |
+
optimizer.param_groups[0]["lr"] = lr
|
| 62 |
+
optimizer2.param_groups[0]["lr"] = get_lr(t, init_lr)
|
| 63 |
+
|
| 64 |
+
noise_strength = latent_std * noise_weight * max(0, 1 - t / 0.75) ** 2
|
| 65 |
+
latent_n = latent_noise(latent_in, noise_strength.item())
|
| 66 |
+
|
| 67 |
+
y_hat, _ = pspex.decoder([latent_n], input_is_latent=True, randomize_noise=False,
|
| 68 |
+
first_layer_feature=feat, noise=noises)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
batch, channel, height, width = y_hat.shape
|
| 72 |
+
|
| 73 |
+
if height > y.shape[2]:
|
| 74 |
+
factor = height // y.shape[2]
|
| 75 |
+
|
| 76 |
+
y_hat = y_hat.reshape(
|
| 77 |
+
batch, channel, height // factor, factor, width // factor, factor
|
| 78 |
+
)
|
| 79 |
+
y_hat = y_hat.mean([3, 5])
|
| 80 |
+
|
| 81 |
+
p_loss = percept(y_hat, y).sum()
|
| 82 |
+
n_loss = noise_regularize(noises) * 1e3
|
| 83 |
+
|
| 84 |
+
loss = p_loss + n_loss
|
| 85 |
+
|
| 86 |
+
optimizer.zero_grad()
|
| 87 |
+
optimizer2.zero_grad()
|
| 88 |
+
loss.backward()
|
| 89 |
+
optimizer.step()
|
| 90 |
+
optimizer2.step()
|
| 91 |
+
|
| 92 |
+
noise_normalize_(noises)
|
| 93 |
+
|
| 94 |
+
''' for visualization
|
| 95 |
+
if (i + 1) % 100 == 0 or i == 0:
|
| 96 |
+
viz = torch.cat((y_hat,y,y_hat-y), dim=3)
|
| 97 |
+
visualize(torch.clamp(viz[0].cpu(),-1,1), 60)
|
| 98 |
+
'''
|
| 99 |
+
|
| 100 |
+
pbar.set_description(
|
| 101 |
+
(
|
| 102 |
+
f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};"
|
| 103 |
+
f" lr: {lr:.4f}"
|
| 104 |
+
)
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
return latent_n, feat, noises, wplus, f
|