Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import os | |
| import kornia | |
| import warnings | |
| from modules.layers.faceshifter.layers import AEI_Net | |
| from modules.layers.faceshifter.hear_layers import Hear_Net | |
| from third_party.arcface import iresnet100, MouthNet | |
| make_abs_path = lambda fn: os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), fn)) | |
| class FSGenerator(nn.Module): | |
| def __init__(self, | |
| id_ckpt: str = None, | |
| id_dim: int = 512, | |
| mouth_net_param: dict = None, | |
| in_size: int = 256, | |
| finetune: bool = False, | |
| downup: bool = False, | |
| ): | |
| super(FSGenerator, self).__init__() | |
| ''' MouthNet ''' | |
| self.use_mouth_net = mouth_net_param.get('use') | |
| self.mouth_feat_dim = 0 | |
| self.mouth_net = None | |
| if self.use_mouth_net: | |
| self.mouth_feat_dim = mouth_net_param.get('feature_dim') | |
| self.mouth_crop_param = mouth_net_param.get('crop_param') | |
| mouth_weight_path = make_abs_path(mouth_net_param.get('weight_path')) | |
| self.mouth_net = MouthNet( | |
| bisenet=None, | |
| feature_dim=self.mouth_feat_dim, | |
| crop_param=self.mouth_crop_param | |
| ) | |
| self.mouth_net.load_backbone(mouth_weight_path) | |
| print("[FaceShifter Generator] MouthNet loaded from %s" % mouth_weight_path) | |
| self.mouth_net.eval() | |
| self.mouth_net.requires_grad_(False) | |
| self.G = AEI_Net(c_id=id_dim + self.mouth_feat_dim, finetune=finetune, downup=downup) | |
| self.iresnet = iresnet100() | |
| if not id_ckpt is None: | |
| self.iresnet.load_state_dict(torch.load(id_ckpt, "cpu")) | |
| else: | |
| warnings.warn("Face ID backbone [%s] not found!" % id_ckpt) | |
| raise FileNotFoundError("Face ID backbone [%s] not found!" % id_ckpt) | |
| self.iresnet.eval() | |
| self.register_buffer( | |
| name="trans_matrix", | |
| tensor=torch.tensor( | |
| [ | |
| [ | |
| [1.07695457, -0.03625215, -1.56352194 * (in_size / 256)], | |
| [0.03625215, 1.07695457, -5.32134629 * (in_size / 256)], | |
| ] | |
| ], | |
| requires_grad=False, | |
| ).float(), | |
| ) | |
| self.in_size = in_size | |
| self.iresnet.requires_grad_(False) | |
| def forward(self, source, target, infer=False): | |
| with torch.no_grad(): | |
| ''' 1. get id ''' | |
| if infer: | |
| resize_input = F.interpolate(source, size=112, mode="bilinear", align_corners=True) | |
| id_vector = F.normalize(self.iresnet(resize_input), dim=-1, p=2) | |
| else: | |
| M = self.trans_matrix.repeat(source.size()[0], 1, 1) | |
| source = kornia.geometry.transform.warp_affine(source, M, (self.in_size, self.in_size)) | |
| # import cv2 | |
| # from tricks import Trick | |
| # cv2.imwrite('warpped_source.png', Trick.tensor_to_arr(source)[0, :, :, ::-1]) | |
| resize_input = F.interpolate(source, size=112, mode="bilinear", align_corners=True) | |
| id_vector = F.normalize(self.iresnet(resize_input), dim=-1, p=2) | |
| ''' 2. get mouth feature ''' | |
| if self.use_mouth_net: | |
| w1, h1, w2, h2 = self.mouth_crop_param | |
| mouth_input = resize_input[:, :, h1:h2, w1:w2] # 112->mouth | |
| mouth_feat = self.mouth_net(mouth_input) | |
| id_vector = torch.cat([id_vector, mouth_feat], dim=-1) # (B,dim_id+dim_mouth) | |
| x, att = self.G(target, id_vector) | |
| return x, id_vector, att | |
| def get_recon(self): | |
| return self.G.get_recon_tensor() | |
| def get_att(self, x): | |
| return self.G.get_attr(x) | |
| class FSHearNet(nn.Module): | |
| def __init__(self, aei_path: str): | |
| super(FSHearNet, self).__init__() | |
| ''' Stage I. AEI_Net ''' | |
| self.aei = FSGenerator( | |
| id_ckpt=make_abs_path("../../modules/third_party/arcface/weights/ms1mv3_arcface_r100_fp16/backbone.pth") | |
| ).requires_grad_(False) | |
| print('Loading pre-trained AEI-Net from %s...' % aei_path) | |
| self._load_pretrained_aei(aei_path) | |
| print('Loaded.') | |
| ''' Stage II. HEAR_Net ''' | |
| self.hear = Hear_Net() | |
| def _load_pretrained_aei(self, path: str): | |
| if '.ckpt' in path: | |
| from trainer.faceshifter.extract_ckpt import extract_generator | |
| pth_folder = make_abs_path('../../trainer/faceshifter/extracted_ckpt') | |
| pth_name = 'hear_tmp.pth' | |
| assert '.pth' in pth_name | |
| state_dict = extract_generator(load_path=path, path=os.path.join(pth_folder, pth_name)) | |
| self.aei.load_state_dict(state_dict, strict=False) | |
| self.aei.eval() | |
| elif '.pth' in path: | |
| self.aei.load_state_dict(torch.load(path, "cpu"), strict=False) | |
| self.aei.eval() | |
| else: | |
| raise FileNotFoundError('%s (.ckpt or .pth) not found.' % path) | |
| def forward(self, source, target): | |
| with torch.no_grad(): | |
| y_hat_st, _, _ = self.aei(source, target, infer=True) | |
| y_hat_tt, _, _ = self.aei(target, target, infer=True) | |
| delta_y_t = target - y_hat_tt | |
| y_cat = torch.cat([y_hat_st, delta_y_t], dim=1) # (B,6,256,256) | |
| y_st = self.hear(y_cat) | |
| return y_st, y_hat_st # both (B,3,256,256) | |
| if __name__ == '__main__': | |
| source = torch.randn(8, 3, 512, 512) | |
| target = torch.randn(8, 3, 512, 512) | |
| net = FSGenerator( | |
| id_ckpt="/apdcephfs_cq2/share_1290939/gavinyuan/code/FaceShifter/faceswap/faceswap/checkpoints/" | |
| "face_id/ms1mv3_arcface_r100_fp16_backbone.pth", | |
| mouth_net_param={ | |
| 'use': False | |
| } | |
| ) | |
| result, _, _ = net(source, target) | |
| print('result:', result.shape) | |
| # stage2 = FSHearNet( | |
| # aei_path=make_abs_path("../../trainer/faceshifter/out/faceshifter_vanilla/epoch=32-step=509999.ckpt") | |
| # ) | |
| # final_out, _ = stage2(source, target) | |
| # print('final out:', final_out.shape) | |