Spaces:
Runtime error
Runtime error
| from argparse import Namespace | |
| import os | |
| from os.path import join as pjoin | |
| import random | |
| import sys | |
| from typing import ( | |
| Iterable, | |
| Optional, | |
| ) | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from torch.utils.tensorboard import SummaryWriter | |
| from torchvision.transforms import ( | |
| Compose, | |
| Grayscale, | |
| Resize, | |
| ToTensor, | |
| Normalize, | |
| ) | |
| from losses.joint_loss import JointLoss | |
| from model import Generator | |
| from tools.initialize import Initializer | |
| from tools.match_skin_histogram import match_skin_histogram | |
| from utils.projector_arguments import ProjectorArguments | |
| from utils import torch_helpers as th | |
| from utils.torch_helpers import make_image | |
| from utils.misc import stem | |
| from utils.optimize import Optimizer | |
| from models.degrade import ( | |
| Degrade, | |
| Downsample, | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| TOKEN = "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" | |
| def set_random_seed(seed: int): | |
| # FIXME (xuanluo): this setup still allows randomness somehow | |
| torch.manual_seed(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| def read_images(paths: str, max_size: Optional[int] = None): | |
| transform = Compose( | |
| [ | |
| Grayscale(), | |
| ToTensor(), | |
| ] | |
| ) | |
| imgs = [] | |
| for path in paths: | |
| img = Image.open(path) | |
| if max_size is not None and img.width > max_size: | |
| img = img.resize((max_size, max_size)) | |
| img = transform(img) | |
| imgs.append(img) | |
| imgs = torch.stack(imgs, 0) | |
| return imgs | |
| def normalize(img: torch.Tensor, mean=0.5, std=0.5): | |
| """[0, 1] -> [-1, 1]""" | |
| return (img - mean) / std | |
| def create_generator(file_name: str, path:str,args: Namespace, device: torch.device): | |
| path = hf_hub_download(f'{path}', | |
| f'{file_name}', | |
| use_auth_token=TOKEN) | |
| with open(path, 'rb') as f: | |
| generator = Generator(args.generator_size, 512, 8) | |
| generator.load_state_dict(torch.load(f)['g_ema'], strict=False) | |
| generator.eval() | |
| generator.to(device) | |
| return generator | |
| def save( | |
| path_prefixes: Iterable[str], | |
| imgs: torch.Tensor, # BCHW | |
| latents: torch.Tensor, | |
| noises: torch.Tensor, | |
| imgs_rand: Optional[torch.Tensor] = None, | |
| ): | |
| assert len(path_prefixes) == len(imgs) and len(latents) == len(path_prefixes) | |
| if imgs_rand is not None: | |
| assert len(imgs) == len(imgs_rand) | |
| imgs_arr = make_image(imgs) | |
| for path_prefix, img, latent, noise in zip(path_prefixes, imgs_arr, latents, noises): | |
| os.makedirs(os.path.dirname(path_prefix), exist_ok=True) | |
| cv2.imwrite(path_prefix + ".png", img[...,::-1]) | |
| torch.save({"latent": latent.detach().cpu(), "noise": noise.detach().cpu()}, | |
| path_prefix + ".pt") | |
| if imgs_rand is not None: | |
| imgs_arr = make_image(imgs_rand) | |
| for path_prefix, img in zip(path_prefixes, imgs_arr): | |
| cv2.imwrite(path_prefix + "-rand.png", img[...,::-1]) | |
| def main(args): | |
| opt_str = ProjectorArguments.to_string(args) | |
| print(opt_str) | |
| if args.rand_seed is not None: | |
| set_random_seed(args.rand_seed) | |
| device = th.device() | |
| # read inputs. TODO imgs_orig has channel 1 | |
| imgs_orig = read_images([args.input], max_size=args.generator_size).to(device) | |
| imgs = normalize(imgs_orig) # actually this will be overwritten by the histogram matching result | |
| # initialize | |
| with torch.no_grad(): | |
| init = Initializer(args).to(device) | |
| latent_init = init(imgs_orig) | |
| # create generator | |
| generator = create_generator(args, device) | |
| # init noises | |
| with torch.no_grad(): | |
| noises_init = generator.make_noise() | |
| # create a new input by matching the input's histogram to the sibling image | |
| with torch.no_grad(): | |
| sibling, _, sibling_rgbs = generator([latent_init], input_is_latent=True, noise=noises_init) | |
| mh_dir = pjoin(args.results_dir, stem(args.input)) | |
| imgs = match_skin_histogram( | |
| imgs, sibling, | |
| args.spectral_sensitivity, | |
| pjoin(mh_dir, "input_sibling"), | |
| pjoin(mh_dir, "skin_mask"), | |
| matched_hist_fn=mh_dir.rstrip(os.sep) + f"_{args.spectral_sensitivity}.png", | |
| normalize=normalize, | |
| ).to(device) | |
| torch.cuda.empty_cache() | |
| # TODO imgs has channel 3 | |
| degrade = Degrade(args).to(device) | |
| rgb_levels = generator.get_latent_size(args.coarse_min) // 2 + len(args.wplus_step) - 1 | |
| criterion = JointLoss( | |
| args, imgs, | |
| sibling=sibling.detach(), sibling_rgbs=sibling_rgbs[:rgb_levels]).to(device) | |
| # save initialization | |
| save( | |
| [pjoin(args.results_dir, f"{stem(args.input)}-{opt_str}-init")], | |
| sibling, latent_init, noises_init, | |
| ) | |
| writer = SummaryWriter(pjoin(args.log_dir, f"{stem(args.input)}/{opt_str}")) | |
| # start optimize | |
| latent, noises = Optimizer.optimize(generator, criterion, degrade, imgs, latent_init, noises_init, args, writer=writer) | |
| # generate output | |
| img_out, _, _ = generator([latent], input_is_latent=True, noise=noises) | |
| img_out_rand_noise, _, _ = generator([latent], input_is_latent=True) | |
| # save output | |
| save( | |
| [pjoin(args.results_dir, f"{stem(args.input)}-{opt_str}")], | |
| img_out, latent, noises, | |
| imgs_rand=img_out_rand_noise | |
| ) | |
| def parse_args(): | |
| return ProjectorArguments().parse() | |
| if __name__ == "__main__": | |
| sys.exit(main(parse_args())) | |