Spaces:
Running
Running
| import argparse | |
| import pytorch_lightning as pl | |
| import numpy as np | |
| import torch | |
| from third_party.arcface.mouth_net_pl import MouthNetPL | |
| from third_party.arcface.mouth_net import MouthNet | |
| class MouthTest(object): | |
| def __init__(self): | |
| self.dataset_len = 400 | |
| self.fixer_crop_param = (28, 56, 84, 112) | |
| self.fixer_casia_model = MouthNet( | |
| bisenet=None, | |
| feature_dim=128, | |
| crop_param=self.fixer_crop_param | |
| ).cuda() | |
| fixer_path = "/gavin/code/FaceSwapping/modules/third_party/arcface/weights/fixer_net_casia_28_56_84_112.pth" | |
| self.fixer_casia_model.load_backbone(fixer_path) | |
| self.fixer_casia_model.eval() | |
| self.fixer_t = np.zeros((self.dataset_len, 128), dtype=np.float32) | |
| self.fixer_s = np.zeros_like(self.fixer_t, dtype=np.float32) # each embedding repeats 10 times in ffplus | |
| self.fixer_r = np.zeros_like(self.fixer_t, dtype=np.float32) | |
| print('Fixer model loaded.') | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| args = parser.parse_args() | |
| args.val_targets = [] | |
| args.rec_folder = "/gavin/datasets/msml/ms1m-retinaface" | |
| fixer_net = MouthNetPL.load_from_checkpoint( | |
| "/apdcephfs/share_1290939/gavinyuan/out/fixernet_casia/epoch=22-step=10999-v1.ckpt", | |
| map_location='cpu', strict=False, | |
| num_classes=10572, | |
| batch_size=128, | |
| dim_feature=128, | |
| rec_folder=args.rec_folder, | |
| header_type="AMCosFace", | |
| crop=(28, 56, 84, 112), | |
| ) | |
| lower_net_1 = MouthNetPL.load_from_checkpoint( | |
| "/apdcephfs/share_1290939/gavinyuan/out/mouth_net_1/epoch=24-step=242999.ckpt", | |
| map_location='cpu', strict=False, | |
| num_classes=93431, | |
| batch_size=128, | |
| dim_feature=128, | |
| rec_folder=args.rec_folder, | |
| header_type="AMArcFace", | |
| crop=(28, 56, 84, 112), | |
| ) | |
| # test_net = fixer_net | |
| test_net = lower_net_1 | |
| trainer = pl.Trainer( | |
| logger=False, | |
| gpus=1, | |
| distributed_backend='dp', | |
| benchmark=True, | |
| ) | |
| trainer.test(test_net) | |
| # print('Fixer model loading...') | |
| # m_test = MouthTest() | |