Spaces:
Build error
Build error
| from __future__ import annotations | |
| import numpy as np | |
| import gradio as gr | |
| import os | |
| import pathlib | |
| import gc | |
| import torch | |
| import dlib | |
| import cv2 | |
| import PIL | |
| from tqdm import tqdm | |
| import numpy as np | |
| import torch.nn.functional as F | |
| import torchvision | |
| from torchvision import transforms, utils | |
| from argparse import Namespace | |
| from datasets import augmentations | |
| from huggingface_hub import hf_hub_download | |
| from scripts.align_all_parallel import align_face | |
| from latent_optimization import latent_optimization | |
| from utils.inference_utils import save_image, load_image, visualize, get_video_crop_parameter, tensor2cv2, tensor2label, labelcolormap | |
| from models.psp import pSp | |
| from models.bisenet.model import BiSeNet | |
| from models.stylegan2.model import Generator | |
| class Model(): | |
| def __init__(self, device): | |
| super().__init__() | |
| self.device = device | |
| self.task_name = None | |
| self.editing_w = None | |
| self.pspex = None | |
| self.landmarkpredictor = dlib.shape_predictor(hf_hub_download('PKUWilliamYang/StyleGANEX', 'pretrained_models/shape_predictor_68_face_landmarks.dat')) | |
| self.transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5,0.5,0.5]), | |
| ]) | |
| self.to_tensor = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
| ]) | |
| self.maskpredictor = BiSeNet(n_classes=19) | |
| self.maskpredictor.load_state_dict(torch.load(hf_hub_download('PKUWilliamYang/VToonify', 'models/faceparsing.pth'), map_location='cpu')) | |
| self.maskpredictor.to(self.device).eval() | |
| self.parameters = {} | |
| self.parameters['inversion'] = {'path':'pretrained_models/styleganex_inversion.pt', 'image_path':'./data/ILip77SbmOE.png'} | |
| self.parameters['sr-32'] = {'path':'pretrained_models/styleganex_sr32.pt', 'image_path':'./data/pexels-daniel-xavier-1239291.jpg'} | |
| self.parameters['sr'] = {'path':'pretrained_models/styleganex_sr.pt', 'image_path':'./data/pexels-daniel-xavier-1239291.jpg'} | |
| self.parameters['sketch2face'] = {'path':'pretrained_models/styleganex_sketch2face.pt', 'image_path':'./data/234_sketch.jpg'} | |
| self.parameters['mask2face'] = {'path':'pretrained_models/styleganex_mask2face.pt', 'image_path':'./data/540.jpg'} | |
| self.parameters['edit_age'] = {'path':'pretrained_models/styleganex_edit_age.pt', 'image_path':'./data/390.mp4'} | |
| self.parameters['edit_hair'] = {'path':'pretrained_models/styleganex_edit_hair.pt', 'image_path':'./data/390.mp4'} | |
| self.parameters['toonify_pixar'] = {'path':'pretrained_models/styleganex_toonify_pixar.pt', 'image_path':'./data/pexels-anthony-shkraba-production-8136210.mp4'} | |
| self.parameters['toonify_cartoon'] = {'path':'pretrained_models/styleganex_toonify_cartoon.pt', 'image_path':'./data/pexels-anthony-shkraba-production-8136210.mp4'} | |
| self.parameters['toonify_arcane'] = {'path':'pretrained_models/styleganex_toonify_arcane.pt', 'image_path':'./data/pexels-anthony-shkraba-production-8136210.mp4'} | |
| self.print_log = True | |
| self.editing_dicts = torch.load(hf_hub_download('PKUWilliamYang/StyleGANEX', 'direction_dics.pt')) | |
| self.generator = Generator(1024, 512, 8) | |
| self.model_type = None | |
| self.error_info = 'Error: no face detected! \ | |
| StyleGANEX uses dlib.get_frontal_face_detector but sometimes it fails to detect a face. \ | |
| You can try several times or use other images until a face is detected, \ | |
| then switch back to the original image.' | |
| def load_model(self, task_name: str) -> None: | |
| if task_name == self.task_name: | |
| return | |
| if self.pspex is not None: | |
| del self.pspex | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| path = self.parameters[task_name]['path'] | |
| local_path = hf_hub_download('PKUWilliamYang/StyleGANEX', path) | |
| ckpt = torch.load(local_path, map_location='cpu') | |
| opts = ckpt['opts'] | |
| opts['checkpoint_path'] = local_path | |
| opts['device'] = self.device | |
| opts = Namespace(**opts) | |
| self.pspex = pSp(opts, ckpt).to(self.device).eval() | |
| self.pspex.latent_avg = self.pspex.latent_avg.to(self.device) | |
| if 'editing_w' in ckpt.keys(): | |
| self.editing_w = ckpt['editing_w'].clone().to(self.device) | |
| self.task_name = task_name | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def load_G_model(self, model_type: str) -> None: | |
| if model_type == self.model_type: | |
| return | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| local_path = hf_hub_download('rinong/stylegan-nada-models', model_type+'.pt') | |
| self.generator.load_state_dict(torch.load(local_path, map_location='cpu')['g_ema'], strict=False) | |
| self.generator.to(self.device).eval() | |
| self.model_type = model_type | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def tensor2np(self, img): | |
| tmp = ((img.cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8) | |
| return tmp | |
| def process_sr(self, input_image: str, resize_scale: int, model: str) -> list[np.ndarray]: | |
| #false_image = np.zeros((256,256,3), np.uint8) | |
| #info = 'Error: no face detected! Please retry or change the photo.' | |
| if input_image is None: | |
| #return [false_image, false_image], 'Error: fail to load empty file.' | |
| raise gr.Error("Error: fail to load empty file.") | |
| frame = cv2.imread(input_image) | |
| if frame is None: | |
| #return [false_image, false_image], 'Error: fail to load the image.' | |
| raise gr.Error("Error: fail to load the image.") | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| if model is None or model == 'SR for 32x': | |
| task_name = 'sr-32' | |
| resize_scale = 32 | |
| else: | |
| task_name = 'sr' | |
| with torch.no_grad(): | |
| paras = get_video_crop_parameter(frame, self.landmarkpredictor) | |
| if paras is None: | |
| #return [false_image, false_image], info | |
| raise gr.Error(self.error_info) | |
| h,w,top,bottom,left,right,scale = paras | |
| H, W = int(bottom-top), int(right-left) | |
| frame = cv2.resize(frame, (w, h))[top:bottom, left:right] | |
| x1 = PIL.Image.fromarray(np.uint8(frame)) | |
| x1 = augmentations.BilinearResize(factors=[resize_scale//4])(x1) | |
| x1_up = x1.resize((W, H)) | |
| x2_up = align_face(np.array(x1_up), self.landmarkpredictor) | |
| if x2_up is None: | |
| #return [false_image, false_image], 'Error: no face detected! Please retry or change the photo.' | |
| raise gr.Error(self.error_info) | |
| x1_up = transforms.ToTensor()(x1_up).unsqueeze(dim=0).to(self.device) * 2 - 1 | |
| x2_up = self.transform(x2_up).unsqueeze(dim=0).to(self.device) | |
| if self.print_log: print('image loaded') | |
| self.load_model(task_name) | |
| if self.print_log: print('model %s loaded'%(task_name)) | |
| y_hat = torch.clamp(self.pspex(x1=x1_up, x2=x2_up, use_skip=self.pspex.opts.use_skip, resize=False), -1, 1) | |
| return [self.tensor2np(x1_up[0]), self.tensor2np(y_hat[0])] | |
| def process_s2f(self, input_image: str, seed: int) -> np.ndarray: | |
| task_name = 'sketch2face' | |
| with torch.no_grad(): | |
| x1 = transforms.ToTensor()(PIL.Image.open(input_image)).unsqueeze(0).to(self.device) | |
| if x1.shape[2] > 513: | |
| x1 = x1[:,:,(x1.shape[2]//2-256)//8*8:(x1.shape[2]//2+256)//8*8] | |
| if x1.shape[3] > 513: | |
| x1 = x1[:,:,:,(x1.shape[3]//2-256)//8*8:(x1.shape[3]//2+256)//8*8] | |
| x1 = x1[:,0:1] # uploaded files will be transformed to 3-channel RGB image! | |
| if self.print_log: print('image loaded') | |
| self.load_model(task_name) | |
| if self.print_log: print('model %s loaded'%(task_name)) | |
| self.pspex.train() | |
| torch.manual_seed(seed) | |
| y_hat = self.pspex(x1=x1, resize=False, latent_mask=[8,9,10,11,12,13,14,15,16,17], use_skip=self.pspex.opts.use_skip, | |
| inject_latent= self.pspex.decoder.style(torch.randn(1, 512).to(self.device)).unsqueeze(1).repeat(1,18,1) * 0.7) | |
| y_hat = torch.clamp(y_hat, -1, 1) | |
| self.pspex.eval() | |
| return self.tensor2np(y_hat[0]) | |
| def process_m2f(self, input_image: str, input_type: str, seed: int) -> list[np.ndarray]: | |
| #false_image = np.zeros((256,256,3), np.uint8) | |
| if input_image is None: | |
| raise gr.Error('Error: fail to load empty file.' ) | |
| #return [false_image, false_image], 'Error: fail to load empty file.' | |
| task_name = 'mask2face' | |
| with torch.no_grad(): | |
| if input_type == 'parsing mask': | |
| x1 = PIL.Image.open(input_image).getchannel(0) # uploaded files will be transformed to 3-channel RGB image! | |
| x1 = augmentations.ToOneHot(19)(x1) | |
| x1 = transforms.ToTensor()(x1).unsqueeze(dim=0).float().to(self.device) | |
| #print(x1.shape) | |
| else: | |
| frame = cv2.imread(input_image) | |
| if frame is None: | |
| #return [false_image, false_image], 'Error: fail to load the image.' | |
| raise gr.Error('Error: fail to load the image.' ) | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| paras = get_video_crop_parameter(frame, self.landmarkpredictor) | |
| if paras is None: | |
| #return [false_image, false_image], 'Error: no face detected! Please retry or change the photo.' | |
| raise gr.Error(self.error_info) | |
| h,w,top,bottom,left,right,scale = paras | |
| H, W = int(bottom-top), int(right-left) | |
| frame = cv2.resize(frame, (w, h))[top:bottom, left:right] | |
| # convert face image to segmentation mask | |
| x1 = self.to_tensor(frame).unsqueeze(0).to(self.device) | |
| # upsample image for precise segmentation | |
| x1 = F.interpolate(x1, scale_factor=2, mode='bilinear') | |
| x1 = self.maskpredictor(x1)[0] | |
| x1 = F.interpolate(x1, scale_factor=0.5).argmax(dim=1) | |
| x1 = F.one_hot(x1, num_classes=19).permute(0, 3, 1, 2).float().to(self.device) | |
| if x1.shape[2] > 513: | |
| x1 = x1[:,:,(x1.shape[2]//2-256)//8*8:(x1.shape[2]//2+256)//8*8] | |
| if x1.shape[3] > 513: | |
| x1 = x1[:,:,:,(x1.shape[3]//2-256)//8*8:(x1.shape[3]//2+256)//8*8] | |
| x1_viz = (tensor2label(x1[0], 19) / 192 * 256).astype(np.uint8) | |
| if self.print_log: print('image loaded') | |
| self.load_model(task_name) | |
| if self.print_log: print('model %s loaded'%(task_name)) | |
| self.pspex.train() | |
| torch.manual_seed(seed) | |
| y_hat = self.pspex(x1=x1, resize=False, latent_mask=[8,9,10,11,12,13,14,15,16,17], use_skip=self.pspex.opts.use_skip, | |
| inject_latent= self.pspex.decoder.style(torch.randn(1, 512).to(self.device)).unsqueeze(1).repeat(1,18,1) * 0.7) | |
| y_hat = torch.clamp(y_hat, -1, 1) | |
| self.pspex.eval() | |
| return [x1_viz, self.tensor2np(y_hat[0])] | |
| def process_editing(self, input_image: str, scale_factor: float, model_type: str) -> np.ndarray: | |
| #false_image = np.zeros((256,256,3), np.uint8) | |
| #info = 'Error: no face detected! Please retry or change the photo.' | |
| if input_image is None: | |
| #return false_image, false_image, 'Error: fail to load empty file.' | |
| raise gr.Error('Error: fail to load empty file.') | |
| frame = cv2.imread(input_image) | |
| if frame is None: | |
| #return false_image, false_image, 'Error: fail to load the image.' | |
| raise gr.Error('Error: fail to load the image.') | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| if model_type is None or model_type == 'reduce age': | |
| task_name = 'edit_age' | |
| else: | |
| task_name = 'edit_hair' | |
| with torch.no_grad(): | |
| paras = get_video_crop_parameter(frame, self.landmarkpredictor) | |
| if paras is None: | |
| #return false_image, false_image, info | |
| raise gr.Error(self.error_info) | |
| h,w,top,bottom,left,right,scale = paras | |
| H, W = int(bottom-top), int(right-left) | |
| frame = cv2.resize(frame, (w, h))[top:bottom, left:right] | |
| x1 = self.transform(frame).unsqueeze(0).to(self.device) | |
| x2 = align_face(frame, self.landmarkpredictor) | |
| if x2 is None: | |
| #return false_image, 'Error: no face detected! Please retry or change the photo.' | |
| raise gr.Error(self.error_info) | |
| x2 = self.transform(x2).unsqueeze(dim=0).to(self.device) | |
| if self.print_log: print('image loaded') | |
| self.load_model(task_name) | |
| if self.print_log: print('model %s loaded'%(task_name)) | |
| y_hat = self.pspex(x1=x1, x2=x2, use_skip=self.pspex.opts.use_skip, zero_noise=True, | |
| resize=False, editing_w= - scale_factor* self.editing_w[0:1]) | |
| y_hat = torch.clamp(y_hat, -1, 1) | |
| return self.tensor2np(y_hat[0]) | |
| def process_vediting(self, input_video: str, scale_factor: float, model_type: str, frame_num: int) -> tuple[list[np.ndarray], str]: | |
| #false_image = np.zeros((256,256,3), np.uint8) | |
| #info = 'Error: no face detected! Please retry or change the video.' | |
| if input_video is None: | |
| #return [false_image], 'default.mp4', 'Error: fail to load empty file.' | |
| raise gr.Error('Error: fail to load empty file.') | |
| video_cap = cv2.VideoCapture(input_video) | |
| success, frame = video_cap.read() | |
| if success is False: | |
| #return [false_image], 'default.mp4', 'Error: fail to load the video.' | |
| raise gr.Error('Error: fail to load the video.') | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| if model_type is None or model_type == 'reduce age': | |
| task_name = 'edit_age' | |
| else: | |
| task_name = 'edit_hair' | |
| with torch.no_grad(): | |
| paras = get_video_crop_parameter(frame, self.landmarkpredictor) | |
| if paras is None: | |
| #return [false_image], 'default.mp4', info | |
| raise gr.Error(self.error_info) | |
| h,w,top,bottom,left,right,scale = paras | |
| H, W = int(bottom-top), int(right-left) | |
| frame = cv2.resize(frame, (w, h))[top:bottom, left:right] | |
| x1 = self.transform(frame).unsqueeze(0).to(self.device) | |
| x2 = align_face(frame, self.landmarkpredictor) | |
| if x2 is None: | |
| #return [false_image], 'default.mp4', info | |
| raise gr.Error(self.error_info) | |
| x2 = self.transform(x2).unsqueeze(dim=0).to(self.device) | |
| if self.print_log: print('first frame loaded') | |
| self.load_model(task_name) | |
| if self.print_log: print('model %s loaded'%(task_name)) | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| videoWriter = cv2.VideoWriter('output.mp4', fourcc, video_cap.get(5), (4*W, 4*H)) | |
| viz_frames = [] | |
| for i in range(frame_num): | |
| if i > 0: | |
| success, frame = video_cap.read() | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frame = cv2.resize(frame, (w, h))[top:bottom, left:right] | |
| x1 = self.transform(frame).unsqueeze(0).to(self.device) | |
| y_hat = self.pspex(x1=x1, x2=x2, use_skip=self.pspex.opts.use_skip, zero_noise=True, | |
| resize=False, editing_w= - scale_factor * self.editing_w[0:1]) | |
| y_hat = torch.clamp(y_hat, -1, 1) | |
| videoWriter.write(tensor2cv2(y_hat[0].cpu())) | |
| if i < min(frame_num, 4): | |
| viz_frames += [self.tensor2np(y_hat[0])] | |
| videoWriter.release() | |
| return viz_frames, 'output.mp4' | |
| def process_toonify(self, input_image: str, style_type: str) -> np.ndarray: | |
| #false_image = np.zeros((256,256,3), np.uint8) | |
| #info = 'Error: no face detected! Please retry or change the photo.' | |
| if input_image is None: | |
| raise gr.Error('Error: fail to load empty file.') | |
| #return false_image, false_image, 'Error: fail to load empty file.' | |
| frame = cv2.imread(input_image) | |
| if frame is None: | |
| raise gr.Error('Error: fail to load the image.') | |
| #return false_image, false_image, 'Error: fail to load the image.' | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| if style_type is None or style_type == 'Pixar': | |
| task_name = 'toonify_pixar' | |
| elif style_type == 'Cartoon': | |
| task_name = 'toonify_cartoon' | |
| else: | |
| task_name = 'toonify_arcane' | |
| with torch.no_grad(): | |
| paras = get_video_crop_parameter(frame, self.landmarkpredictor) | |
| if paras is None: | |
| raise gr.Error(self.error_info) | |
| #return false_image, false_image, info | |
| h,w,top,bottom,left,right,scale = paras | |
| H, W = int(bottom-top), int(right-left) | |
| frame = cv2.resize(frame, (w, h))[top:bottom, left:right] | |
| x1 = self.transform(frame).unsqueeze(0).to(self.device) | |
| x2 = align_face(frame, self.landmarkpredictor) | |
| if x2 is None: | |
| raise gr.Error(self.error_info) | |
| #return false_image, 'Error: no face detected! Please retry or change the photo.' | |
| x2 = self.transform(x2).unsqueeze(dim=0).to(self.device) | |
| if self.print_log: print('image loaded') | |
| self.load_model(task_name) | |
| if self.print_log: print('model %s loaded'%(task_name)) | |
| y_hat = self.pspex(x1=x1, x2=x2, use_skip=self.pspex.opts.use_skip, zero_noise=True, resize=False) | |
| y_hat = torch.clamp(y_hat, -1, 1) | |
| return self.tensor2np(y_hat[0]) | |
| def process_vtoonify(self, input_video: str, style_type: str, frame_num: int) -> tuple[list[np.ndarray], str]: | |
| #false_image = np.zeros((256,256,3), np.uint8) | |
| #info = 'Error: no face detected! Please retry or change the video.' | |
| if input_video is None: | |
| raise gr.Error('Error: fail to load empty file.') | |
| #return [false_image], 'default.mp4', 'Error: fail to load empty file.' | |
| video_cap = cv2.VideoCapture(input_video) | |
| success, frame = video_cap.read() | |
| if success is False: | |
| raise gr.Error('Error: fail to load the video.') | |
| #return [false_image], 'default.mp4', 'Error: fail to load the video.' | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| if style_type is None or style_type == 'Pixar': | |
| task_name = 'toonify_pixar' | |
| elif style_type == 'Cartoon': | |
| task_name = 'toonify_cartoon' | |
| else: | |
| task_name = 'toonify_arcane' | |
| with torch.no_grad(): | |
| paras = get_video_crop_parameter(frame, self.landmarkpredictor) | |
| if paras is None: | |
| raise gr.Error(self.error_info) | |
| #return [false_image], 'default.mp4', info | |
| h,w,top,bottom,left,right,scale = paras | |
| H, W = int(bottom-top), int(right-left) | |
| frame = cv2.resize(frame, (w, h))[top:bottom, left:right] | |
| x1 = self.transform(frame).unsqueeze(0).to(self.device) | |
| x2 = align_face(frame, self.landmarkpredictor) | |
| if x2 is None: | |
| raise gr.Error(self.error_info) | |
| #return [false_image], 'default.mp4', info | |
| x2 = self.transform(x2).unsqueeze(dim=0).to(self.device) | |
| if self.print_log: print('first frame loaded') | |
| self.load_model(task_name) | |
| if self.print_log: print('model %s loaded'%(task_name)) | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| videoWriter = cv2.VideoWriter('output.mp4', fourcc, video_cap.get(5), (4*W, 4*H)) | |
| viz_frames = [] | |
| for i in range(frame_num): | |
| if i > 0: | |
| success, frame = video_cap.read() | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frame = cv2.resize(frame, (w, h))[top:bottom, left:right] | |
| x1 = self.transform(frame).unsqueeze(0).to(self.device) | |
| y_hat = self.pspex(x1=x1, x2=x2, use_skip=self.pspex.opts.use_skip, zero_noise=True, resize=False) | |
| y_hat = torch.clamp(y_hat, -1, 1) | |
| videoWriter.write(tensor2cv2(y_hat[0].cpu())) | |
| if i < min(frame_num, 4): | |
| viz_frames += [self.tensor2np(y_hat[0])] | |
| videoWriter.release() | |
| return viz_frames, 'output.mp4' | |
| def process_inversion(self, input_image: str, optimize: str, input_latent: file-object, editing_options: str, | |
| scale_factor: float, seed: int) -> tuple[np.ndarray, np.ndarray]: | |
| #false_image = np.zeros((256,256,3), np.uint8) | |
| #info = 'Error: no face detected! Please retry or change the photo.' | |
| if input_image is None: | |
| raise gr.Error('Error: fail to load empty file.') | |
| #return false_image, false_image, 'Error: fail to load empty file.' | |
| frame = cv2.imread(input_image) | |
| if frame is None: | |
| raise gr.Error('Error: fail to load the image.') | |
| #return false_image, false_image, 'Error: fail to load the image.' | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| task_name = 'inversion' | |
| self.load_model(task_name) | |
| if self.print_log: print('model %s loaded'%(task_name)) | |
| if input_latent is not None: | |
| if '.pt' not in input_latent.name: | |
| raise gr.Error('Error: the latent format is wrong') | |
| #return false_image, false_image, 'Error: the latent format is wrong' | |
| latents = torch.load(input_latent.name) | |
| if 'wplus' not in latents.keys() or 'f' not in latents.keys(): | |
| raise gr.Error('Error: the latent format is wrong') | |
| #return false_image, false_image, 'Error: the latent format is wrong' | |
| wplus = latents['wplus'].to(self.device) # w+ | |
| f = [latents['f'][0].to(self.device)] # f | |
| elif optimize == 'Latent optimization': | |
| wplus, f, _, _, _ = latent_optimization(frame, self.pspex, self.landmarkpredictor, | |
| step=500, device=self.device) | |
| else: | |
| with torch.no_grad(): | |
| paras = get_video_crop_parameter(frame, self.landmarkpredictor) | |
| if paras is None: | |
| raise gr.Error(self.error_info) | |
| #return false_image, false_image, info | |
| h,w,top,bottom,left,right,scale = paras | |
| H, W = int(bottom-top), int(right-left) | |
| frame = cv2.resize(frame, (w, h))[top:bottom, left:right] | |
| x1 = self.transform(frame).unsqueeze(0).to(self.device) | |
| x2 = align_face(frame, self.landmarkpredictor) | |
| if x2 is None: | |
| raise gr.Error(self.error_info) | |
| #return false_image, false_image, 'Error: no face detected! Please retry or change the photo.' | |
| x2 = self.transform(x2).unsqueeze(dim=0).to(self.device) | |
| if self.print_log: print('image loaded') | |
| wplus = self.pspex.encoder(x2) + self.pspex.latent_avg.unsqueeze(0) | |
| _, f = self.pspex.encoder(x1, return_feat=True) | |
| with torch.no_grad(): | |
| y_hat, _ = self.pspex.decoder([wplus], input_is_latent=True, first_layer_feature=f) | |
| y_hat = torch.clamp(y_hat, -1, 1) | |
| if 'Style Mixing' in editing_options: | |
| torch.manual_seed(seed) | |
| wplus[:, 8:] = self.pspex.decoder.style(torch.randn(1, 512).to(self.device)).unsqueeze(1).repeat(1,10,1) * 0.7 | |
| y_hat_edit, _ = self.pspex.decoder([wplus], input_is_latent=True, first_layer_feature=f) | |
| elif 'Attribute Editing' in editing_options: | |
| editing_w = self.editing_dicts[editing_options[19:]].to(self.device) | |
| y_hat_edit, _ = self.pspex.decoder([wplus+scale_factor*editing_w], input_is_latent=True, first_layer_feature=f) | |
| elif 'Domain Transfer' in editing_options: | |
| self.load_G_model(editing_options[17:]) | |
| if self.print_log: print('model %s loaded'%(editing_options[17:])) | |
| y_hat_edit, _ = self.generator([wplus], input_is_latent=True, first_layer_feature=f) | |
| else: | |
| y_hat_edit = y_hat | |
| y_hat_edit = torch.clamp(y_hat_edit, -1, 1) | |
| return self.tensor2np(y_hat[0]), self.tensor2np(y_hat_edit[0]) |