Spaces:
Runtime error
Runtime error
| """ | |
| Approach: "StyleMC: Multi-Channel Based Fast Text-Guided Image Generation and Manipulation" | |
| Original source code: | |
| https://github.com/autonomousvision/stylegan_xl/blob/f9be58e98110bd946fcdadef2aac8345466faaf3/run_stylemc.py# | |
| Modified by Håkon Hukkelås | |
| """ | |
| import os | |
| from pathlib import Path | |
| import tqdm | |
| import re | |
| import click | |
| from dp2 import utils | |
| import tops | |
| from typing import List, Optional | |
| import PIL.Image | |
| import imageio | |
| from timeit import default_timer as timer | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision.transforms.functional import resize, normalize | |
| from dp2.infer import build_trained_generator | |
| import clip | |
| #---------------------------------------------------------------------------- | |
| class AverageMeter(object): | |
| """Computes and stores the average and current value""" | |
| def __init__(self, name, fmt=':f'): | |
| self.name = name | |
| self.fmt = fmt | |
| self.reset() | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |
| def __str__(self): | |
| fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' | |
| return fmtstr.format(**self.__dict__) | |
| class ProgressMeter(object): | |
| def __init__(self, num_batches, meters, prefix=""): | |
| self.batch_fmtstr = self._get_batch_fmtstr(num_batches) | |
| self.meters = meters | |
| self.prefix = prefix | |
| def display(self, batch): | |
| entries = [self.prefix + self.batch_fmtstr.format(batch)] | |
| entries += [str(meter) for meter in self.meters] | |
| print('\t'.join(entries)) | |
| def _get_batch_fmtstr(self, num_batches): | |
| num_digits = len(str(num_batches // 1)) | |
| fmt = '{:' + str(num_digits) + 'd}' | |
| return '[' + fmt + '/' + fmt.format(num_batches) + ']' | |
| def save_image(img, path): | |
| img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) | |
| PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(path) | |
| def unravel_index(index, shape): | |
| out = [] | |
| for dim in reversed(shape): | |
| out.append(index % dim) | |
| index = index // dim | |
| return tuple(reversed(out)) | |
| def num_range(s: str) -> List[int]: | |
| '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.''' | |
| range_re = re.compile(r'^(\d+)-(\d+)$') | |
| m = range_re.match(s) | |
| if m: | |
| return list(range(int(m.group(1)), int(m.group(2))+1)) | |
| vals = s.split(',') | |
| return [int(x) for x in vals] | |
| #---------------------------------------------------------------------------- | |
| def spherical_dist_loss(x, y): | |
| x = F.normalize(x, dim=-1) | |
| y = F.normalize(y, dim=-1) | |
| return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) | |
| def prompts_dist_loss(x, targets, loss): | |
| if len(targets) == 1: # Keeps consistent results vs previous method for single objective guidance | |
| return loss(x, targets[0]) | |
| distances = [loss(x, target) for target in targets] | |
| return torch.stack(distances, dim=-1).sum(dim=-1) | |
| def embed_text(model, prompt, device='cuda'): | |
| return | |
| #---------------------------------------------------------------------------- | |
| def generate_edit( | |
| G, | |
| dl, | |
| direction, | |
| edit_strength, | |
| path, | |
| ): | |
| for it, batch in enumerate(dl): | |
| batch["embedding"] = None | |
| styles = get_styles(None, G, batch, truncation_value=0) | |
| imgs = [] | |
| grad_changes = [_*edit_strength for _ in [0, 0.25, 0.5, 0.75, 1]] | |
| grad_changes = [*[-x for x in grad_changes][::-1], *grad_changes] | |
| batch = {k: tops.to_cuda(v) if v is not None else v for k,v in batch.items()} | |
| for i, grad_change in enumerate(grad_changes): | |
| s = styles + direction*grad_change | |
| img = G(**batch, s=iter(s))["img"] | |
| img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255) | |
| imgs.append(img[0].to(torch.uint8).cpu().numpy()) | |
| PIL.Image.fromarray(np.concatenate(imgs, axis=1), 'RGB').save(path + f'{it}.png') | |
| def get_styles(seed, G: torch.nn.Module, batch, truncation_value=1): | |
| all_styles = [] | |
| if seed is None: | |
| z = np.random.normal(0, 0, size=(1, G.z_channels)) | |
| else: | |
| z = np.random.RandomState(seed=seed).normal(0, 1, size=(1, G.z_channels)) | |
| z_idx = np.random.RandomState(seed=seed).randint(0, len(G.style_net.w_centers)) | |
| w_c = G.style_net.w_centers[z_idx].to(tops.get_device()).view(1, -1) | |
| w = G.style_net(torch.from_numpy(z).to(tops.get_device())) | |
| w = w_c.to(w.dtype).lerp(w, truncation_value) | |
| if hasattr(G, "get_comod_y"): | |
| w = G.get_comod_y(batch, w) | |
| for block in G.modules(): | |
| if not hasattr(block, "affine") or not hasattr(block.affine, "weight"): | |
| continue | |
| gamma0 = block.affine(w) | |
| if hasattr(block, "affine_beta"): | |
| beta0 = block.affine_beta(w) | |
| gamma0 = torch.cat((gamma0, beta0), dim=1) | |
| all_styles.append(gamma0) | |
| max_ch = max([s.shape[-1] for s in all_styles]) | |
| all_styles = [F.pad(s, ((0, max_ch - s.shape[-1])), "constant", 0) for s in all_styles] | |
| all_styles = torch.cat(all_styles) | |
| return all_styles | |
| def get_and_cache_direction(output_dir: Path, dl_val, G, text_prompt): | |
| cache_path = output_dir.joinpath( | |
| "stylemc_cache", text_prompt.replace(" ", "_") + ".torch") | |
| if cache_path.is_file(): | |
| print("Loaded cache from:", cache_path) | |
| return torch.load(cache_path) | |
| direction = find_direction(G, text_prompt, None, dl_val=iter(dl_val)) | |
| cache_path.parent.mkdir(exist_ok=True, parents=True) | |
| torch.save(direction, cache_path) | |
| return direction | |
| def find_direction( | |
| G, | |
| text_prompt, | |
| batches, | |
| #layers, | |
| n_iterations=128*8, | |
| batch_size=8, | |
| dl_val=None | |
| ): | |
| time_start = timer() | |
| clip_model = clip.load("ViT-B/16", device=tops.get_device())[0] | |
| target = [clip_model.encode_text(clip.tokenize(text_prompt).to(tops.get_device())).float()] | |
| all_styles = [] | |
| if dl_val is not None: | |
| first_batch = next(dl_val) | |
| else: | |
| first_batch = batches[0] | |
| first_batch["embedding"] = None if "embedding" not in first_batch else first_batch["embedding"] | |
| s = get_styles(0, G, first_batch) | |
| # stats tracker | |
| cos_sim_track = AverageMeter('cos_sim', ':.4f') | |
| norm_track = AverageMeter('norm', ':.4f') | |
| n_iterations = n_iterations // batch_size | |
| progress = ProgressMeter(n_iterations, [cos_sim_track, norm_track]) | |
| # initalize styles direction | |
| direction = torch.zeros(s.shape, device=tops.get_device()) | |
| direction.requires_grad_() | |
| utils.set_requires_grad(G, False) | |
| direction_tracker = torch.zeros_like(direction) | |
| opt = torch.optim.AdamW([direction], lr=0.05, betas=(0., 0.999), weight_decay=0.25) | |
| grads = [] | |
| for seed_idx in tqdm.trange(n_iterations): | |
| # forward pass through synthesis network with new styles | |
| if seed_idx == 0: | |
| batch = first_batch | |
| elif dl_val is not None: | |
| batch = next(dl_val) | |
| batch["embedding"] = None if "embedding" not in batch else batch["embedding"] | |
| else: | |
| batch = {k: tops.to_cuda(v) if v is not None else v for k, v in batches[seed_idx].items()} | |
| styles = get_styles(seed_idx, G, batch) + direction | |
| img = G(**batch, s=iter(styles))["img"] | |
| batch = {k: v.cpu() if v is not None else v for k, v in batch.items()} | |
| # clip loss | |
| img = (img + 1)/2 | |
| img = normalize(img, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) | |
| img = resize(img, (224, 224)) | |
| embeds = clip_model.encode_image(img) | |
| cos_sim = prompts_dist_loss(embeds, target, spherical_dist_loss) | |
| cos_sim.backward(retain_graph=True) | |
| # track stats | |
| cos_sim_track.update(cos_sim.item()) | |
| norm_track.update(torch.norm(direction).item()) | |
| if not (seed_idx % batch_size): | |
| # zeroing out gradients for non-optimized layers | |
| #layers_zeroed = torch.tensor([x for x in range(G.num_ws) if not x in layers]) | |
| #direction.grad[:, layers_zeroed] = 0 | |
| opt.step() | |
| grads.append(direction.grad.clone()) | |
| direction.grad.data.zero_() | |
| # keep track of gradients over time | |
| if seed_idx > 3: | |
| direction_tracker[grads[-2] * grads[-1] < 0] += 1 | |
| # plot stats | |
| progress.display(seed_idx) | |
| # throw out fluctuating channels | |
| direction = direction.detach() | |
| direction[direction_tracker > n_iterations / 4] = 0 | |
| print(direction) | |
| print(f"Time for direction search: {timer() - time_start:.2f} s") | |
| return direction | |
| #@click.option('--layers', type=num_range, help='Restrict the style space to a range of layers. We recommend not to optimize the critically sampled layers (last 3).', required=True) | |
| def stylemc( | |
| config_path, | |
| #layers: List[int], | |
| text_prompt: str, | |
| edit_strength: float, | |
| outdir: str, | |
| ): | |
| cfg = utils.load_config(config_path) | |
| G = build_trained_generator(cfg) | |
| cfg.train.batch_size = 1 | |
| n_iterations = 256 | |
| dl_val = tops.config.instantiate(cfg.data.val.loader) | |
| direction = find_direction(G, text_prompt, None, n_iterations=n_iterations, dl_val=iter(dl_val)) | |
| text_prompt = text_prompt.replace(" ", "_") | |
| generate_edit(G, input_path, direction, edit_strength, output_path) | |
| if __name__ == "__main__": | |
| stylemc() | |