|
""" |
|
@paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) |
|
@author: yangxy ([email protected]) |
|
""" |
|
import torch |
|
import os |
|
import cv2 |
|
import glob |
|
import numpy as np |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from torchvision import transforms, utils |
|
from gpen_model import FullGenerator, FullGenerator_SR |
|
|
|
|
|
class FaceGAN(object): |
|
def __init__( |
|
self, |
|
base_dir="./", |
|
in_size=512, |
|
out_size=512, |
|
model=None, |
|
channel_multiplier=2, |
|
narrow=1, |
|
key=None, |
|
is_norm=True, |
|
device="cuda", |
|
): |
|
self.mfile = os.path.join(base_dir, "weights", model + ".pth") |
|
self.n_mlp = 8 |
|
self.device = device |
|
self.is_norm = is_norm |
|
self.in_resolution = in_size |
|
self.out_resolution = out_size |
|
self.key = key |
|
self.load_model(channel_multiplier, narrow) |
|
|
|
def load_model(self, channel_multiplier=2, narrow=1): |
|
if self.in_resolution == self.out_resolution: |
|
self.model = FullGenerator( |
|
self.in_resolution, |
|
512, |
|
self.n_mlp, |
|
channel_multiplier, |
|
narrow=narrow, |
|
device=self.device, |
|
) |
|
else: |
|
self.model = FullGenerator_SR( |
|
self.in_resolution, |
|
self.out_resolution, |
|
512, |
|
self.n_mlp, |
|
channel_multiplier, |
|
narrow=narrow, |
|
device=self.device, |
|
) |
|
pretrained_dict = torch.load(self.mfile, map_location=torch.device("cpu")) |
|
if self.key is not None: |
|
pretrained_dict = pretrained_dict[self.key] |
|
self.model.load_state_dict(pretrained_dict) |
|
self.model.to(self.device) |
|
self.model.eval() |
|
|
|
def process(self, img): |
|
img = cv2.resize(img, (self.in_resolution, self.in_resolution)) |
|
img_t = self.img2tensor(img) |
|
|
|
with torch.no_grad(): |
|
out, __ = self.model(img_t) |
|
|
|
out = self.tensor2img(out) |
|
|
|
return out |
|
|
|
def img2tensor(self, img): |
|
img_t = torch.from_numpy(img).to(self.device) / 255.0 |
|
if self.is_norm: |
|
img_t = (img_t - 0.5) / 0.5 |
|
img_t = img_t.permute(2, 0, 1).unsqueeze(0).flip(1) |
|
return img_t |
|
|
|
def tensor2img(self, img_t, pmax=255.0, imtype=np.uint8): |
|
if self.is_norm: |
|
img_t = img_t * 0.5 + 0.5 |
|
img_t = img_t.squeeze(0).permute(1, 2, 0).flip(2) |
|
img_np = np.clip(img_t.float().cpu().numpy(), 0, 1) * pmax |
|
|
|
return img_np.astype(imtype) |
|
|