Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	V0.1.0 release.
Browse files- README.md +2 -1
- facelib/detection/__init__.py +34 -8
- facelib/parsing/__init__.py +2 -2
- inference_codeformer.py +9 -2
- requirements.txt +1 -1
- scripts/download_pretrained_models.py +12 -33
- scripts/download_pretrained_models_from_gdrive.py +60 -0
    	
        README.md
    CHANGED
    
    | @@ -6,7 +6,8 @@ | |
| 6 |  | 
| 7 | 
             
            [Paper](https://arxiv.org/abs/2206.11253) | [Project Page](https://shangchenzhou.com/projects/CodeFormer/) | [Video](https://youtu.be/d3VDpkXlueI)
         | 
| 8 |  | 
| 9 | 
            -
            <a href="https://colab.research.google.com/drive/1m52PNveE4PBhYrecj34cnpEeiHcC5LTb?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>
         | 
|  | |
| 10 |  | 
| 11 | 
             
            [Shangchen Zhou](https://shangchenzhou.com/), [Kelvin C.K. Chan](https://ckkelvinchan.github.io/), [Chongyi Li](https://li-chongyi.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/) 
         | 
| 12 |  | 
|  | |
| 6 |  | 
| 7 | 
             
            [Paper](https://arxiv.org/abs/2206.11253) | [Project Page](https://shangchenzhou.com/projects/CodeFormer/) | [Video](https://youtu.be/d3VDpkXlueI)
         | 
| 8 |  | 
| 9 | 
            +
            <a href="https://colab.research.google.com/drive/1m52PNveE4PBhYrecj34cnpEeiHcC5LTb?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a> 
         | 
| 10 | 
            +
             | 
| 11 |  | 
| 12 | 
             
            [Shangchen Zhou](https://shangchenzhou.com/), [Kelvin C.K. Chan](https://ckkelvinchan.github.io/), [Chongyi Li](https://li-chongyi.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/) 
         | 
| 13 |  | 
    	
        facelib/detection/__init__.py
    CHANGED
    
    | @@ -49,17 +49,14 @@ def init_retinaface_model(model_name, half=False, device='cuda'): | |
| 49 | 
             
            def init_yolov5face_model(model_name, device='cuda'):
         | 
| 50 | 
             
                if model_name == 'YOLOv5l':
         | 
| 51 | 
             
                    model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device)
         | 
| 52 | 
            -
                     | 
| 53 | 
             
                elif model_name == 'YOLOv5n':
         | 
| 54 | 
             
                    model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device)
         | 
| 55 | 
            -
                     | 
| 56 | 
             
                else:
         | 
| 57 | 
             
                    raise NotImplementedError(f'{model_name} is not implemented.')
         | 
| 58 | 
            -
             | 
| 59 | 
            -
                model_path =  | 
| 60 | 
            -
                if not os.path.exists(model_path):
         | 
| 61 | 
            -
                    download_pretrained_models(file_ids=f_id, save_path_root='weights/facelib')
         | 
| 62 | 
            -
             | 
| 63 | 
             
                load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
         | 
| 64 | 
             
                model.detector.load_state_dict(load_net, strict=True)
         | 
| 65 | 
             
                model.detector.eval()
         | 
| @@ -71,4 +68,33 @@ def init_yolov5face_model(model_name, device='cuda'): | |
| 71 | 
             
                    elif isinstance(m, Conv):
         | 
| 72 | 
             
                        m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatibility
         | 
| 73 |  | 
| 74 | 
            -
                return model
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 49 | 
             
            def init_yolov5face_model(model_name, device='cuda'):
         | 
| 50 | 
             
                if model_name == 'YOLOv5l':
         | 
| 51 | 
             
                    model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device)
         | 
| 52 | 
            +
                    model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth'
         | 
| 53 | 
             
                elif model_name == 'YOLOv5n':
         | 
| 54 | 
             
                    model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device)
         | 
| 55 | 
            +
                    model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth'
         | 
| 56 | 
             
                else:
         | 
| 57 | 
             
                    raise NotImplementedError(f'{model_name} is not implemented.')
         | 
| 58 | 
            +
                
         | 
| 59 | 
            +
                model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None)
         | 
|  | |
|  | |
|  | |
| 60 | 
             
                load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
         | 
| 61 | 
             
                model.detector.load_state_dict(load_net, strict=True)
         | 
| 62 | 
             
                model.detector.eval()
         | 
|  | |
| 68 | 
             
                    elif isinstance(m, Conv):
         | 
| 69 | 
             
                        m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatibility
         | 
| 70 |  | 
| 71 | 
            +
                return model
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            # Download from Google Drive
         | 
| 75 | 
            +
            # def init_yolov5face_model(model_name, device='cuda'):
         | 
| 76 | 
            +
            #     if model_name == 'YOLOv5l':
         | 
| 77 | 
            +
            #         model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device)
         | 
| 78 | 
            +
            #         f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'}
         | 
| 79 | 
            +
            #     elif model_name == 'YOLOv5n':
         | 
| 80 | 
            +
            #         model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device)
         | 
| 81 | 
            +
            #         f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'}
         | 
| 82 | 
            +
            #     else:
         | 
| 83 | 
            +
            #         raise NotImplementedError(f'{model_name} is not implemented.')
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            #     model_path = os.path.join('weights/facelib', list(f_id.keys())[0])
         | 
| 86 | 
            +
            #     if not os.path.exists(model_path):
         | 
| 87 | 
            +
            #         download_pretrained_models(file_ids=f_id, save_path_root='weights/facelib')
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            #     load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
         | 
| 90 | 
            +
            #     model.detector.load_state_dict(load_net, strict=True)
         | 
| 91 | 
            +
            #     model.detector.eval()
         | 
| 92 | 
            +
            #     model.detector = model.detector.to(device).float()
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            #     for m in model.detector.modules():
         | 
| 95 | 
            +
            #         if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
         | 
| 96 | 
            +
            #             m.inplace = True  # pytorch 1.7.0 compatibility
         | 
| 97 | 
            +
            #         elif isinstance(m, Conv):
         | 
| 98 | 
            +
            #             m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatibility
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            #     return model
         | 
    	
        facelib/parsing/__init__.py
    CHANGED
    
    | @@ -8,10 +8,10 @@ from .parsenet import ParseNet | |
| 8 | 
             
            def init_parsing_model(model_name='bisenet', half=False, device='cuda'):
         | 
| 9 | 
             
                if model_name == 'bisenet':
         | 
| 10 | 
             
                    model = BiSeNet(num_class=19)
         | 
| 11 | 
            -
                    model_url = 'https://github.com/ | 
| 12 | 
             
                elif model_name == 'parsenet':
         | 
| 13 | 
             
                    model = ParseNet(in_size=512, out_size=512, parsing_ch=19)
         | 
| 14 | 
            -
                    model_url = 'https://github.com/ | 
| 15 | 
             
                else:
         | 
| 16 | 
             
                    raise NotImplementedError(f'{model_name} is not implemented.')
         | 
| 17 |  | 
|  | |
| 8 | 
             
            def init_parsing_model(model_name='bisenet', half=False, device='cuda'):
         | 
| 9 | 
             
                if model_name == 'bisenet':
         | 
| 10 | 
             
                    model = BiSeNet(num_class=19)
         | 
| 11 | 
            +
                    model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth'
         | 
| 12 | 
             
                elif model_name == 'parsenet':
         | 
| 13 | 
             
                    model = ParseNet(in_size=512, out_size=512, parsing_ch=19)
         | 
| 14 | 
            +
                    model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth'
         | 
| 15 | 
             
                else:
         | 
| 16 | 
             
                    raise NotImplementedError(f'{model_name} is not implemented.')
         | 
| 17 |  | 
    	
        inference_codeformer.py
    CHANGED
    
    | @@ -6,11 +6,16 @@ import glob | |
| 6 | 
             
            import torch
         | 
| 7 | 
             
            from torchvision.transforms.functional import normalize
         | 
| 8 | 
             
            from basicsr.utils import imwrite, img2tensor, tensor2img
         | 
|  | |
| 9 | 
             
            from facelib.utils.face_restoration_helper import FaceRestoreHelper
         | 
| 10 | 
             
            import torch.nn.functional as F
         | 
| 11 |  | 
| 12 | 
             
            from basicsr.utils.registry import ARCH_REGISTRY
         | 
| 13 |  | 
|  | |
|  | |
|  | |
|  | |
| 14 | 
             
            if __name__ == '__main__':
         | 
| 15 | 
             
                device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 16 | 
             
                parser = argparse.ArgumentParser()
         | 
| @@ -59,8 +64,10 @@ if __name__ == '__main__': | |
| 59 | 
             
                # ------------------ set up CodeFormer restorer -------------------
         | 
| 60 | 
             
                net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, 
         | 
| 61 | 
             
                                                        connect_list=['32', '64', '128', '256']).to(device)
         | 
| 62 | 
            -
             | 
| 63 | 
            -
                ckpt_path = 'weights/CodeFormer/codeformer.pth'
         | 
|  | |
|  | |
| 64 | 
             
                checkpoint = torch.load(ckpt_path)['params_ema']
         | 
| 65 | 
             
                net.load_state_dict(checkpoint)
         | 
| 66 | 
             
                net.eval()
         | 
|  | |
| 6 | 
             
            import torch
         | 
| 7 | 
             
            from torchvision.transforms.functional import normalize
         | 
| 8 | 
             
            from basicsr.utils import imwrite, img2tensor, tensor2img
         | 
| 9 | 
            +
            from basicsr.utils.download_util import load_file_from_url
         | 
| 10 | 
             
            from facelib.utils.face_restoration_helper import FaceRestoreHelper
         | 
| 11 | 
             
            import torch.nn.functional as F
         | 
| 12 |  | 
| 13 | 
             
            from basicsr.utils.registry import ARCH_REGISTRY
         | 
| 14 |  | 
| 15 | 
            +
            pretrain_model_url = {
         | 
| 16 | 
            +
                'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
         | 
| 17 | 
            +
            }
         | 
| 18 | 
            +
             | 
| 19 | 
             
            if __name__ == '__main__':
         | 
| 20 | 
             
                device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 21 | 
             
                parser = argparse.ArgumentParser()
         | 
|  | |
| 64 | 
             
                # ------------------ set up CodeFormer restorer -------------------
         | 
| 65 | 
             
                net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, 
         | 
| 66 | 
             
                                                        connect_list=['32', '64', '128', '256']).to(device)
         | 
| 67 | 
            +
                
         | 
| 68 | 
            +
                # ckpt_path = 'weights/CodeFormer/codeformer.pth'
         | 
| 69 | 
            +
                ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'], 
         | 
| 70 | 
            +
                                                model_dir='weights/CodeFormer', progress=True, file_name=None)
         | 
| 71 | 
             
                checkpoint = torch.load(ckpt_path)['params_ema']
         | 
| 72 | 
             
                net.load_state_dict(checkpoint)
         | 
| 73 | 
             
                net.eval()
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -14,7 +14,7 @@ torchvision | |
| 14 | 
             
            tqdm
         | 
| 15 | 
             
            yapf
         | 
| 16 | 
             
            lpips
         | 
| 17 | 
            -
            gdown # supports downloading the large file from Google Drive
         | 
| 18 | 
             
            # cmake
         | 
| 19 | 
             
            # dlib
         | 
| 20 | 
             
            # conda install -c conda-forge dlib
         | 
|  | |
| 14 | 
             
            tqdm
         | 
| 15 | 
             
            yapf
         | 
| 16 | 
             
            lpips
         | 
| 17 | 
            +
            # gdown # supports downloading the large file from Google Drive
         | 
| 18 | 
             
            # cmake
         | 
| 19 | 
             
            # dlib
         | 
| 20 | 
             
            # conda install -c conda-forge dlib
         | 
    	
        scripts/download_pretrained_models.py
    CHANGED
    
    | @@ -2,31 +2,16 @@ import argparse | |
| 2 | 
             
            import os
         | 
| 3 | 
             
            from os import path as osp
         | 
| 4 |  | 
| 5 | 
            -
             | 
| 6 | 
            -
            import gdown
         | 
| 7 |  | 
| 8 |  | 
| 9 | 
            -
            def download_pretrained_models(method,  | 
| 10 | 
             
                save_path_root = f'./weights/{method}'
         | 
| 11 | 
             
                os.makedirs(save_path_root, exist_ok=True)
         | 
| 12 |  | 
| 13 | 
            -
                for file_name,  | 
| 14 | 
            -
                    file_url =  | 
| 15 | 
            -
             | 
| 16 | 
            -
                    if osp.exists(save_path):
         | 
| 17 | 
            -
                        user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
         | 
| 18 | 
            -
                        if user_response.lower() == 'y':
         | 
| 19 | 
            -
                            print(f'Covering {file_name} to {save_path}')
         | 
| 20 | 
            -
                            gdown.download(file_url, save_path, quiet=False)
         | 
| 21 | 
            -
                            # download_file_from_google_drive(file_id, save_path)
         | 
| 22 | 
            -
                        elif user_response.lower() == 'n':
         | 
| 23 | 
            -
                            print(f'Skipping {file_name}')
         | 
| 24 | 
            -
                        else:
         | 
| 25 | 
            -
                            raise ValueError('Wrong input. Only accepts Y/N.')
         | 
| 26 | 
            -
                    else:
         | 
| 27 | 
            -
                        print(f'Downloading {file_name} to {save_path}')
         | 
| 28 | 
            -
                        gdown.download(file_url, save_path, quiet=False)
         | 
| 29 | 
            -
                        # download_file_from_google_drive(file_id, save_path)
         | 
| 30 |  | 
| 31 | 
             
            if __name__ == '__main__':
         | 
| 32 | 
             
                parser = argparse.ArgumentParser()
         | 
| @@ -37,24 +22,18 @@ if __name__ == '__main__': | |
| 37 | 
             
                    help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models."))
         | 
| 38 | 
             
                args = parser.parse_args()
         | 
| 39 |  | 
| 40 | 
            -
                 | 
| 41 | 
            -
                # 'dlib': {
         | 
| 42 | 
            -
                #     'mmod_human_face_detector-4cb19393.dat': '1qD-OqY8M6j4PWUP_FtqfwUPFPRMu6ubX',
         | 
| 43 | 
            -
                #     'shape_predictor_5_face_landmarks-c4b1e980.dat': '1vF3WBUApw4662v9Pw6wke3uk1qxnmLdg',
         | 
| 44 | 
            -
                #     'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1tJyIVdCHaU6IDMDx86BZCxLGZfsWB8yq'
         | 
| 45 | 
            -
                # }
         | 
| 46 | 
            -
                file_ids = {
         | 
| 47 | 
             
                    'CodeFormer': {
         | 
| 48 | 
            -
                        'codeformer.pth': ' | 
| 49 | 
             
                    },
         | 
| 50 | 
             
                    'facelib': {
         | 
| 51 | 
            -
                        'yolov5l-face.pth': ' | 
| 52 | 
            -
                        'parsing_parsenet.pth': ' | 
| 53 | 
             
                    }
         | 
| 54 | 
             
                }
         | 
| 55 |  | 
| 56 | 
             
                if args.method == 'all':
         | 
| 57 | 
            -
                    for method in  | 
| 58 | 
            -
                        download_pretrained_models(method,  | 
| 59 | 
             
                else:
         | 
| 60 | 
            -
                    download_pretrained_models(args.method,  | 
|  | |
| 2 | 
             
            import os
         | 
| 3 | 
             
            from os import path as osp
         | 
| 4 |  | 
| 5 | 
            +
            from basicsr.utils.download_util import load_file_from_url
         | 
|  | |
| 6 |  | 
| 7 |  | 
| 8 | 
            +
            def download_pretrained_models(method, file_urls):
         | 
| 9 | 
             
                save_path_root = f'./weights/{method}'
         | 
| 10 | 
             
                os.makedirs(save_path_root, exist_ok=True)
         | 
| 11 |  | 
| 12 | 
            +
                for file_name, file_url in file_urls.items():
         | 
| 13 | 
            +
                    save_path = load_file_from_url(url=file_url, model_dir=save_path_root, progress=True, file_name=file_name)
         | 
| 14 | 
            +
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 15 |  | 
| 16 | 
             
            if __name__ == '__main__':
         | 
| 17 | 
             
                parser = argparse.ArgumentParser()
         | 
|  | |
| 22 | 
             
                    help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models."))
         | 
| 23 | 
             
                args = parser.parse_args()
         | 
| 24 |  | 
| 25 | 
            +
                file_urls = {
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 26 | 
             
                    'CodeFormer': {
         | 
| 27 | 
            +
                        'codeformer.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
         | 
| 28 | 
             
                    },
         | 
| 29 | 
             
                    'facelib': {
         | 
| 30 | 
            +
                        'yolov5l-face.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth',
         | 
| 31 | 
            +
                        'parsing_parsenet.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth'
         | 
| 32 | 
             
                    }
         | 
| 33 | 
             
                }
         | 
| 34 |  | 
| 35 | 
             
                if args.method == 'all':
         | 
| 36 | 
            +
                    for method in file_urls.keys():
         | 
| 37 | 
            +
                        download_pretrained_models(method, file_urls[method])
         | 
| 38 | 
             
                else:
         | 
| 39 | 
            +
                    download_pretrained_models(args.method, file_urls[args.method])
         | 
    	
        scripts/download_pretrained_models_from_gdrive.py
    ADDED
    
    | @@ -0,0 +1,60 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            from os import path as osp
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            # from basicsr.utils.download_util import download_file_from_google_drive
         | 
| 6 | 
            +
            import gdown
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def download_pretrained_models(method, file_ids):
         | 
| 10 | 
            +
                save_path_root = f'./weights/{method}'
         | 
| 11 | 
            +
                os.makedirs(save_path_root, exist_ok=True)
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                for file_name, file_id in file_ids.items():
         | 
| 14 | 
            +
                    file_url = 'https://drive.google.com/uc?id='+file_id
         | 
| 15 | 
            +
                    save_path = osp.abspath(osp.join(save_path_root, file_name))
         | 
| 16 | 
            +
                    if osp.exists(save_path):
         | 
| 17 | 
            +
                        user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
         | 
| 18 | 
            +
                        if user_response.lower() == 'y':
         | 
| 19 | 
            +
                            print(f'Covering {file_name} to {save_path}')
         | 
| 20 | 
            +
                            gdown.download(file_url, save_path, quiet=False)
         | 
| 21 | 
            +
                            # download_file_from_google_drive(file_id, save_path)
         | 
| 22 | 
            +
                        elif user_response.lower() == 'n':
         | 
| 23 | 
            +
                            print(f'Skipping {file_name}')
         | 
| 24 | 
            +
                        else:
         | 
| 25 | 
            +
                            raise ValueError('Wrong input. Only accepts Y/N.')
         | 
| 26 | 
            +
                    else:
         | 
| 27 | 
            +
                        print(f'Downloading {file_name} to {save_path}')
         | 
| 28 | 
            +
                        gdown.download(file_url, save_path, quiet=False)
         | 
| 29 | 
            +
                        # download_file_from_google_drive(file_id, save_path)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            if __name__ == '__main__':
         | 
| 32 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                parser.add_argument(
         | 
| 35 | 
            +
                    'method',
         | 
| 36 | 
            +
                    type=str,
         | 
| 37 | 
            +
                    help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models."))
         | 
| 38 | 
            +
                args = parser.parse_args()
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                # file name: file id
         | 
| 41 | 
            +
                # 'dlib': {
         | 
| 42 | 
            +
                #     'mmod_human_face_detector-4cb19393.dat': '1qD-OqY8M6j4PWUP_FtqfwUPFPRMu6ubX',
         | 
| 43 | 
            +
                #     'shape_predictor_5_face_landmarks-c4b1e980.dat': '1vF3WBUApw4662v9Pw6wke3uk1qxnmLdg',
         | 
| 44 | 
            +
                #     'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1tJyIVdCHaU6IDMDx86BZCxLGZfsWB8yq'
         | 
| 45 | 
            +
                # }
         | 
| 46 | 
            +
                file_ids = {
         | 
| 47 | 
            +
                    'CodeFormer': {
         | 
| 48 | 
            +
                        'codeformer.pth': '1v_E_vZvP-dQPF55Kc5SRCjaKTQXDz-JB'
         | 
| 49 | 
            +
                    },
         | 
| 50 | 
            +
                    'facelib': {
         | 
| 51 | 
            +
                        'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV',
         | 
| 52 | 
            +
                        'parsing_parsenet.pth': '16pkohyZZ8ViHGBk3QtVqxLZKzdo466bK'
         | 
| 53 | 
            +
                    }
         | 
| 54 | 
            +
                }
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                if args.method == 'all':
         | 
| 57 | 
            +
                    for method in file_ids.keys():
         | 
| 58 | 
            +
                        download_pretrained_models(method, file_ids[method])
         | 
| 59 | 
            +
                else:
         | 
| 60 | 
            +
                    download_pretrained_models(args.method, file_ids[args.method])
         | 
