kevinwang676 commited on
Commit
3e2da18
·
1 Parent(s): 57ce766

Delete SadTalker

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. SadTalker/.gitignore +0 -174
  2. SadTalker/.ipynb_checkpoints/requirements-checkpoint.txt +0 -37
  3. SadTalker/checkpoints/SadTalker_V0.0.2_256.safetensors +0 -3
  4. SadTalker/checkpoints/SadTalker_V0.0.2_512.safetensors +0 -3
  5. SadTalker/checkpoints/mapping_00109-model.pth.tar +0 -3
  6. SadTalker/checkpoints/mapping_00229-model.pth.tar +0 -3
  7. SadTalker/cog.yaml +0 -35
  8. SadTalker/gfpgan/weights/GFPGANv1.4.pth +0 -3
  9. SadTalker/gfpgan/weights/alignment_WFLW_4HG.pth +0 -3
  10. SadTalker/gfpgan/weights/detection_Resnet50_Final.pth +0 -3
  11. SadTalker/gfpgan/weights/parsing_parsenet.pth +0 -3
  12. SadTalker/inference.py +0 -145
  13. SadTalker/launcher.py +0 -204
  14. SadTalker/predict.py +0 -192
  15. SadTalker/req.txt +0 -22
  16. SadTalker/requirements.txt +0 -37
  17. SadTalker/scripts/download_models.sh +0 -32
  18. SadTalker/scripts/extension.py +0 -189
  19. SadTalker/scripts/test.sh +0 -21
  20. SadTalker/src/audio2exp_models/audio2exp.py +0 -41
  21. SadTalker/src/audio2exp_models/networks.py +0 -74
  22. SadTalker/src/audio2pose_models/audio2pose.py +0 -94
  23. SadTalker/src/audio2pose_models/audio_encoder.py +0 -64
  24. SadTalker/src/audio2pose_models/cvae.py +0 -149
  25. SadTalker/src/audio2pose_models/discriminator.py +0 -76
  26. SadTalker/src/audio2pose_models/networks.py +0 -140
  27. SadTalker/src/audio2pose_models/res_unet.py +0 -65
  28. SadTalker/src/config/auido2exp.yaml +0 -58
  29. SadTalker/src/config/auido2pose.yaml +0 -49
  30. SadTalker/src/config/facerender.yaml +0 -45
  31. SadTalker/src/config/facerender_still.yaml +0 -45
  32. SadTalker/src/config/similarity_Lm3D_all.mat +0 -0
  33. SadTalker/src/face3d/data/__init__.py +0 -116
  34. SadTalker/src/face3d/data/base_dataset.py +0 -125
  35. SadTalker/src/face3d/data/flist_dataset.py +0 -125
  36. SadTalker/src/face3d/data/image_folder.py +0 -66
  37. SadTalker/src/face3d/data/template_dataset.py +0 -75
  38. SadTalker/src/face3d/extract_kp_videos.py +0 -108
  39. SadTalker/src/face3d/extract_kp_videos_safe.py +0 -151
  40. SadTalker/src/face3d/models/__init__.py +0 -67
  41. SadTalker/src/face3d/models/arcface_torch/README.md +0 -164
  42. SadTalker/src/face3d/models/arcface_torch/backbones/__init__.py +0 -25
  43. SadTalker/src/face3d/models/arcface_torch/backbones/iresnet.py +0 -187
  44. SadTalker/src/face3d/models/arcface_torch/backbones/iresnet2060.py +0 -176
  45. SadTalker/src/face3d/models/arcface_torch/backbones/mobilefacenet.py +0 -130
  46. SadTalker/src/face3d/models/arcface_torch/configs/3millions.py +0 -23
  47. SadTalker/src/face3d/models/arcface_torch/configs/3millions_pfc.py +0 -23
  48. SadTalker/src/face3d/models/arcface_torch/configs/__init__.py +0 -0
  49. SadTalker/src/face3d/models/arcface_torch/configs/base.py +0 -56
  50. SadTalker/src/face3d/models/arcface_torch/configs/glint360k_mbf.py +0 -26
SadTalker/.gitignore DELETED
@@ -1,174 +0,0 @@
1
- # Byte-compiled / optimized / DLL files
2
- __pycache__/
3
- *.py[cod]
4
- *$py.class
5
-
6
- # C extensions
7
- *.so
8
-
9
- # Distribution / packaging
10
- .Python
11
- build/
12
- develop-eggs/
13
- dist/
14
- downloads/
15
- eggs/
16
- .eggs/
17
- lib/
18
- lib64/
19
- parts/
20
- sdist/
21
- var/
22
- wheels/
23
- share/python-wheels/
24
- *.egg-info/
25
- .installed.cfg
26
- *.egg
27
- MANIFEST
28
-
29
- # PyInstaller
30
- # Usually these files are written by a python script from a template
31
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
- *.manifest
33
- *.spec
34
-
35
- # Installer logs
36
- pip-log.txt
37
- pip-delete-this-directory.txt
38
-
39
- # Unit test / coverage reports
40
- htmlcov/
41
- .tox/
42
- .nox/
43
- .coverage
44
- .coverage.*
45
- .cache
46
- nosetests.xml
47
- coverage.xml
48
- *.cover
49
- *.py,cover
50
- .hypothesis/
51
- .pytest_cache/
52
- cover/
53
-
54
- # Translations
55
- *.mo
56
- *.pot
57
-
58
- # Django stuff:
59
- *.log
60
- local_settings.py
61
- db.sqlite3
62
- db.sqlite3-journal
63
-
64
- # Flask stuff:
65
- instance/
66
- .webassets-cache
67
-
68
- # Scrapy stuff:
69
- .scrapy
70
-
71
- # Sphinx documentation
72
- docs/_build/
73
-
74
- # PyBuilder
75
- .pybuilder/
76
- target/
77
-
78
- # Jupyter Notebook
79
- .ipynb_checkpoints
80
-
81
- # IPython
82
- profile_default/
83
- ipython_config.py
84
-
85
- # pyenv
86
- # For a library or package, you might want to ignore these files since the code is
87
- # intended to run in multiple environments; otherwise, check them in:
88
- # .python-version
89
-
90
- # pipenv
91
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
- # install all needed dependencies.
95
- #Pipfile.lock
96
-
97
- # poetry
98
- # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
- # This is especially recommended for binary packages to ensure reproducibility, and is more
100
- # commonly ignored for libraries.
101
- # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
- #poetry.lock
103
-
104
- # pdm
105
- # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
- #pdm.lock
107
- # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
- # in version control.
109
- # https://pdm.fming.dev/#use-with-ide
110
- .pdm.toml
111
-
112
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
- __pypackages__/
114
-
115
- # Celery stuff
116
- celerybeat-schedule
117
- celerybeat.pid
118
-
119
- # SageMath parsed files
120
- *.sage.py
121
-
122
- # Environments
123
- .env
124
- .venv
125
- env/
126
- venv/
127
- ENV/
128
- env.bak/
129
- venv.bak/
130
-
131
- # Spyder project settings
132
- .spyderproject
133
- .spyproject
134
-
135
- # Rope project settings
136
- .ropeproject
137
-
138
- # mkdocs documentation
139
- /site
140
-
141
- # mypy
142
- .mypy_cache/
143
- .dmypy.json
144
- dmypy.json
145
-
146
- # Pyre type checker
147
- .pyre/
148
-
149
- # pytype static type analyzer
150
- .pytype/
151
-
152
- # Cython debug symbols
153
- cython_debug/
154
-
155
- # PyCharm
156
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
- # and can be added to the global gitignore or merged into this file. For a more nuclear
159
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
- .idea/
161
-
162
- examples/results/*
163
- gfpgan/*
164
- checkpoints/*
165
- assets/*
166
- results/*
167
- Dockerfile
168
- start_docker.sh
169
- start.sh
170
-
171
- checkpoints
172
-
173
- # Mac
174
- .DS_Store
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/.ipynb_checkpoints/requirements-checkpoint.txt DELETED
@@ -1,37 +0,0 @@
1
- numpy==1.23.5
2
- matplotlib
3
- moviepy
4
- yt-dlp
5
- pydub
6
- demucs
7
- gradio
8
- torch
9
- flask
10
- flask-cors
11
- torchaudio
12
- fairseq==0.12.2
13
- scipy==1.10.1
14
- pyworld>=0.3.2
15
- faiss-cpu==1.7.3
16
- praat-parselmouth>=0.4.2
17
- librosa==0.9.2
18
- edge-tts
19
- torchcrepe
20
- Pillow==9.5.0
21
-
22
- face_alignment==1.3.5
23
- imageio==2.19.3
24
- imageio-ffmpeg==0.4.7
25
- numba
26
- resampy==0.3.1
27
- kornia==0.6.8
28
- tqdm
29
- yacs==0.1.8
30
- pyyaml
31
- joblib==1.1.0
32
- scikit-image==0.19.3
33
- basicsr==1.4.2
34
- facexlib==0.3.0
35
- gfpgan
36
- av
37
- safetensors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/checkpoints/SadTalker_V0.0.2_256.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c211f5d6de003516bf1bbda9f47049a4c9c99133b1ab565c6961e5af16477bff
3
- size 725066984
 
 
 
 
SadTalker/checkpoints/SadTalker_V0.0.2_512.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0e063f7ff5258240bdb0f7690783a7b1374e6a4a81ce8fa33456f4cd49694340
3
- size 725066984
 
 
 
 
SadTalker/checkpoints/mapping_00109-model.pth.tar DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:84a8642468a3fcfdd9ab6be955267043116c2bec2284686a5262f1eaf017f64c
3
- size 155779231
 
 
 
 
SadTalker/checkpoints/mapping_00229-model.pth.tar DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:62a1e06006cc963220f6477438518ed86e9788226c62ae382ddc42fbcefb83f1
3
- size 155521183
 
 
 
 
SadTalker/cog.yaml DELETED
@@ -1,35 +0,0 @@
1
- build:
2
- gpu: true
3
- cuda: "11.3"
4
- python_version: "3.8"
5
- system_packages:
6
- - "ffmpeg"
7
- - "libgl1-mesa-glx"
8
- - "libglib2.0-0"
9
- python_packages:
10
- - "torch==1.12.1"
11
- - "torchvision==0.13.1"
12
- - "torchaudio==0.12.1"
13
- - "joblib==1.1.0"
14
- - "scikit-image==0.19.3"
15
- - "basicsr==1.4.2"
16
- - "facexlib==0.3.0"
17
- - "resampy==0.3.1"
18
- - "pydub==0.25.1"
19
- - "scipy==1.10.1"
20
- - "kornia==0.6.8"
21
- - "face_alignment==1.3.5"
22
- - "imageio==2.19.3"
23
- - "imageio-ffmpeg==0.4.7"
24
- - "librosa==0.9.2" #
25
- - "tqdm==4.65.0"
26
- - "yacs==0.1.8"
27
- - "gfpgan==1.3.8"
28
- - "dlib-bin==19.24.1"
29
- - "av==10.0.0"
30
- - "trimesh==3.9.20"
31
- run:
32
- - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/s3fd-619a316812.pth" "https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth"
33
- - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/2DFAN4-cd938726ad.zip" "https://www.adrianbulat.com/downloads/python-fan/2DFAN4-cd938726ad.zip"
34
-
35
- predict: "predict.py:Predictor"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/gfpgan/weights/GFPGANv1.4.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e2cd4703ab14f4d01fd1383a8a8b266f9a5833dacee8e6a79d3bf21a1b6be5ad
3
- size 348632874
 
 
 
 
SadTalker/gfpgan/weights/alignment_WFLW_4HG.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:bbfd137307a4c7debd5c283b9b0ce539466cee417ac0a155e184d857f9f2899c
3
- size 193670248
 
 
 
 
SadTalker/gfpgan/weights/detection_Resnet50_Final.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d
3
- size 109497761
 
 
 
 
SadTalker/gfpgan/weights/parsing_parsenet.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3d558d8d0e42c20224f13cf5a29c79eba2d59913419f945545d8cf7b72920de2
3
- size 85331193
 
 
 
 
SadTalker/inference.py DELETED
@@ -1,145 +0,0 @@
1
- from glob import glob
2
- import shutil
3
- import torch
4
- from time import strftime
5
- import os, sys, time
6
- from argparse import ArgumentParser
7
-
8
- from src.utils.preprocess import CropAndExtract
9
- from src.test_audio2coeff import Audio2Coeff
10
- from src.facerender.animate import AnimateFromCoeff
11
- from src.generate_batch import get_data
12
- from src.generate_facerender_batch import get_facerender_data
13
- from src.utils.init_path import init_path
14
-
15
- def main(args):
16
- #torch.backends.cudnn.enabled = False
17
-
18
- pic_path = args.source_image
19
- audio_path = args.driven_audio
20
- save_dir = os.path.join(args.result_dir, strftime("%Y_%m_%d_%H.%M.%S"))
21
- os.makedirs(save_dir, exist_ok=True)
22
- pose_style = args.pose_style
23
- device = args.device
24
- batch_size = args.batch_size
25
- input_yaw_list = args.input_yaw
26
- input_pitch_list = args.input_pitch
27
- input_roll_list = args.input_roll
28
- ref_eyeblink = args.ref_eyeblink
29
- ref_pose = args.ref_pose
30
-
31
- current_root_path = os.path.split(sys.argv[0])[0]
32
-
33
- sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess)
34
-
35
- #init model
36
- preprocess_model = CropAndExtract(sadtalker_paths, device)
37
-
38
- audio_to_coeff = Audio2Coeff(sadtalker_paths, device)
39
-
40
- animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device)
41
-
42
- #crop image and extract 3dmm from image
43
- first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
44
- os.makedirs(first_frame_dir, exist_ok=True)
45
- print('3DMM Extraction for source image')
46
- first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess,\
47
- source_image_flag=True, pic_size=args.size)
48
- if first_coeff_path is None:
49
- print("Can't get the coeffs of the input")
50
- return
51
-
52
- if ref_eyeblink is not None:
53
- ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[0]
54
- ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname)
55
- os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
56
- print('3DMM Extraction for the reference video providing eye blinking')
57
- ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir, args.preprocess, source_image_flag=False)
58
- else:
59
- ref_eyeblink_coeff_path=None
60
-
61
- if ref_pose is not None:
62
- if ref_pose == ref_eyeblink:
63
- ref_pose_coeff_path = ref_eyeblink_coeff_path
64
- else:
65
- ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
66
- ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname)
67
- os.makedirs(ref_pose_frame_dir, exist_ok=True)
68
- print('3DMM Extraction for the reference video providing pose')
69
- ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir, args.preprocess, source_image_flag=False)
70
- else:
71
- ref_pose_coeff_path=None
72
-
73
- #audio2ceoff
74
- batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still)
75
- coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
76
-
77
- # 3dface render
78
- if args.face3dvis:
79
- from src.face3d.visualize import gen_composed_video
80
- gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, os.path.join(save_dir, '3dface.mp4'))
81
-
82
- #coeff2video
83
- data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path,
84
- batch_size, input_yaw_list, input_pitch_list, input_roll_list,
85
- expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess, size=args.size)
86
-
87
- result = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \
88
- enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size)
89
-
90
- shutil.move(result, save_dir+'.mp4')
91
- print('The generated video is named:', save_dir+'.mp4')
92
-
93
- if not args.verbose:
94
- shutil.rmtree(save_dir)
95
-
96
-
97
- if __name__ == '__main__':
98
-
99
- parser = ArgumentParser()
100
- parser.add_argument("--driven_audio", default='./examples/driven_audio/bus_chinese.wav', help="path to driven audio")
101
- parser.add_argument("--source_image", default='./examples/source_image/full_body_1.png', help="path to source image")
102
- parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking")
103
- parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose")
104
- parser.add_argument("--checkpoint_dir", default='./checkpoints', help="path to output")
105
- parser.add_argument("--result_dir", default='./results', help="path to output")
106
- parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)")
107
- parser.add_argument("--batch_size", type=int, default=2, help="the batch size of facerender")
108
- parser.add_argument("--size", type=int, default=256, help="the image size of the facerender")
109
- parser.add_argument("--expression_scale", type=float, default=1., help="the batch size of facerender")
110
- parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help="the input yaw degree of the user ")
111
- parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help="the input pitch degree of the user")
112
- parser.add_argument('--input_roll', nargs='+', type=int, default=None, help="the input roll degree of the user")
113
- parser.add_argument('--enhancer', type=str, default=None, help="Face enhancer, [gfpgan, RestoreFormer]")
114
- parser.add_argument('--background_enhancer', type=str, default=None, help="background enhancer, [realesrgan]")
115
- parser.add_argument("--cpu", dest="cpu", action="store_true")
116
- parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks")
117
- parser.add_argument("--still", action="store_true", help="can crop back to the original videos for the full body aniamtion")
118
- parser.add_argument("--preprocess", default='crop', choices=['crop', 'extcrop', 'resize', 'full', 'extfull'], help="how to preprocess the images" )
119
- parser.add_argument("--verbose",action="store_true", help="saving the intermedia output or not" )
120
- parser.add_argument("--old_version",action="store_true", help="use the pth other than safetensor version" )
121
-
122
-
123
- # net structure and parameters
124
- parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='useless')
125
- parser.add_argument('--init_path', type=str, default=None, help='Useless')
126
- parser.add_argument('--use_last_fc',default=False, help='zero initialize the last fc')
127
- parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/')
128
- parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
129
-
130
- # default renderer parameters
131
- parser.add_argument('--focal', type=float, default=1015.)
132
- parser.add_argument('--center', type=float, default=112.)
133
- parser.add_argument('--camera_d', type=float, default=10.)
134
- parser.add_argument('--z_near', type=float, default=5.)
135
- parser.add_argument('--z_far', type=float, default=15.)
136
-
137
- args = parser.parse_args()
138
-
139
- if torch.cuda.is_available() and not args.cpu:
140
- args.device = "cuda"
141
- else:
142
- args.device = "cpu"
143
-
144
- main(args)
145
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/launcher.py DELETED
@@ -1,204 +0,0 @@
1
- # this scripts installs necessary requirements and launches main program in webui.py
2
- # borrow from : https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/launch.py
3
- import subprocess
4
- import os
5
- import sys
6
- import importlib.util
7
- import shlex
8
- import platform
9
- import json
10
-
11
- python = sys.executable
12
- git = os.environ.get('GIT', "git")
13
- index_url = os.environ.get('INDEX_URL', "")
14
- stored_commit_hash = None
15
- skip_install = False
16
- dir_repos = "repositories"
17
- script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
18
-
19
- if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
20
- os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
21
-
22
-
23
- def check_python_version():
24
- is_windows = platform.system() == "Windows"
25
- major = sys.version_info.major
26
- minor = sys.version_info.minor
27
- micro = sys.version_info.micro
28
-
29
- if is_windows:
30
- supported_minors = [10]
31
- else:
32
- supported_minors = [7, 8, 9, 10, 11]
33
-
34
- if not (major == 3 and minor in supported_minors):
35
-
36
- raise (f"""
37
- INCOMPATIBLE PYTHON VERSION
38
- This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
39
- If you encounter an error with "RuntimeError: Couldn't install torch." message,
40
- or any other error regarding unsuccessful package (library) installation,
41
- please downgrade (or upgrade) to the latest version of 3.10 Python
42
- and delete current Python and "venv" folder in WebUI's directory.
43
- You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
44
- {"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
45
- Use --skip-python-version-check to suppress this warning.
46
- """)
47
-
48
-
49
- def commit_hash():
50
- global stored_commit_hash
51
-
52
- if stored_commit_hash is not None:
53
- return stored_commit_hash
54
-
55
- try:
56
- stored_commit_hash = run(f"{git} rev-parse HEAD").strip()
57
- except Exception:
58
- stored_commit_hash = "<none>"
59
-
60
- return stored_commit_hash
61
-
62
-
63
- def run(command, desc=None, errdesc=None, custom_env=None, live=False):
64
- if desc is not None:
65
- print(desc)
66
-
67
- if live:
68
- result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
69
- if result.returncode != 0:
70
- raise RuntimeError(f"""{errdesc or 'Error running command'}.
71
- Command: {command}
72
- Error code: {result.returncode}""")
73
-
74
- return ""
75
-
76
- result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
77
-
78
- if result.returncode != 0:
79
-
80
- message = f"""{errdesc or 'Error running command'}.
81
- Command: {command}
82
- Error code: {result.returncode}
83
- stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
84
- stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
85
- """
86
- raise RuntimeError(message)
87
-
88
- return result.stdout.decode(encoding="utf8", errors="ignore")
89
-
90
-
91
- def check_run(command):
92
- result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
93
- return result.returncode == 0
94
-
95
-
96
- def is_installed(package):
97
- try:
98
- spec = importlib.util.find_spec(package)
99
- except ModuleNotFoundError:
100
- return False
101
-
102
- return spec is not None
103
-
104
-
105
- def repo_dir(name):
106
- return os.path.join(script_path, dir_repos, name)
107
-
108
-
109
- def run_python(code, desc=None, errdesc=None):
110
- return run(f'"{python}" -c "{code}"', desc, errdesc)
111
-
112
-
113
- def run_pip(args, desc=None):
114
- if skip_install:
115
- return
116
-
117
- index_url_line = f' --index-url {index_url}' if index_url != '' else ''
118
- return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
119
-
120
-
121
- def check_run_python(code):
122
- return check_run(f'"{python}" -c "{code}"')
123
-
124
-
125
- def git_clone(url, dir, name, commithash=None):
126
- # TODO clone into temporary dir and move if successful
127
-
128
- if os.path.exists(dir):
129
- if commithash is None:
130
- return
131
-
132
- current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
133
- if current_hash == commithash:
134
- return
135
-
136
- run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
137
- run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
138
- return
139
-
140
- run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
141
-
142
- if commithash is not None:
143
- run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
144
-
145
-
146
- def git_pull_recursive(dir):
147
- for subdir, _, _ in os.walk(dir):
148
- if os.path.exists(os.path.join(subdir, '.git')):
149
- try:
150
- output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
151
- print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
152
- except subprocess.CalledProcessError as e:
153
- print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
154
-
155
-
156
- def run_extension_installer(extension_dir):
157
- path_installer = os.path.join(extension_dir, "install.py")
158
- if not os.path.isfile(path_installer):
159
- return
160
-
161
- try:
162
- env = os.environ.copy()
163
- env['PYTHONPATH'] = os.path.abspath(".")
164
-
165
- print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
166
- except Exception as e:
167
- print(e, file=sys.stderr)
168
-
169
-
170
- def prepare_environment():
171
- global skip_install
172
-
173
- torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113")
174
-
175
- ## check windows
176
- if sys.platform != 'win32':
177
- requirements_file = os.environ.get('REQS_FILE', "req.txt")
178
- else:
179
- requirements_file = os.environ.get('REQS_FILE', "requirements.txt")
180
-
181
- commit = commit_hash()
182
-
183
- print(f"Python {sys.version}")
184
- print(f"Commit hash: {commit}")
185
-
186
- if not is_installed("torch") or not is_installed("torchvision"):
187
- run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
188
-
189
- run_pip(f"install -r \"{requirements_file}\"", "requirements for SadTalker WebUI (may take longer time in first time)")
190
-
191
- if sys.platform != 'win32' and not is_installed('tts'):
192
- run_pip(f"install TTS", "install TTS individually in SadTalker, which might not work on windows.")
193
-
194
-
195
- def start():
196
- print(f"Launching SadTalker Web UI")
197
- from app_sadtalker import sadtalker_demo
198
- demo = sadtalker_demo()
199
- demo.queue()
200
- demo.launch()
201
-
202
- if __name__ == "__main__":
203
- prepare_environment()
204
- start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/predict.py DELETED
@@ -1,192 +0,0 @@
1
- """run bash scripts/download_models.sh first to prepare the weights file"""
2
- import os
3
- import shutil
4
- from argparse import Namespace
5
- from src.utils.preprocess import CropAndExtract
6
- from src.test_audio2coeff import Audio2Coeff
7
- from src.facerender.animate import AnimateFromCoeff
8
- from src.generate_batch import get_data
9
- from src.generate_facerender_batch import get_facerender_data
10
- from src.utils.init_path import init_path
11
- from cog import BasePredictor, Input, Path
12
-
13
- checkpoints = "checkpoints"
14
-
15
-
16
- class Predictor(BasePredictor):
17
- def setup(self):
18
- """Load the model into memory to make running multiple predictions efficient"""
19
- device = "cuda"
20
-
21
-
22
- sadtalker_paths = init_path(checkpoints,os.path.join("src","config"))
23
-
24
- # init model
25
- self.preprocess_model = CropAndExtract(sadtalker_paths, device
26
- )
27
-
28
- self.audio_to_coeff = Audio2Coeff(
29
- sadtalker_paths,
30
- device,
31
- )
32
-
33
- self.animate_from_coeff = {
34
- "full": AnimateFromCoeff(
35
- sadtalker_paths,
36
- device,
37
- ),
38
- "others": AnimateFromCoeff(
39
- sadtalker_paths,
40
- device,
41
- ),
42
- }
43
-
44
- def predict(
45
- self,
46
- source_image: Path = Input(
47
- description="Upload the source image, it can be video.mp4 or picture.png",
48
- ),
49
- driven_audio: Path = Input(
50
- description="Upload the driven audio, accepts .wav and .mp4 file",
51
- ),
52
- enhancer: str = Input(
53
- description="Choose a face enhancer",
54
- choices=["gfpgan", "RestoreFormer"],
55
- default="gfpgan",
56
- ),
57
- preprocess: str = Input(
58
- description="how to preprocess the images",
59
- choices=["crop", "resize", "full"],
60
- default="full",
61
- ),
62
- ref_eyeblink: Path = Input(
63
- description="path to reference video providing eye blinking",
64
- default=None,
65
- ),
66
- ref_pose: Path = Input(
67
- description="path to reference video providing pose",
68
- default=None,
69
- ),
70
- still: bool = Input(
71
- description="can crop back to the original videos for the full body aniamtion when preprocess is full",
72
- default=True,
73
- ),
74
- ) -> Path:
75
- """Run a single prediction on the model"""
76
-
77
- animate_from_coeff = (
78
- self.animate_from_coeff["full"]
79
- if preprocess == "full"
80
- else self.animate_from_coeff["others"]
81
- )
82
-
83
- args = load_default()
84
- args.pic_path = str(source_image)
85
- args.audio_path = str(driven_audio)
86
- device = "cuda"
87
- args.still = still
88
- args.ref_eyeblink = None if ref_eyeblink is None else str(ref_eyeblink)
89
- args.ref_pose = None if ref_pose is None else str(ref_pose)
90
-
91
- # crop image and extract 3dmm from image
92
- results_dir = "results"
93
- if os.path.exists(results_dir):
94
- shutil.rmtree(results_dir)
95
- os.makedirs(results_dir)
96
- first_frame_dir = os.path.join(results_dir, "first_frame_dir")
97
- os.makedirs(first_frame_dir)
98
-
99
- print("3DMM Extraction for source image")
100
- first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(
101
- args.pic_path, first_frame_dir, preprocess, source_image_flag=True
102
- )
103
- if first_coeff_path is None:
104
- print("Can't get the coeffs of the input")
105
- return
106
-
107
- if ref_eyeblink is not None:
108
- ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[
109
- 0
110
- ]
111
- ref_eyeblink_frame_dir = os.path.join(results_dir, ref_eyeblink_videoname)
112
- os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
113
- print("3DMM Extraction for the reference video providing eye blinking")
114
- ref_eyeblink_coeff_path, _, _ = self.preprocess_model.generate(
115
- ref_eyeblink, ref_eyeblink_frame_dir
116
- )
117
- else:
118
- ref_eyeblink_coeff_path = None
119
-
120
- if ref_pose is not None:
121
- if ref_pose == ref_eyeblink:
122
- ref_pose_coeff_path = ref_eyeblink_coeff_path
123
- else:
124
- ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
125
- ref_pose_frame_dir = os.path.join(results_dir, ref_pose_videoname)
126
- os.makedirs(ref_pose_frame_dir, exist_ok=True)
127
- print("3DMM Extraction for the reference video providing pose")
128
- ref_pose_coeff_path, _, _ = self.preprocess_model.generate(
129
- ref_pose, ref_pose_frame_dir
130
- )
131
- else:
132
- ref_pose_coeff_path = None
133
-
134
- # audio2ceoff
135
- batch = get_data(
136
- first_coeff_path,
137
- args.audio_path,
138
- device,
139
- ref_eyeblink_coeff_path,
140
- still=still,
141
- )
142
- coeff_path = self.audio_to_coeff.generate(
143
- batch, results_dir, args.pose_style, ref_pose_coeff_path
144
- )
145
- # coeff2video
146
- print("coeff2video")
147
- data = get_facerender_data(
148
- coeff_path,
149
- crop_pic_path,
150
- first_coeff_path,
151
- args.audio_path,
152
- args.batch_size,
153
- args.input_yaw,
154
- args.input_pitch,
155
- args.input_roll,
156
- expression_scale=args.expression_scale,
157
- still_mode=still,
158
- preprocess=preprocess,
159
- )
160
- animate_from_coeff.generate(
161
- data, results_dir, args.pic_path, crop_info,
162
- enhancer=enhancer, background_enhancer=args.background_enhancer,
163
- preprocess=preprocess)
164
-
165
- output = "/tmp/out.mp4"
166
- mp4_path = os.path.join(results_dir, [f for f in os.listdir(results_dir) if "enhanced.mp4" in f][0])
167
- shutil.copy(mp4_path, output)
168
-
169
- return Path(output)
170
-
171
-
172
- def load_default():
173
- return Namespace(
174
- pose_style=0,
175
- batch_size=2,
176
- expression_scale=1.0,
177
- input_yaw=None,
178
- input_pitch=None,
179
- input_roll=None,
180
- background_enhancer=None,
181
- face3dvis=False,
182
- net_recon="resnet50",
183
- init_path=None,
184
- use_last_fc=False,
185
- bfm_folder="./src/config/",
186
- bfm_model="BFM_model_front.mat",
187
- focal=1015.0,
188
- center=112.0,
189
- camera_d=10.0,
190
- z_near=5.0,
191
- z_far=15.0,
192
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/req.txt DELETED
@@ -1,22 +0,0 @@
1
- llvmlite==0.38.1
2
- numpy==1.21.6
3
- face_alignment==1.3.5
4
- imageio==2.19.3
5
- imageio-ffmpeg==0.4.7
6
- librosa==0.10.0.post2
7
- numba==0.55.1
8
- resampy==0.3.1
9
- pydub==0.25.1
10
- scipy==1.10.1
11
- kornia==0.6.8
12
- tqdm
13
- yacs==0.1.8
14
- pyyaml
15
- joblib==1.1.0
16
- scikit-image==0.19.3
17
- basicsr==1.4.2
18
- facexlib==0.3.0
19
- gradio
20
- gfpgan
21
- av
22
- safetensors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/requirements.txt DELETED
@@ -1,37 +0,0 @@
1
- numpy==1.23.5
2
- matplotlib
3
- moviepy
4
- yt-dlp
5
- pydub
6
- demucs
7
- gradio
8
- torch
9
- flask
10
- flask-cors
11
- torchaudio
12
- fairseq==0.12.2
13
- scipy==1.10.1
14
- pyworld>=0.3.2
15
- faiss-cpu==1.7.3
16
- praat-parselmouth>=0.4.2
17
- librosa==0.9.2
18
- edge-tts
19
- torchcrepe
20
- Pillow==9.5.0
21
-
22
- face_alignment==1.3.5
23
- imageio==2.19.3
24
- imageio-ffmpeg==0.4.7
25
- numba
26
- resampy==0.3.1
27
- kornia==0.6.8
28
- tqdm
29
- yacs==0.1.8
30
- pyyaml
31
- joblib==1.1.0
32
- scikit-image==0.19.3
33
- basicsr==1.4.2
34
- facexlib==0.3.0
35
- gfpgan
36
- av
37
- safetensors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/scripts/download_models.sh DELETED
@@ -1,32 +0,0 @@
1
- mkdir ./checkpoints
2
-
3
- # lagency download link
4
- # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2exp_00300-model.pth -O ./checkpoints/auido2exp_00300-model.pth
5
- # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2pose_00140-model.pth -O ./checkpoints/auido2pose_00140-model.pth
6
- # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/epoch_20.pth -O ./checkpoints/epoch_20.pth
7
- # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/facevid2vid_00189-model.pth.tar -O ./checkpoints/facevid2vid_00189-model.pth.tar
8
- # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/shape_predictor_68_face_landmarks.dat -O ./checkpoints/shape_predictor_68_face_landmarks.dat
9
- # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/wav2lip.pth -O ./checkpoints/wav2lip.pth
10
- # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar
11
- # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar
12
- # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/hub.zip -O ./checkpoints/hub.zip
13
- # unzip -n ./checkpoints/hub.zip -d ./checkpoints/
14
-
15
-
16
- #### download the new links.
17
- wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar
18
- wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar
19
- wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_256.safetensors -O ./checkpoints/SadTalker_V0.0.2_256.safetensors
20
- wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_512.safetensors -O ./checkpoints/SadTalker_V0.0.2_512.safetensors
21
-
22
-
23
- # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/BFM_Fitting.zip -O ./checkpoints/BFM_Fitting.zip
24
- # unzip -n ./checkpoints/BFM_Fitting.zip -d ./checkpoints/
25
-
26
- ### enhancer
27
- mkdir -p ./gfpgan/weights
28
- wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth -O ./gfpgan/weights/alignment_WFLW_4HG.pth
29
- wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth -O ./gfpgan/weights/detection_Resnet50_Final.pth
30
- wget -nc https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -O ./gfpgan/weights/GFPGANv1.4.pth
31
- wget -nc https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth -O ./gfpgan/weights/parsing_parsenet.pth
32
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/scripts/extension.py DELETED
@@ -1,189 +0,0 @@
1
- import os, sys
2
- from pathlib import Path
3
- import tempfile
4
- import gradio as gr
5
- from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call
6
- from modules.shared import opts, OptionInfo
7
- from modules import shared, paths, script_callbacks
8
- import launch
9
- import glob
10
- from huggingface_hub import snapshot_download
11
-
12
-
13
-
14
- def check_all_files_safetensor(current_dir):
15
- kv = {
16
- "SadTalker_V0.0.2_256.safetensors": "sadtalker-256",
17
- "SadTalker_V0.0.2_512.safetensors": "sadtalker-512",
18
- "mapping_00109-model.pth.tar" : "mapping-109" ,
19
- "mapping_00229-model.pth.tar" : "mapping-229" ,
20
- }
21
-
22
- if not os.path.isdir(current_dir):
23
- return False
24
-
25
- dirs = os.listdir(current_dir)
26
-
27
- for f in dirs:
28
- if f in kv.keys():
29
- del kv[f]
30
-
31
- return len(kv.keys()) == 0
32
-
33
- def check_all_files(current_dir):
34
- kv = {
35
- "auido2exp_00300-model.pth": "audio2exp",
36
- "auido2pose_00140-model.pth": "audio2pose",
37
- "epoch_20.pth": "face_recon",
38
- "facevid2vid_00189-model.pth.tar": "face-render",
39
- "mapping_00109-model.pth.tar" : "mapping-109" ,
40
- "mapping_00229-model.pth.tar" : "mapping-229" ,
41
- "wav2lip.pth": "wav2lip",
42
- "shape_predictor_68_face_landmarks.dat": "dlib",
43
- }
44
-
45
- if not os.path.isdir(current_dir):
46
- return False
47
-
48
- dirs = os.listdir(current_dir)
49
-
50
- for f in dirs:
51
- if f in kv.keys():
52
- del kv[f]
53
-
54
- return len(kv.keys()) == 0
55
-
56
-
57
-
58
- def download_model(local_dir='./checkpoints'):
59
- REPO_ID = 'vinthony/SadTalker'
60
- snapshot_download(repo_id=REPO_ID, local_dir=local_dir, local_dir_use_symlinks=False)
61
-
62
- def get_source_image(image):
63
- return image
64
-
65
- def get_img_from_txt2img(x):
66
- talker_path = Path(paths.script_path) / "outputs"
67
- imgs_from_txt_dir = str(talker_path / "txt2img-images/")
68
- imgs = glob.glob(imgs_from_txt_dir+'/*/*.png')
69
- imgs.sort(key=lambda x:os.path.getmtime(os.path.join(imgs_from_txt_dir, x)))
70
- img_from_txt_path = os.path.join(imgs_from_txt_dir, imgs[-1])
71
- return img_from_txt_path, img_from_txt_path
72
-
73
- def get_img_from_img2img(x):
74
- talker_path = Path(paths.script_path) / "outputs"
75
- imgs_from_img_dir = str(talker_path / "img2img-images/")
76
- imgs = glob.glob(imgs_from_img_dir+'/*/*.png')
77
- imgs.sort(key=lambda x:os.path.getmtime(os.path.join(imgs_from_img_dir, x)))
78
- img_from_img_path = os.path.join(imgs_from_img_dir, imgs[-1])
79
- return img_from_img_path, img_from_img_path
80
-
81
- def get_default_checkpoint_path():
82
- # check the path of models/checkpoints and extensions/
83
- checkpoint_path = Path(paths.script_path) / "models"/ "SadTalker"
84
- extension_checkpoint_path = Path(paths.script_path) / "extensions"/ "SadTalker" / "checkpoints"
85
-
86
- if check_all_files_safetensor(checkpoint_path):
87
- # print('founding sadtalker checkpoint in ' + str(checkpoint_path))
88
- return checkpoint_path
89
-
90
- if check_all_files_safetensor(extension_checkpoint_path):
91
- # print('founding sadtalker checkpoint in ' + str(extension_checkpoint_path))
92
- return extension_checkpoint_path
93
-
94
- if check_all_files(checkpoint_path):
95
- # print('founding sadtalker checkpoint in ' + str(checkpoint_path))
96
- return checkpoint_path
97
-
98
- if check_all_files(extension_checkpoint_path):
99
- # print('founding sadtalker checkpoint in ' + str(extension_checkpoint_path))
100
- return extension_checkpoint_path
101
-
102
- return None
103
-
104
-
105
-
106
- def install():
107
-
108
- kv = {
109
- "face_alignment": "face-alignment==1.3.5",
110
- "imageio": "imageio==2.19.3",
111
- "imageio_ffmpeg": "imageio-ffmpeg==0.4.7",
112
- "librosa":"librosa==0.8.0",
113
- "pydub":"pydub==0.25.1",
114
- "scipy":"scipy==1.8.1",
115
- "tqdm": "tqdm",
116
- "yacs":"yacs==0.1.8",
117
- "yaml": "pyyaml",
118
- "av":"av",
119
- "gfpgan": "gfpgan",
120
- }
121
-
122
- # # dlib is not necessary currently
123
- # if 'darwin' in sys.platform:
124
- # kv['dlib'] = "dlib"
125
- # else:
126
- # kv['dlib'] = 'dlib-bin'
127
-
128
- # #### we need to have a newer version of imageio for our method.
129
- # launch.run_pip("install imageio==2.19.3", "requirements for SadTalker")
130
-
131
- for k,v in kv.items():
132
- if not launch.is_installed(k):
133
- print(k, launch.is_installed(k))
134
- launch.run_pip("install "+ v, "requirements for SadTalker")
135
-
136
- if os.getenv('SADTALKER_CHECKPOINTS'):
137
- print('load Sadtalker Checkpoints from '+ os.getenv('SADTALKER_CHECKPOINTS'))
138
-
139
- elif get_default_checkpoint_path() is not None:
140
- os.environ['SADTALKER_CHECKPOINTS'] = str(get_default_checkpoint_path())
141
- else:
142
-
143
- print(
144
- """"
145
- SadTalker will not support download all the files from hugging face, which will take a long time.
146
-
147
- please manually set the SADTALKER_CHECKPOINTS in `webui_user.bat`(windows) or `webui_user.sh`(linux)
148
- """
149
- )
150
-
151
- # python = sys.executable
152
-
153
- # launch.run(f'"{python}" -m pip uninstall -y huggingface_hub', live=True)
154
- # launch.run(f'"{python}" -m pip install --upgrade git+https://github.com/huggingface/huggingface_hub@main', live=True)
155
- # ### run the scripts to downlod models to correct localtion.
156
- # # print('download models for SadTalker')
157
- # # launch.run("cd " + paths.script_path+"/extensions/SadTalker && bash ./scripts/download_models.sh", live=True)
158
- # # print('SadTalker is successfully installed!')
159
- # download_model(paths.script_path+'/extensions/SadTalker/checkpoints')
160
-
161
-
162
- def on_ui_tabs():
163
- install()
164
-
165
- sys.path.extend([paths.script_path+'/extensions/SadTalker'])
166
-
167
- repo_dir = paths.script_path+'/extensions/SadTalker/'
168
-
169
- result_dir = opts.sadtalker_result_dir
170
- os.makedirs(result_dir, exist_ok=True)
171
-
172
- from app_sadtalker import sadtalker_demo
173
-
174
- if os.getenv('SADTALKER_CHECKPOINTS'):
175
- checkpoint_path = os.getenv('SADTALKER_CHECKPOINTS')
176
- else:
177
- checkpoint_path = repo_dir+'checkpoints/'
178
-
179
- audio_to_video = sadtalker_demo(checkpoint_path=checkpoint_path, config_path=repo_dir+'src/config', warpfn = wrap_queued_call)
180
-
181
- return [(audio_to_video, "SadTalker", "extension")]
182
-
183
- def on_ui_settings():
184
- talker_path = Path(paths.script_path) / "outputs"
185
- section = ('extension', "SadTalker")
186
- opts.add_option("sadtalker_result_dir", OptionInfo(str(talker_path / "SadTalker/"), "Path to save results of sadtalker", section=section))
187
-
188
- script_callbacks.on_ui_settings(on_ui_settings)
189
- script_callbacks.on_ui_tabs(on_ui_tabs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/scripts/test.sh DELETED
@@ -1,21 +0,0 @@
1
- # ### some test command before commit.
2
- # python inference.py --preprocess crop --size 256
3
- # python inference.py --preprocess crop --size 512
4
-
5
- # python inference.py --preprocess extcrop --size 256
6
- # python inference.py --preprocess extcrop --size 512
7
-
8
- # python inference.py --preprocess resize --size 256
9
- # python inference.py --preprocess resize --size 512
10
-
11
- # python inference.py --preprocess full --size 256
12
- # python inference.py --preprocess full --size 512
13
-
14
- # python inference.py --preprocess extfull --size 256
15
- # python inference.py --preprocess extfull --size 512
16
-
17
- python inference.py --preprocess full --size 256 --enhancer gfpgan
18
- python inference.py --preprocess full --size 512 --enhancer gfpgan
19
-
20
- python inference.py --preprocess full --size 256 --enhancer gfpgan --still
21
- python inference.py --preprocess full --size 512 --enhancer gfpgan --still
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/audio2exp_models/audio2exp.py DELETED
@@ -1,41 +0,0 @@
1
- from tqdm import tqdm
2
- import torch
3
- from torch import nn
4
-
5
-
6
- class Audio2Exp(nn.Module):
7
- def __init__(self, netG, cfg, device, prepare_training_loss=False):
8
- super(Audio2Exp, self).__init__()
9
- self.cfg = cfg
10
- self.device = device
11
- self.netG = netG.to(device)
12
-
13
- def test(self, batch):
14
-
15
- mel_input = batch['indiv_mels'] # bs T 1 80 16
16
- bs = mel_input.shape[0]
17
- T = mel_input.shape[1]
18
-
19
- exp_coeff_pred = []
20
-
21
- for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
22
-
23
- current_mel_input = mel_input[:,i:i+10]
24
-
25
- #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
26
- ref = batch['ref'][:, :, :64][:, i:i+10]
27
- ratio = batch['ratio_gt'][:, i:i+10] #bs T
28
-
29
- audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
30
-
31
- curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
32
-
33
- exp_coeff_pred += [curr_exp_coeff_pred]
34
-
35
- # BS x T x 64
36
- results_dict = {
37
- 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
38
- }
39
- return results_dict
40
-
41
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/audio2exp_models/networks.py DELETED
@@ -1,74 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch import nn
4
-
5
- class Conv2d(nn.Module):
6
- def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
7
- super().__init__(*args, **kwargs)
8
- self.conv_block = nn.Sequential(
9
- nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
- nn.BatchNorm2d(cout)
11
- )
12
- self.act = nn.ReLU()
13
- self.residual = residual
14
- self.use_act = use_act
15
-
16
- def forward(self, x):
17
- out = self.conv_block(x)
18
- if self.residual:
19
- out += x
20
-
21
- if self.use_act:
22
- return self.act(out)
23
- else:
24
- return out
25
-
26
- class SimpleWrapperV2(nn.Module):
27
- def __init__(self) -> None:
28
- super().__init__()
29
- self.audio_encoder = nn.Sequential(
30
- Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
31
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
32
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
33
-
34
- Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
35
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
36
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
37
-
38
- Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
39
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
40
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
41
-
42
- Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
43
- Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
44
-
45
- Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
46
- Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
47
- )
48
-
49
- #### load the pre-trained audio_encoder
50
- #self.audio_encoder = self.audio_encoder.to(device)
51
- '''
52
- wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
53
- state_dict = self.audio_encoder.state_dict()
54
-
55
- for k,v in wav2lip_state_dict.items():
56
- if 'audio_encoder' in k:
57
- print('init:', k)
58
- state_dict[k.replace('module.audio_encoder.', '')] = v
59
- self.audio_encoder.load_state_dict(state_dict)
60
- '''
61
-
62
- self.mapping1 = nn.Linear(512+64+1, 64)
63
- #self.mapping2 = nn.Linear(30, 64)
64
- #nn.init.constant_(self.mapping1.weight, 0.)
65
- nn.init.constant_(self.mapping1.bias, 0.)
66
-
67
- def forward(self, x, ref, ratio):
68
- x = self.audio_encoder(x).view(x.size(0), -1)
69
- ref_reshape = ref.reshape(x.size(0), -1)
70
- ratio = ratio.reshape(x.size(0), -1)
71
-
72
- y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
73
- out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
74
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/audio2pose_models/audio2pose.py DELETED
@@ -1,94 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from src.audio2pose_models.cvae import CVAE
4
- from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
5
- from src.audio2pose_models.audio_encoder import AudioEncoder
6
-
7
- class Audio2Pose(nn.Module):
8
- def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
9
- super().__init__()
10
- self.cfg = cfg
11
- self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
12
- self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
13
- self.device = device
14
-
15
- self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
16
- self.audio_encoder.eval()
17
- for param in self.audio_encoder.parameters():
18
- param.requires_grad = False
19
-
20
- self.netG = CVAE(cfg)
21
- self.netD_motion = PoseSequenceDiscriminator(cfg)
22
-
23
-
24
- def forward(self, x):
25
-
26
- batch = {}
27
- coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
28
- batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
29
- batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6
30
- batch['class'] = x['class'].squeeze(0).cuda() # bs
31
- indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
32
-
33
- # forward
34
- audio_emb_list = []
35
- audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
36
- batch['audio_emb'] = audio_emb
37
- batch = self.netG(batch)
38
-
39
- pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
40
- pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6
41
- pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6
42
-
43
- batch['pose_pred'] = pose_pred
44
- batch['pose_gt'] = pose_gt
45
-
46
- return batch
47
-
48
- def test(self, x):
49
-
50
- batch = {}
51
- ref = x['ref'] #bs 1 70
52
- batch['ref'] = x['ref'][:,0,-6:]
53
- batch['class'] = x['class']
54
- bs = ref.shape[0]
55
-
56
- indiv_mels= x['indiv_mels'] # bs T 1 80 16
57
- indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
58
- num_frames = x['num_frames']
59
- num_frames = int(num_frames) - 1
60
-
61
- #
62
- div = num_frames//self.seq_len
63
- re = num_frames%self.seq_len
64
- audio_emb_list = []
65
- pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
66
- device=batch['ref'].device)]
67
-
68
- for i in range(div):
69
- z = torch.randn(bs, self.latent_dim).to(ref.device)
70
- batch['z'] = z
71
- audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
72
- batch['audio_emb'] = audio_emb
73
- batch = self.netG.test(batch)
74
- pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
75
-
76
- if re != 0:
77
- z = torch.randn(bs, self.latent_dim).to(ref.device)
78
- batch['z'] = z
79
- audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
80
- if audio_emb.shape[1] != self.seq_len:
81
- pad_dim = self.seq_len-audio_emb.shape[1]
82
- pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
83
- audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
84
- batch['audio_emb'] = audio_emb
85
- batch = self.netG.test(batch)
86
- pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
87
-
88
- pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
89
- batch['pose_motion_pred'] = pose_motion_pred
90
-
91
- pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
92
-
93
- batch['pose_pred'] = pose_pred
94
- return batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/audio2pose_models/audio_encoder.py DELETED
@@ -1,64 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import functional as F
4
-
5
- class Conv2d(nn.Module):
6
- def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
7
- super().__init__(*args, **kwargs)
8
- self.conv_block = nn.Sequential(
9
- nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
- nn.BatchNorm2d(cout)
11
- )
12
- self.act = nn.ReLU()
13
- self.residual = residual
14
-
15
- def forward(self, x):
16
- out = self.conv_block(x)
17
- if self.residual:
18
- out += x
19
- return self.act(out)
20
-
21
- class AudioEncoder(nn.Module):
22
- def __init__(self, wav2lip_checkpoint, device):
23
- super(AudioEncoder, self).__init__()
24
-
25
- self.audio_encoder = nn.Sequential(
26
- Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
27
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
28
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
29
-
30
- Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
31
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
32
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
33
-
34
- Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
35
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
36
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
37
-
38
- Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
39
- Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
40
-
41
- Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
42
- Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
43
-
44
- #### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
45
- # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
46
- # state_dict = self.audio_encoder.state_dict()
47
-
48
- # for k,v in wav2lip_state_dict.items():
49
- # if 'audio_encoder' in k:
50
- # state_dict[k.replace('module.audio_encoder.', '')] = v
51
- # self.audio_encoder.load_state_dict(state_dict)
52
-
53
-
54
- def forward(self, audio_sequences):
55
- # audio_sequences = (B, T, 1, 80, 16)
56
- B = audio_sequences.size(0)
57
-
58
- audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
59
-
60
- audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
61
- dim = audio_embedding.shape[1]
62
- audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
63
-
64
- return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/audio2pose_models/cvae.py DELETED
@@ -1,149 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch import nn
4
- from src.audio2pose_models.res_unet import ResUnet
5
-
6
- def class2onehot(idx, class_num):
7
-
8
- assert torch.max(idx).item() < class_num
9
- onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
10
- onehot.scatter_(1, idx, 1)
11
- return onehot
12
-
13
- class CVAE(nn.Module):
14
- def __init__(self, cfg):
15
- super().__init__()
16
- encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
17
- decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
18
- latent_size = cfg.MODEL.CVAE.LATENT_SIZE
19
- num_classes = cfg.DATASET.NUM_CLASSES
20
- audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
21
- audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
22
- seq_len = cfg.MODEL.CVAE.SEQ_LEN
23
-
24
- self.latent_size = latent_size
25
-
26
- self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
27
- audio_emb_in_size, audio_emb_out_size, seq_len)
28
- self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
29
- audio_emb_in_size, audio_emb_out_size, seq_len)
30
- def reparameterize(self, mu, logvar):
31
- std = torch.exp(0.5 * logvar)
32
- eps = torch.randn_like(std)
33
- return mu + eps * std
34
-
35
- def forward(self, batch):
36
- batch = self.encoder(batch)
37
- mu = batch['mu']
38
- logvar = batch['logvar']
39
- z = self.reparameterize(mu, logvar)
40
- batch['z'] = z
41
- return self.decoder(batch)
42
-
43
- def test(self, batch):
44
- '''
45
- class_id = batch['class']
46
- z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
47
- batch['z'] = z
48
- '''
49
- return self.decoder(batch)
50
-
51
- class ENCODER(nn.Module):
52
- def __init__(self, layer_sizes, latent_size, num_classes,
53
- audio_emb_in_size, audio_emb_out_size, seq_len):
54
- super().__init__()
55
-
56
- self.resunet = ResUnet()
57
- self.num_classes = num_classes
58
- self.seq_len = seq_len
59
-
60
- self.MLP = nn.Sequential()
61
- layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
62
- for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
63
- self.MLP.add_module(
64
- name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
65
- self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
66
-
67
- self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
68
- self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
69
- self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
70
-
71
- self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
72
-
73
- def forward(self, batch):
74
- class_id = batch['class']
75
- pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
76
- ref = batch['ref'] #bs 6
77
- bs = pose_motion_gt.shape[0]
78
- audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
79
-
80
- #pose encode
81
- pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
82
- pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
83
-
84
- #audio mapping
85
- print(audio_in.shape)
86
- audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
87
- audio_out = audio_out.reshape(bs, -1)
88
-
89
- class_bias = self.classbias[class_id] #bs latent_size
90
- x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
91
- x_out = self.MLP(x_in)
92
-
93
- mu = self.linear_means(x_out)
94
- logvar = self.linear_means(x_out) #bs latent_size
95
-
96
- batch.update({'mu':mu, 'logvar':logvar})
97
- return batch
98
-
99
- class DECODER(nn.Module):
100
- def __init__(self, layer_sizes, latent_size, num_classes,
101
- audio_emb_in_size, audio_emb_out_size, seq_len):
102
- super().__init__()
103
-
104
- self.resunet = ResUnet()
105
- self.num_classes = num_classes
106
- self.seq_len = seq_len
107
-
108
- self.MLP = nn.Sequential()
109
- input_size = latent_size + seq_len*audio_emb_out_size + 6
110
- for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
111
- self.MLP.add_module(
112
- name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
113
- if i+1 < len(layer_sizes):
114
- self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
115
- else:
116
- self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
117
-
118
- self.pose_linear = nn.Linear(6, 6)
119
- self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
120
-
121
- self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
122
-
123
- def forward(self, batch):
124
-
125
- z = batch['z'] #bs latent_size
126
- bs = z.shape[0]
127
- class_id = batch['class']
128
- ref = batch['ref'] #bs 6
129
- audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
130
- #print('audio_in: ', audio_in[:, :, :10])
131
-
132
- audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
133
- #print('audio_out: ', audio_out[:, :, :10])
134
- audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
135
- class_bias = self.classbias[class_id] #bs latent_size
136
-
137
- z = z + class_bias
138
- x_in = torch.cat([ref, z, audio_out], dim=-1)
139
- x_out = self.MLP(x_in) # bs layer_sizes[-1]
140
- x_out = x_out.reshape((bs, self.seq_len, -1))
141
-
142
- #print('x_out: ', x_out)
143
-
144
- pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
145
-
146
- pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
147
-
148
- batch.update({'pose_motion_pred':pose_motion_pred})
149
- return batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/audio2pose_models/discriminator.py DELETED
@@ -1,76 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch import nn
4
-
5
- class ConvNormRelu(nn.Module):
6
- def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
7
- kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
8
- super().__init__()
9
- if kernel_size is None:
10
- if downsample:
11
- kernel_size, stride, padding = 4, 2, 1
12
- else:
13
- kernel_size, stride, padding = 3, 1, 1
14
-
15
- if conv_type == '2d':
16
- self.conv = nn.Conv2d(
17
- in_channels,
18
- out_channels,
19
- kernel_size,
20
- stride,
21
- padding,
22
- bias=False,
23
- )
24
- if norm == 'BN':
25
- self.norm = nn.BatchNorm2d(out_channels)
26
- elif norm == 'IN':
27
- self.norm = nn.InstanceNorm2d(out_channels)
28
- else:
29
- raise NotImplementedError
30
- elif conv_type == '1d':
31
- self.conv = nn.Conv1d(
32
- in_channels,
33
- out_channels,
34
- kernel_size,
35
- stride,
36
- padding,
37
- bias=False,
38
- )
39
- if norm == 'BN':
40
- self.norm = nn.BatchNorm1d(out_channels)
41
- elif norm == 'IN':
42
- self.norm = nn.InstanceNorm1d(out_channels)
43
- else:
44
- raise NotImplementedError
45
- nn.init.kaiming_normal_(self.conv.weight)
46
-
47
- self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
48
-
49
- def forward(self, x):
50
- x = self.conv(x)
51
- if isinstance(self.norm, nn.InstanceNorm1d):
52
- x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
53
- else:
54
- x = self.norm(x)
55
- x = self.act(x)
56
- return x
57
-
58
-
59
- class PoseSequenceDiscriminator(nn.Module):
60
- def __init__(self, cfg):
61
- super().__init__()
62
- self.cfg = cfg
63
- leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
64
-
65
- self.seq = nn.Sequential(
66
- ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
67
- ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
68
- ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
69
- nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
70
- )
71
-
72
- def forward(self, x):
73
- x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
74
- x = self.seq(x)
75
- x = x.squeeze(1)
76
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/audio2pose_models/networks.py DELETED
@@ -1,140 +0,0 @@
1
- import torch.nn as nn
2
- import torch
3
-
4
-
5
- class ResidualConv(nn.Module):
6
- def __init__(self, input_dim, output_dim, stride, padding):
7
- super(ResidualConv, self).__init__()
8
-
9
- self.conv_block = nn.Sequential(
10
- nn.BatchNorm2d(input_dim),
11
- nn.ReLU(),
12
- nn.Conv2d(
13
- input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
14
- ),
15
- nn.BatchNorm2d(output_dim),
16
- nn.ReLU(),
17
- nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
18
- )
19
- self.conv_skip = nn.Sequential(
20
- nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
21
- nn.BatchNorm2d(output_dim),
22
- )
23
-
24
- def forward(self, x):
25
-
26
- return self.conv_block(x) + self.conv_skip(x)
27
-
28
-
29
- class Upsample(nn.Module):
30
- def __init__(self, input_dim, output_dim, kernel, stride):
31
- super(Upsample, self).__init__()
32
-
33
- self.upsample = nn.ConvTranspose2d(
34
- input_dim, output_dim, kernel_size=kernel, stride=stride
35
- )
36
-
37
- def forward(self, x):
38
- return self.upsample(x)
39
-
40
-
41
- class Squeeze_Excite_Block(nn.Module):
42
- def __init__(self, channel, reduction=16):
43
- super(Squeeze_Excite_Block, self).__init__()
44
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
45
- self.fc = nn.Sequential(
46
- nn.Linear(channel, channel // reduction, bias=False),
47
- nn.ReLU(inplace=True),
48
- nn.Linear(channel // reduction, channel, bias=False),
49
- nn.Sigmoid(),
50
- )
51
-
52
- def forward(self, x):
53
- b, c, _, _ = x.size()
54
- y = self.avg_pool(x).view(b, c)
55
- y = self.fc(y).view(b, c, 1, 1)
56
- return x * y.expand_as(x)
57
-
58
-
59
- class ASPP(nn.Module):
60
- def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
61
- super(ASPP, self).__init__()
62
-
63
- self.aspp_block1 = nn.Sequential(
64
- nn.Conv2d(
65
- in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
66
- ),
67
- nn.ReLU(inplace=True),
68
- nn.BatchNorm2d(out_dims),
69
- )
70
- self.aspp_block2 = nn.Sequential(
71
- nn.Conv2d(
72
- in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
73
- ),
74
- nn.ReLU(inplace=True),
75
- nn.BatchNorm2d(out_dims),
76
- )
77
- self.aspp_block3 = nn.Sequential(
78
- nn.Conv2d(
79
- in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
80
- ),
81
- nn.ReLU(inplace=True),
82
- nn.BatchNorm2d(out_dims),
83
- )
84
-
85
- self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
86
- self._init_weights()
87
-
88
- def forward(self, x):
89
- x1 = self.aspp_block1(x)
90
- x2 = self.aspp_block2(x)
91
- x3 = self.aspp_block3(x)
92
- out = torch.cat([x1, x2, x3], dim=1)
93
- return self.output(out)
94
-
95
- def _init_weights(self):
96
- for m in self.modules():
97
- if isinstance(m, nn.Conv2d):
98
- nn.init.kaiming_normal_(m.weight)
99
- elif isinstance(m, nn.BatchNorm2d):
100
- m.weight.data.fill_(1)
101
- m.bias.data.zero_()
102
-
103
-
104
- class Upsample_(nn.Module):
105
- def __init__(self, scale=2):
106
- super(Upsample_, self).__init__()
107
-
108
- self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
109
-
110
- def forward(self, x):
111
- return self.upsample(x)
112
-
113
-
114
- class AttentionBlock(nn.Module):
115
- def __init__(self, input_encoder, input_decoder, output_dim):
116
- super(AttentionBlock, self).__init__()
117
-
118
- self.conv_encoder = nn.Sequential(
119
- nn.BatchNorm2d(input_encoder),
120
- nn.ReLU(),
121
- nn.Conv2d(input_encoder, output_dim, 3, padding=1),
122
- nn.MaxPool2d(2, 2),
123
- )
124
-
125
- self.conv_decoder = nn.Sequential(
126
- nn.BatchNorm2d(input_decoder),
127
- nn.ReLU(),
128
- nn.Conv2d(input_decoder, output_dim, 3, padding=1),
129
- )
130
-
131
- self.conv_attn = nn.Sequential(
132
- nn.BatchNorm2d(output_dim),
133
- nn.ReLU(),
134
- nn.Conv2d(output_dim, 1, 1),
135
- )
136
-
137
- def forward(self, x1, x2):
138
- out = self.conv_encoder(x1) + self.conv_decoder(x2)
139
- out = self.conv_attn(out)
140
- return out * x2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/audio2pose_models/res_unet.py DELETED
@@ -1,65 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from src.audio2pose_models.networks import ResidualConv, Upsample
4
-
5
-
6
- class ResUnet(nn.Module):
7
- def __init__(self, channel=1, filters=[32, 64, 128, 256]):
8
- super(ResUnet, self).__init__()
9
-
10
- self.input_layer = nn.Sequential(
11
- nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
12
- nn.BatchNorm2d(filters[0]),
13
- nn.ReLU(),
14
- nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
15
- )
16
- self.input_skip = nn.Sequential(
17
- nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
18
- )
19
-
20
- self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
21
- self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
22
-
23
- self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
24
-
25
- self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
26
- self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
27
-
28
- self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
29
- self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
30
-
31
- self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
32
- self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
33
-
34
- self.output_layer = nn.Sequential(
35
- nn.Conv2d(filters[0], 1, 1, 1),
36
- nn.Sigmoid(),
37
- )
38
-
39
- def forward(self, x):
40
- # Encode
41
- x1 = self.input_layer(x) + self.input_skip(x)
42
- x2 = self.residual_conv_1(x1)
43
- x3 = self.residual_conv_2(x2)
44
- # Bridge
45
- x4 = self.bridge(x3)
46
-
47
- # Decode
48
- x4 = self.upsample_1(x4)
49
- x5 = torch.cat([x4, x3], dim=1)
50
-
51
- x6 = self.up_residual_conv1(x5)
52
-
53
- x6 = self.upsample_2(x6)
54
- x7 = torch.cat([x6, x2], dim=1)
55
-
56
- x8 = self.up_residual_conv2(x7)
57
-
58
- x8 = self.upsample_3(x8)
59
- x9 = torch.cat([x8, x1], dim=1)
60
-
61
- x10 = self.up_residual_conv3(x9)
62
-
63
- output = self.output_layer(x10)
64
-
65
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/config/auido2exp.yaml DELETED
@@ -1,58 +0,0 @@
1
- DATASET:
2
- TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
3
- EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
4
- TRAIN_BATCH_SIZE: 32
5
- EVAL_BATCH_SIZE: 32
6
- EXP: True
7
- EXP_DIM: 64
8
- FRAME_LEN: 32
9
- COEFF_LEN: 73
10
- NUM_CLASSES: 46
11
- AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
- COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
13
- LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
14
- DEBUG: True
15
- NUM_REPEATS: 2
16
- T: 40
17
-
18
-
19
- MODEL:
20
- FRAMEWORK: V2
21
- AUDIOENCODER:
22
- LEAKY_RELU: True
23
- NORM: 'IN'
24
- DISCRIMINATOR:
25
- LEAKY_RELU: False
26
- INPUT_CHANNELS: 6
27
- CVAE:
28
- AUDIO_EMB_IN_SIZE: 512
29
- AUDIO_EMB_OUT_SIZE: 128
30
- SEQ_LEN: 32
31
- LATENT_SIZE: 256
32
- ENCODER_LAYER_SIZES: [192, 1024]
33
- DECODER_LAYER_SIZES: [1024, 192]
34
-
35
-
36
- TRAIN:
37
- MAX_EPOCH: 300
38
- GENERATOR:
39
- LR: 2.0e-5
40
- DISCRIMINATOR:
41
- LR: 1.0e-5
42
- LOSS:
43
- W_FEAT: 0
44
- W_COEFF_EXP: 2
45
- W_LM: 1.0e-2
46
- W_LM_MOUTH: 0
47
- W_REG: 0
48
- W_SYNC: 0
49
- W_COLOR: 0
50
- W_EXPRESSION: 0
51
- W_LIPREADING: 0.01
52
- W_LIPREADING_VV: 0
53
- W_EYE_BLINK: 4
54
-
55
- TAG:
56
- NAME: small_dataset
57
-
58
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/config/auido2pose.yaml DELETED
@@ -1,49 +0,0 @@
1
- DATASET:
2
- TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
3
- EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
4
- TRAIN_BATCH_SIZE: 64
5
- EVAL_BATCH_SIZE: 1
6
- EXP: True
7
- EXP_DIM: 64
8
- FRAME_LEN: 32
9
- COEFF_LEN: 73
10
- NUM_CLASSES: 46
11
- AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
- COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
13
- DEBUG: True
14
-
15
-
16
- MODEL:
17
- AUDIOENCODER:
18
- LEAKY_RELU: True
19
- NORM: 'IN'
20
- DISCRIMINATOR:
21
- LEAKY_RELU: False
22
- INPUT_CHANNELS: 6
23
- CVAE:
24
- AUDIO_EMB_IN_SIZE: 512
25
- AUDIO_EMB_OUT_SIZE: 6
26
- SEQ_LEN: 32
27
- LATENT_SIZE: 64
28
- ENCODER_LAYER_SIZES: [192, 128]
29
- DECODER_LAYER_SIZES: [128, 192]
30
-
31
-
32
- TRAIN:
33
- MAX_EPOCH: 150
34
- GENERATOR:
35
- LR: 1.0e-4
36
- DISCRIMINATOR:
37
- LR: 1.0e-4
38
- LOSS:
39
- LAMBDA_REG: 1
40
- LAMBDA_LANDMARKS: 0
41
- LAMBDA_VERTICES: 0
42
- LAMBDA_GAN_MOTION: 0.7
43
- LAMBDA_GAN_COEFF: 0
44
- LAMBDA_KL: 1
45
-
46
- TAG:
47
- NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
48
-
49
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/config/facerender.yaml DELETED
@@ -1,45 +0,0 @@
1
- model_params:
2
- common_params:
3
- num_kp: 15
4
- image_channel: 3
5
- feature_channel: 32
6
- estimate_jacobian: False # True
7
- kp_detector_params:
8
- temperature: 0.1
9
- block_expansion: 32
10
- max_features: 1024
11
- scale_factor: 0.25 # 0.25
12
- num_blocks: 5
13
- reshape_channel: 16384 # 16384 = 1024 * 16
14
- reshape_depth: 16
15
- he_estimator_params:
16
- block_expansion: 64
17
- max_features: 2048
18
- num_bins: 66
19
- generator_params:
20
- block_expansion: 64
21
- max_features: 512
22
- num_down_blocks: 2
23
- reshape_channel: 32
24
- reshape_depth: 16 # 512 = 32 * 16
25
- num_resblocks: 6
26
- estimate_occlusion_map: True
27
- dense_motion_params:
28
- block_expansion: 32
29
- max_features: 1024
30
- num_blocks: 5
31
- reshape_depth: 16
32
- compress: 4
33
- discriminator_params:
34
- scales: [1]
35
- block_expansion: 32
36
- max_features: 512
37
- num_blocks: 4
38
- sn: True
39
- mapping_params:
40
- coeff_nc: 70
41
- descriptor_nc: 1024
42
- layer: 3
43
- num_kp: 15
44
- num_bins: 66
45
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/config/facerender_still.yaml DELETED
@@ -1,45 +0,0 @@
1
- model_params:
2
- common_params:
3
- num_kp: 15
4
- image_channel: 3
5
- feature_channel: 32
6
- estimate_jacobian: False # True
7
- kp_detector_params:
8
- temperature: 0.1
9
- block_expansion: 32
10
- max_features: 1024
11
- scale_factor: 0.25 # 0.25
12
- num_blocks: 5
13
- reshape_channel: 16384 # 16384 = 1024 * 16
14
- reshape_depth: 16
15
- he_estimator_params:
16
- block_expansion: 64
17
- max_features: 2048
18
- num_bins: 66
19
- generator_params:
20
- block_expansion: 64
21
- max_features: 512
22
- num_down_blocks: 2
23
- reshape_channel: 32
24
- reshape_depth: 16 # 512 = 32 * 16
25
- num_resblocks: 6
26
- estimate_occlusion_map: True
27
- dense_motion_params:
28
- block_expansion: 32
29
- max_features: 1024
30
- num_blocks: 5
31
- reshape_depth: 16
32
- compress: 4
33
- discriminator_params:
34
- scales: [1]
35
- block_expansion: 32
36
- max_features: 512
37
- num_blocks: 4
38
- sn: True
39
- mapping_params:
40
- coeff_nc: 73
41
- descriptor_nc: 1024
42
- layer: 3
43
- num_kp: 15
44
- num_bins: 66
45
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/config/similarity_Lm3D_all.mat DELETED
Binary file (994 Bytes)
 
SadTalker/src/face3d/data/__init__.py DELETED
@@ -1,116 +0,0 @@
1
- """This package includes all the modules related to data loading and preprocessing
2
-
3
- To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4
- You need to implement four functions:
5
- -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6
- -- <__len__>: return the size of dataset.
7
- -- <__getitem__>: get a data point from data loader.
8
- -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
9
-
10
- Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11
- See our template dataset class 'template_dataset.py' for more details.
12
- """
13
- import numpy as np
14
- import importlib
15
- import torch.utils.data
16
- from face3d.data.base_dataset import BaseDataset
17
-
18
-
19
- def find_dataset_using_name(dataset_name):
20
- """Import the module "data/[dataset_name]_dataset.py".
21
-
22
- In the file, the class called DatasetNameDataset() will
23
- be instantiated. It has to be a subclass of BaseDataset,
24
- and it is case-insensitive.
25
- """
26
- dataset_filename = "data." + dataset_name + "_dataset"
27
- datasetlib = importlib.import_module(dataset_filename)
28
-
29
- dataset = None
30
- target_dataset_name = dataset_name.replace('_', '') + 'dataset'
31
- for name, cls in datasetlib.__dict__.items():
32
- if name.lower() == target_dataset_name.lower() \
33
- and issubclass(cls, BaseDataset):
34
- dataset = cls
35
-
36
- if dataset is None:
37
- raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
38
-
39
- return dataset
40
-
41
-
42
- def get_option_setter(dataset_name):
43
- """Return the static method <modify_commandline_options> of the dataset class."""
44
- dataset_class = find_dataset_using_name(dataset_name)
45
- return dataset_class.modify_commandline_options
46
-
47
-
48
- def create_dataset(opt, rank=0):
49
- """Create a dataset given the option.
50
-
51
- This function wraps the class CustomDatasetDataLoader.
52
- This is the main interface between this package and 'train.py'/'test.py'
53
-
54
- Example:
55
- >>> from data import create_dataset
56
- >>> dataset = create_dataset(opt)
57
- """
58
- data_loader = CustomDatasetDataLoader(opt, rank=rank)
59
- dataset = data_loader.load_data()
60
- return dataset
61
-
62
- class CustomDatasetDataLoader():
63
- """Wrapper class of Dataset class that performs multi-threaded data loading"""
64
-
65
- def __init__(self, opt, rank=0):
66
- """Initialize this class
67
-
68
- Step 1: create a dataset instance given the name [dataset_mode]
69
- Step 2: create a multi-threaded data loader.
70
- """
71
- self.opt = opt
72
- dataset_class = find_dataset_using_name(opt.dataset_mode)
73
- self.dataset = dataset_class(opt)
74
- self.sampler = None
75
- print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
76
- if opt.use_ddp and opt.isTrain:
77
- world_size = opt.world_size
78
- self.sampler = torch.utils.data.distributed.DistributedSampler(
79
- self.dataset,
80
- num_replicas=world_size,
81
- rank=rank,
82
- shuffle=not opt.serial_batches
83
- )
84
- self.dataloader = torch.utils.data.DataLoader(
85
- self.dataset,
86
- sampler=self.sampler,
87
- num_workers=int(opt.num_threads / world_size),
88
- batch_size=int(opt.batch_size / world_size),
89
- drop_last=True)
90
- else:
91
- self.dataloader = torch.utils.data.DataLoader(
92
- self.dataset,
93
- batch_size=opt.batch_size,
94
- shuffle=(not opt.serial_batches) and opt.isTrain,
95
- num_workers=int(opt.num_threads),
96
- drop_last=True
97
- )
98
-
99
- def set_epoch(self, epoch):
100
- self.dataset.current_epoch = epoch
101
- if self.sampler is not None:
102
- self.sampler.set_epoch(epoch)
103
-
104
- def load_data(self):
105
- return self
106
-
107
- def __len__(self):
108
- """Return the number of data in the dataset"""
109
- return min(len(self.dataset), self.opt.max_dataset_size)
110
-
111
- def __iter__(self):
112
- """Return a batch of data"""
113
- for i, data in enumerate(self.dataloader):
114
- if i * self.opt.batch_size >= self.opt.max_dataset_size:
115
- break
116
- yield data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/face3d/data/base_dataset.py DELETED
@@ -1,125 +0,0 @@
1
- """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2
-
3
- It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4
- """
5
- import random
6
- import numpy as np
7
- import torch.utils.data as data
8
- from PIL import Image
9
- import torchvision.transforms as transforms
10
- from abc import ABC, abstractmethod
11
-
12
-
13
- class BaseDataset(data.Dataset, ABC):
14
- """This class is an abstract base class (ABC) for datasets.
15
-
16
- To create a subclass, you need to implement the following four functions:
17
- -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18
- -- <__len__>: return the size of dataset.
19
- -- <__getitem__>: get a data point.
20
- -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
21
- """
22
-
23
- def __init__(self, opt):
24
- """Initialize the class; save the options in the class
25
-
26
- Parameters:
27
- opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28
- """
29
- self.opt = opt
30
- # self.root = opt.dataroot
31
- self.current_epoch = 0
32
-
33
- @staticmethod
34
- def modify_commandline_options(parser, is_train):
35
- """Add new dataset-specific options, and rewrite default values for existing options.
36
-
37
- Parameters:
38
- parser -- original option parser
39
- is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40
-
41
- Returns:
42
- the modified parser.
43
- """
44
- return parser
45
-
46
- @abstractmethod
47
- def __len__(self):
48
- """Return the total number of images in the dataset."""
49
- return 0
50
-
51
- @abstractmethod
52
- def __getitem__(self, index):
53
- """Return a data point and its metadata information.
54
-
55
- Parameters:
56
- index - - a random integer for data indexing
57
-
58
- Returns:
59
- a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60
- """
61
- pass
62
-
63
-
64
- def get_transform(grayscale=False):
65
- transform_list = []
66
- if grayscale:
67
- transform_list.append(transforms.Grayscale(1))
68
- transform_list += [transforms.ToTensor()]
69
- return transforms.Compose(transform_list)
70
-
71
- def get_affine_mat(opt, size):
72
- shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
73
- w, h = size
74
-
75
- if 'shift' in opt.preprocess:
76
- shift_pixs = int(opt.shift_pixs)
77
- shift_x = random.randint(-shift_pixs, shift_pixs)
78
- shift_y = random.randint(-shift_pixs, shift_pixs)
79
- if 'scale' in opt.preprocess:
80
- scale = 1 + opt.scale_delta * (2 * random.random() - 1)
81
- if 'rot' in opt.preprocess:
82
- rot_angle = opt.rot_angle * (2 * random.random() - 1)
83
- rot_rad = -rot_angle * np.pi/180
84
- if 'flip' in opt.preprocess:
85
- flip = random.random() > 0.5
86
-
87
- shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
88
- flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
89
- shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
90
- rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
91
- scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
92
- shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
93
-
94
- affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
95
- affine_inv = np.linalg.inv(affine)
96
- return affine, affine_inv, flip
97
-
98
- def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
99
- return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
100
-
101
- def apply_lm_affine(landmark, affine, flip, size):
102
- _, h = size
103
- lm = landmark.copy()
104
- lm[:, 1] = h - 1 - lm[:, 1]
105
- lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
106
- lm = lm @ np.transpose(affine)
107
- lm[:, :2] = lm[:, :2] / lm[:, 2:]
108
- lm = lm[:, :2]
109
- lm[:, 1] = h - 1 - lm[:, 1]
110
- if flip:
111
- lm_ = lm.copy()
112
- lm_[:17] = lm[16::-1]
113
- lm_[17:22] = lm[26:21:-1]
114
- lm_[22:27] = lm[21:16:-1]
115
- lm_[31:36] = lm[35:30:-1]
116
- lm_[36:40] = lm[45:41:-1]
117
- lm_[40:42] = lm[47:45:-1]
118
- lm_[42:46] = lm[39:35:-1]
119
- lm_[46:48] = lm[41:39:-1]
120
- lm_[48:55] = lm[54:47:-1]
121
- lm_[55:60] = lm[59:54:-1]
122
- lm_[60:65] = lm[64:59:-1]
123
- lm_[65:68] = lm[67:64:-1]
124
- lm = lm_
125
- return lm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/face3d/data/flist_dataset.py DELETED
@@ -1,125 +0,0 @@
1
- """This script defines the custom dataset for Deep3DFaceRecon_pytorch
2
- """
3
-
4
- import os.path
5
- from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
6
- from data.image_folder import make_dataset
7
- from PIL import Image
8
- import random
9
- import util.util as util
10
- import numpy as np
11
- import json
12
- import torch
13
- from scipy.io import loadmat, savemat
14
- import pickle
15
- from util.preprocess import align_img, estimate_norm
16
- from util.load_mats import load_lm3d
17
-
18
-
19
- def default_flist_reader(flist):
20
- """
21
- flist format: impath label\nimpath label\n ...(same to caffe's filelist)
22
- """
23
- imlist = []
24
- with open(flist, 'r') as rf:
25
- for line in rf.readlines():
26
- impath = line.strip()
27
- imlist.append(impath)
28
-
29
- return imlist
30
-
31
- def jason_flist_reader(flist):
32
- with open(flist, 'r') as fp:
33
- info = json.load(fp)
34
- return info
35
-
36
- def parse_label(label):
37
- return torch.tensor(np.array(label).astype(np.float32))
38
-
39
-
40
- class FlistDataset(BaseDataset):
41
- """
42
- It requires one directories to host training images '/path/to/data/train'
43
- You can train the model with the dataset flag '--dataroot /path/to/data'.
44
- """
45
-
46
- def __init__(self, opt):
47
- """Initialize this dataset class.
48
-
49
- Parameters:
50
- opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
51
- """
52
- BaseDataset.__init__(self, opt)
53
-
54
- self.lm3d_std = load_lm3d(opt.bfm_folder)
55
-
56
- msk_names = default_flist_reader(opt.flist)
57
- self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
58
-
59
- self.size = len(self.msk_paths)
60
- self.opt = opt
61
-
62
- self.name = 'train' if opt.isTrain else 'val'
63
- if '_' in opt.flist:
64
- self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
65
-
66
-
67
- def __getitem__(self, index):
68
- """Return a data point and its metadata information.
69
-
70
- Parameters:
71
- index (int) -- a random integer for data indexing
72
-
73
- Returns a dictionary that contains A, B, A_paths and B_paths
74
- img (tensor) -- an image in the input domain
75
- msk (tensor) -- its corresponding attention mask
76
- lm (tensor) -- its corresponding 3d landmarks
77
- im_paths (str) -- image paths
78
- aug_flag (bool) -- a flag used to tell whether its raw or augmented
79
- """
80
- msk_path = self.msk_paths[index % self.size] # make sure index is within then range
81
- img_path = msk_path.replace('mask/', '')
82
- lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
83
-
84
- raw_img = Image.open(img_path).convert('RGB')
85
- raw_msk = Image.open(msk_path).convert('RGB')
86
- raw_lm = np.loadtxt(lm_path).astype(np.float32)
87
-
88
- _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
89
-
90
- aug_flag = self.opt.use_aug and self.opt.isTrain
91
- if aug_flag:
92
- img, lm, msk = self._augmentation(img, lm, self.opt, msk)
93
-
94
- _, H = img.size
95
- M = estimate_norm(lm, H)
96
- transform = get_transform()
97
- img_tensor = transform(img)
98
- msk_tensor = transform(msk)[:1, ...]
99
- lm_tensor = parse_label(lm)
100
- M_tensor = parse_label(M)
101
-
102
-
103
- return {'imgs': img_tensor,
104
- 'lms': lm_tensor,
105
- 'msks': msk_tensor,
106
- 'M': M_tensor,
107
- 'im_paths': img_path,
108
- 'aug_flag': aug_flag,
109
- 'dataset': self.name}
110
-
111
- def _augmentation(self, img, lm, opt, msk=None):
112
- affine, affine_inv, flip = get_affine_mat(opt, img.size)
113
- img = apply_img_affine(img, affine_inv)
114
- lm = apply_lm_affine(lm, affine, flip, img.size)
115
- if msk is not None:
116
- msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
117
- return img, lm, msk
118
-
119
-
120
-
121
-
122
- def __len__(self):
123
- """Return the total number of images in the dataset.
124
- """
125
- return self.size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/face3d/data/image_folder.py DELETED
@@ -1,66 +0,0 @@
1
- """A modified image folder class
2
-
3
- We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4
- so that this class can load images from both current directory and its subdirectories.
5
- """
6
- import numpy as np
7
- import torch.utils.data as data
8
-
9
- from PIL import Image
10
- import os
11
- import os.path
12
-
13
- IMG_EXTENSIONS = [
14
- '.jpg', '.JPG', '.jpeg', '.JPEG',
15
- '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16
- '.tif', '.TIF', '.tiff', '.TIFF',
17
- ]
18
-
19
-
20
- def is_image_file(filename):
21
- return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22
-
23
-
24
- def make_dataset(dir, max_dataset_size=float("inf")):
25
- images = []
26
- assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
27
-
28
- for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
29
- for fname in fnames:
30
- if is_image_file(fname):
31
- path = os.path.join(root, fname)
32
- images.append(path)
33
- return images[:min(max_dataset_size, len(images))]
34
-
35
-
36
- def default_loader(path):
37
- return Image.open(path).convert('RGB')
38
-
39
-
40
- class ImageFolder(data.Dataset):
41
-
42
- def __init__(self, root, transform=None, return_paths=False,
43
- loader=default_loader):
44
- imgs = make_dataset(root)
45
- if len(imgs) == 0:
46
- raise(RuntimeError("Found 0 images in: " + root + "\n"
47
- "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
48
-
49
- self.root = root
50
- self.imgs = imgs
51
- self.transform = transform
52
- self.return_paths = return_paths
53
- self.loader = loader
54
-
55
- def __getitem__(self, index):
56
- path = self.imgs[index]
57
- img = self.loader(path)
58
- if self.transform is not None:
59
- img = self.transform(img)
60
- if self.return_paths:
61
- return img, path
62
- else:
63
- return img
64
-
65
- def __len__(self):
66
- return len(self.imgs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/face3d/data/template_dataset.py DELETED
@@ -1,75 +0,0 @@
1
- """Dataset class template
2
-
3
- This module provides a template for users to implement custom datasets.
4
- You can specify '--dataset_mode template' to use this dataset.
5
- The class name should be consistent with both the filename and its dataset_mode option.
6
- The filename should be <dataset_mode>_dataset.py
7
- The class name should be <Dataset_mode>Dataset.py
8
- You need to implement the following functions:
9
- -- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
10
- -- <__init__>: Initialize this dataset class.
11
- -- <__getitem__>: Return a data point and its metadata information.
12
- -- <__len__>: Return the number of images.
13
- """
14
- from data.base_dataset import BaseDataset, get_transform
15
- # from data.image_folder import make_dataset
16
- # from PIL import Image
17
-
18
-
19
- class TemplateDataset(BaseDataset):
20
- """A template dataset class for you to implement custom datasets."""
21
- @staticmethod
22
- def modify_commandline_options(parser, is_train):
23
- """Add new dataset-specific options, and rewrite default values for existing options.
24
-
25
- Parameters:
26
- parser -- original option parser
27
- is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
28
-
29
- Returns:
30
- the modified parser.
31
- """
32
- parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
33
- parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
34
- return parser
35
-
36
- def __init__(self, opt):
37
- """Initialize this dataset class.
38
-
39
- Parameters:
40
- opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
41
-
42
- A few things can be done here.
43
- - save the options (have been done in BaseDataset)
44
- - get image paths and meta information of the dataset.
45
- - define the image transformation.
46
- """
47
- # save the option and dataset root
48
- BaseDataset.__init__(self, opt)
49
- # get the image paths of your dataset;
50
- self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
51
- # define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
52
- self.transform = get_transform(opt)
53
-
54
- def __getitem__(self, index):
55
- """Return a data point and its metadata information.
56
-
57
- Parameters:
58
- index -- a random integer for data indexing
59
-
60
- Returns:
61
- a dictionary of data with their names. It usually contains the data itself and its metadata information.
62
-
63
- Step 1: get a random image path: e.g., path = self.image_paths[index]
64
- Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
65
- Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
66
- Step 4: return a data point as a dictionary.
67
- """
68
- path = 'temp' # needs to be a string
69
- data_A = None # needs to be a tensor
70
- data_B = None # needs to be a tensor
71
- return {'data_A': data_A, 'data_B': data_B, 'path': path}
72
-
73
- def __len__(self):
74
- """Return the total number of images."""
75
- return len(self.image_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/face3d/extract_kp_videos.py DELETED
@@ -1,108 +0,0 @@
1
- import os
2
- import cv2
3
- import time
4
- import glob
5
- import argparse
6
- import face_alignment
7
- import numpy as np
8
- from PIL import Image
9
- from tqdm import tqdm
10
- from itertools import cycle
11
-
12
- from torch.multiprocessing import Pool, Process, set_start_method
13
-
14
- class KeypointExtractor():
15
- def __init__(self, device):
16
- self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
17
- device=device)
18
-
19
- def extract_keypoint(self, images, name=None, info=True):
20
- if isinstance(images, list):
21
- keypoints = []
22
- if info:
23
- i_range = tqdm(images,desc='landmark Det:')
24
- else:
25
- i_range = images
26
-
27
- for image in i_range:
28
- current_kp = self.extract_keypoint(image)
29
- if np.mean(current_kp) == -1 and keypoints:
30
- keypoints.append(keypoints[-1])
31
- else:
32
- keypoints.append(current_kp[None])
33
-
34
- keypoints = np.concatenate(keypoints, 0)
35
- np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
36
- return keypoints
37
- else:
38
- while True:
39
- try:
40
- keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
41
- break
42
- except RuntimeError as e:
43
- if str(e).startswith('CUDA'):
44
- print("Warning: out of memory, sleep for 1s")
45
- time.sleep(1)
46
- else:
47
- print(e)
48
- break
49
- except TypeError:
50
- print('No face detected in this image')
51
- shape = [68, 2]
52
- keypoints = -1. * np.ones(shape)
53
- break
54
- if name is not None:
55
- np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
56
- return keypoints
57
-
58
- def read_video(filename):
59
- frames = []
60
- cap = cv2.VideoCapture(filename)
61
- while cap.isOpened():
62
- ret, frame = cap.read()
63
- if ret:
64
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
65
- frame = Image.fromarray(frame)
66
- frames.append(frame)
67
- else:
68
- break
69
- cap.release()
70
- return frames
71
-
72
- def run(data):
73
- filename, opt, device = data
74
- os.environ['CUDA_VISIBLE_DEVICES'] = device
75
- kp_extractor = KeypointExtractor()
76
- images = read_video(filename)
77
- name = filename.split('/')[-2:]
78
- os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
79
- kp_extractor.extract_keypoint(
80
- images,
81
- name=os.path.join(opt.output_dir, name[-2], name[-1])
82
- )
83
-
84
- if __name__ == '__main__':
85
- set_start_method('spawn')
86
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
87
- parser.add_argument('--input_dir', type=str, help='the folder of the input files')
88
- parser.add_argument('--output_dir', type=str, help='the folder of the output files')
89
- parser.add_argument('--device_ids', type=str, default='0,1')
90
- parser.add_argument('--workers', type=int, default=4)
91
-
92
- opt = parser.parse_args()
93
- filenames = list()
94
- VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
95
- VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
96
- extensions = VIDEO_EXTENSIONS
97
-
98
- for ext in extensions:
99
- os.listdir(f'{opt.input_dir}')
100
- print(f'{opt.input_dir}/*.{ext}')
101
- filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
102
- print('Total number of videos:', len(filenames))
103
- pool = Pool(opt.workers)
104
- args_list = cycle([opt])
105
- device_ids = opt.device_ids.split(",")
106
- device_ids = cycle(device_ids)
107
- for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
108
- None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/face3d/extract_kp_videos_safe.py DELETED
@@ -1,151 +0,0 @@
1
- import os
2
- import cv2
3
- import time
4
- import glob
5
- import argparse
6
- import numpy as np
7
- from PIL import Image
8
- import torch
9
- from tqdm import tqdm
10
- from itertools import cycle
11
- from torch.multiprocessing import Pool, Process, set_start_method
12
-
13
- from facexlib.alignment import landmark_98_to_68
14
- from facexlib.detection import init_detection_model
15
-
16
- from facexlib.utils import load_file_from_url
17
- from src.face3d.util.my_awing_arch import FAN
18
-
19
- def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None):
20
- if model_name == 'awing_fan':
21
- model = FAN(num_modules=4, num_landmarks=98, device=device)
22
- model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth'
23
- else:
24
- raise NotImplementedError(f'{model_name} is not implemented.')
25
-
26
- model_path = load_file_from_url(
27
- url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
28
- model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True)
29
- model.eval()
30
- model = model.to(device)
31
- return model
32
-
33
-
34
- class KeypointExtractor():
35
- def __init__(self, device='cuda'):
36
-
37
- ### gfpgan/weights
38
- try:
39
- import webui # in webui
40
- root_path = 'extensions/SadTalker/gfpgan/weights'
41
-
42
- except:
43
- root_path = 'gfpgan/weights'
44
-
45
- self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path)
46
- self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path)
47
-
48
- def extract_keypoint(self, images, name=None, info=True):
49
- if isinstance(images, list):
50
- keypoints = []
51
- if info:
52
- i_range = tqdm(images,desc='landmark Det:')
53
- else:
54
- i_range = images
55
-
56
- for image in i_range:
57
- current_kp = self.extract_keypoint(image)
58
- # current_kp = self.detector.get_landmarks(np.array(image))
59
- if np.mean(current_kp) == -1 and keypoints:
60
- keypoints.append(keypoints[-1])
61
- else:
62
- keypoints.append(current_kp[None])
63
-
64
- keypoints = np.concatenate(keypoints, 0)
65
- np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
66
- return keypoints
67
- else:
68
- while True:
69
- try:
70
- with torch.no_grad():
71
- # face detection -> face alignment.
72
- img = np.array(images)
73
- bboxes = self.det_net.detect_faces(images, 0.97)
74
-
75
- bboxes = bboxes[0]
76
- img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
77
-
78
- keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]
79
-
80
- #### keypoints to the original location
81
- keypoints[:,0] += int(bboxes[0])
82
- keypoints[:,1] += int(bboxes[1])
83
-
84
- break
85
- except RuntimeError as e:
86
- if str(e).startswith('CUDA'):
87
- print("Warning: out of memory, sleep for 1s")
88
- time.sleep(1)
89
- else:
90
- print(e)
91
- break
92
- except TypeError:
93
- print('No face detected in this image')
94
- shape = [68, 2]
95
- keypoints = -1. * np.ones(shape)
96
- break
97
- if name is not None:
98
- np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
99
- return keypoints
100
-
101
- def read_video(filename):
102
- frames = []
103
- cap = cv2.VideoCapture(filename)
104
- while cap.isOpened():
105
- ret, frame = cap.read()
106
- if ret:
107
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
108
- frame = Image.fromarray(frame)
109
- frames.append(frame)
110
- else:
111
- break
112
- cap.release()
113
- return frames
114
-
115
- def run(data):
116
- filename, opt, device = data
117
- os.environ['CUDA_VISIBLE_DEVICES'] = device
118
- kp_extractor = KeypointExtractor()
119
- images = read_video(filename)
120
- name = filename.split('/')[-2:]
121
- os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
122
- kp_extractor.extract_keypoint(
123
- images,
124
- name=os.path.join(opt.output_dir, name[-2], name[-1])
125
- )
126
-
127
- if __name__ == '__main__':
128
- set_start_method('spawn')
129
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
130
- parser.add_argument('--input_dir', type=str, help='the folder of the input files')
131
- parser.add_argument('--output_dir', type=str, help='the folder of the output files')
132
- parser.add_argument('--device_ids', type=str, default='0,1')
133
- parser.add_argument('--workers', type=int, default=4)
134
-
135
- opt = parser.parse_args()
136
- filenames = list()
137
- VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
138
- VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
139
- extensions = VIDEO_EXTENSIONS
140
-
141
- for ext in extensions:
142
- os.listdir(f'{opt.input_dir}')
143
- print(f'{opt.input_dir}/*.{ext}')
144
- filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
145
- print('Total number of videos:', len(filenames))
146
- pool = Pool(opt.workers)
147
- args_list = cycle([opt])
148
- device_ids = opt.device_ids.split(",")
149
- device_ids = cycle(device_ids)
150
- for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
151
- None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/face3d/models/__init__.py DELETED
@@ -1,67 +0,0 @@
1
- """This package contains modules related to objective functions, optimizations, and network architectures.
2
-
3
- To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
- You need to implement the following five functions:
5
- -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
- -- <set_input>: unpack data from dataset and apply preprocessing.
7
- -- <forward>: produce intermediate results.
8
- -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
- -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
-
11
- In the function <__init__>, you need to define four lists:
12
- -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
- -- self.model_names (str list): define networks used in our training.
14
- -- self.visual_names (str list): specify the images that you want to display and save.
15
- -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
-
17
- Now you can use the model class by specifying flag '--model dummy'.
18
- See our template model class 'template_model.py' for more details.
19
- """
20
-
21
- import importlib
22
- from src.face3d.models.base_model import BaseModel
23
-
24
-
25
- def find_model_using_name(model_name):
26
- """Import the module "models/[model_name]_model.py".
27
-
28
- In the file, the class called DatasetNameModel() will
29
- be instantiated. It has to be a subclass of BaseModel,
30
- and it is case-insensitive.
31
- """
32
- model_filename = "face3d.models." + model_name + "_model"
33
- modellib = importlib.import_module(model_filename)
34
- model = None
35
- target_model_name = model_name.replace('_', '') + 'model'
36
- for name, cls in modellib.__dict__.items():
37
- if name.lower() == target_model_name.lower() \
38
- and issubclass(cls, BaseModel):
39
- model = cls
40
-
41
- if model is None:
42
- print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43
- exit(0)
44
-
45
- return model
46
-
47
-
48
- def get_option_setter(model_name):
49
- """Return the static method <modify_commandline_options> of the model class."""
50
- model_class = find_model_using_name(model_name)
51
- return model_class.modify_commandline_options
52
-
53
-
54
- def create_model(opt):
55
- """Create a model given the option.
56
-
57
- This function warps the class CustomDatasetDataLoader.
58
- This is the main interface between this package and 'train.py'/'test.py'
59
-
60
- Example:
61
- >>> from models import create_model
62
- >>> model = create_model(opt)
63
- """
64
- model = find_model_using_name(opt.model)
65
- instance = model(opt)
66
- print("model [%s] was created" % type(instance).__name__)
67
- return instance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/face3d/models/arcface_torch/README.md DELETED
@@ -1,164 +0,0 @@
1
- # Distributed Arcface Training in Pytorch
2
-
3
- This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
4
- identity on a single server.
5
-
6
- ## Requirements
7
-
8
- - Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
9
- - `pip install -r requirements.txt`.
10
- - Download the dataset
11
- from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
12
- .
13
-
14
- ## How to Training
15
-
16
- To train a model, run `train.py` with the path to the configs:
17
-
18
- ### 1. Single node, 8 GPUs:
19
-
20
- ```shell
21
- python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
22
- ```
23
-
24
- ### 2. Multiple nodes, each node 8 GPUs:
25
-
26
- Node 0:
27
-
28
- ```shell
29
- python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
30
- ```
31
-
32
- Node 1:
33
-
34
- ```shell
35
- python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
36
- ```
37
-
38
- ### 3.Training resnet2060 with 8 GPUs:
39
-
40
- ```shell
41
- python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
42
- ```
43
-
44
- ## Model Zoo
45
-
46
- - The models are available for non-commercial research purposes only.
47
- - All models can be found in here.
48
- - [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
49
- - [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
50
-
51
- ### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
52
-
53
- ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
54
- recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
55
- As the result, we can evaluate the FAIR performance for different algorithms.
56
-
57
- For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
58
- globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
59
-
60
- For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4).
61
- Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images.
62
- There are totally 13,928 positive pairs and 96,983,824 negative pairs.
63
-
64
- | Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
65
- | :---: | :--- | :--- | :--- |:--- |:--- |
66
- | MS1MV3 | r18 | - | 91 | **47.85** | **68.33** |
67
- | Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** |
68
- | MS1MV3 | r34 | - | 130 | **58.72** | **77.36** |
69
- | Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** |
70
- | MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** |
71
- | Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** |
72
- | MS1MV3 | r100 | - | 248 | **69.09** | **84.31** |
73
- | Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** |
74
- | MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** |
75
- | Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** |
76
-
77
- ### Performance on IJB-C and Verification Datasets
78
-
79
- | Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log |
80
- | :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- |
81
- | MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|
82
- | MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|
83
- | MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|
84
- | MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|
85
- | MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
86
- | Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)|
87
- | Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)|
88
- | Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)|
89
- | Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
90
-
91
- [comment]: <> (More details see [model.md]&#40;docs/modelzoo.md&#41; in docs.)
92
-
93
-
94
- ## [Speed Benchmark](docs/speed_benchmark.md)
95
-
96
- **Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
97
- classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
98
- accuracy with several times faster training performance and smaller GPU memory.
99
- Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
100
- sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a
101
- sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC,
102
- we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed
103
- training and mixed precision training.
104
-
105
- ![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)
106
-
107
- More details see
108
- [speed_benchmark.md](docs/speed_benchmark.md) in docs.
109
-
110
- ### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
111
-
112
- `-` means training failed because of gpu memory limitations.
113
-
114
- | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
115
- | :--- | :--- | :--- | :--- |
116
- |125000 | 4681 | 4824 | 5004 |
117
- |1400000 | **1672** | 3043 | 4738 |
118
- |5500000 | **-** | **1389** | 3975 |
119
- |8000000 | **-** | **-** | 3565 |
120
- |16000000 | **-** | **-** | 2679 |
121
- |29000000 | **-** | **-** | **1855** |
122
-
123
- ### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
124
-
125
- | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
126
- | :--- | :--- | :--- | :--- |
127
- |125000 | 7358 | 5306 | 4868 |
128
- |1400000 | 32252 | 11178 | 6056 |
129
- |5500000 | **-** | 32188 | 9854 |
130
- |8000000 | **-** | **-** | 12310 |
131
- |16000000 | **-** | **-** | 19950 |
132
- |29000000 | **-** | **-** | 32324 |
133
-
134
- ## Evaluation ICCV2021-MFR and IJB-C
135
-
136
- More details see [eval.md](docs/eval.md) in docs.
137
-
138
- ## Test
139
-
140
- We tested many versions of PyTorch. Please create an issue if you are having trouble.
141
-
142
- - [x] torch 1.6.0
143
- - [x] torch 1.7.1
144
- - [x] torch 1.8.0
145
- - [x] torch 1.9.0
146
-
147
- ## Citation
148
-
149
- ```
150
- @inproceedings{deng2019arcface,
151
- title={Arcface: Additive angular margin loss for deep face recognition},
152
- author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
153
- booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
154
- pages={4690--4699},
155
- year={2019}
156
- }
157
- @inproceedings{an2020partical_fc,
158
- title={Partial FC: Training 10 Million Identities on a Single Machine},
159
- author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
160
- Zhang, Debing and Fu Ying},
161
- booktitle={Arxiv 2010.05222},
162
- year={2020}
163
- }
164
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/face3d/models/arcface_torch/backbones/__init__.py DELETED
@@ -1,25 +0,0 @@
1
- from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
2
- from .mobilefacenet import get_mbf
3
-
4
-
5
- def get_model(name, **kwargs):
6
- # resnet
7
- if name == "r18":
8
- return iresnet18(False, **kwargs)
9
- elif name == "r34":
10
- return iresnet34(False, **kwargs)
11
- elif name == "r50":
12
- return iresnet50(False, **kwargs)
13
- elif name == "r100":
14
- return iresnet100(False, **kwargs)
15
- elif name == "r200":
16
- return iresnet200(False, **kwargs)
17
- elif name == "r2060":
18
- from .iresnet2060 import iresnet2060
19
- return iresnet2060(False, **kwargs)
20
- elif name == "mbf":
21
- fp16 = kwargs.get("fp16", False)
22
- num_features = kwargs.get("num_features", 512)
23
- return get_mbf(fp16=fp16, num_features=num_features)
24
- else:
25
- raise ValueError()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/face3d/models/arcface_torch/backbones/iresnet.py DELETED
@@ -1,187 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
- __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
5
-
6
-
7
- def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
8
- """3x3 convolution with padding"""
9
- return nn.Conv2d(in_planes,
10
- out_planes,
11
- kernel_size=3,
12
- stride=stride,
13
- padding=dilation,
14
- groups=groups,
15
- bias=False,
16
- dilation=dilation)
17
-
18
-
19
- def conv1x1(in_planes, out_planes, stride=1):
20
- """1x1 convolution"""
21
- return nn.Conv2d(in_planes,
22
- out_planes,
23
- kernel_size=1,
24
- stride=stride,
25
- bias=False)
26
-
27
-
28
- class IBasicBlock(nn.Module):
29
- expansion = 1
30
- def __init__(self, inplanes, planes, stride=1, downsample=None,
31
- groups=1, base_width=64, dilation=1):
32
- super(IBasicBlock, self).__init__()
33
- if groups != 1 or base_width != 64:
34
- raise ValueError('BasicBlock only supports groups=1 and base_width=64')
35
- if dilation > 1:
36
- raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
37
- self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
38
- self.conv1 = conv3x3(inplanes, planes)
39
- self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
40
- self.prelu = nn.PReLU(planes)
41
- self.conv2 = conv3x3(planes, planes, stride)
42
- self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
43
- self.downsample = downsample
44
- self.stride = stride
45
-
46
- def forward(self, x):
47
- identity = x
48
- out = self.bn1(x)
49
- out = self.conv1(out)
50
- out = self.bn2(out)
51
- out = self.prelu(out)
52
- out = self.conv2(out)
53
- out = self.bn3(out)
54
- if self.downsample is not None:
55
- identity = self.downsample(x)
56
- out += identity
57
- return out
58
-
59
-
60
- class IResNet(nn.Module):
61
- fc_scale = 7 * 7
62
- def __init__(self,
63
- block, layers, dropout=0, num_features=512, zero_init_residual=False,
64
- groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
65
- super(IResNet, self).__init__()
66
- self.fp16 = fp16
67
- self.inplanes = 64
68
- self.dilation = 1
69
- if replace_stride_with_dilation is None:
70
- replace_stride_with_dilation = [False, False, False]
71
- if len(replace_stride_with_dilation) != 3:
72
- raise ValueError("replace_stride_with_dilation should be None "
73
- "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
74
- self.groups = groups
75
- self.base_width = width_per_group
76
- self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
77
- self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
78
- self.prelu = nn.PReLU(self.inplanes)
79
- self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
80
- self.layer2 = self._make_layer(block,
81
- 128,
82
- layers[1],
83
- stride=2,
84
- dilate=replace_stride_with_dilation[0])
85
- self.layer3 = self._make_layer(block,
86
- 256,
87
- layers[2],
88
- stride=2,
89
- dilate=replace_stride_with_dilation[1])
90
- self.layer4 = self._make_layer(block,
91
- 512,
92
- layers[3],
93
- stride=2,
94
- dilate=replace_stride_with_dilation[2])
95
- self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
96
- self.dropout = nn.Dropout(p=dropout, inplace=True)
97
- self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
98
- self.features = nn.BatchNorm1d(num_features, eps=1e-05)
99
- nn.init.constant_(self.features.weight, 1.0)
100
- self.features.weight.requires_grad = False
101
-
102
- for m in self.modules():
103
- if isinstance(m, nn.Conv2d):
104
- nn.init.normal_(m.weight, 0, 0.1)
105
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
106
- nn.init.constant_(m.weight, 1)
107
- nn.init.constant_(m.bias, 0)
108
-
109
- if zero_init_residual:
110
- for m in self.modules():
111
- if isinstance(m, IBasicBlock):
112
- nn.init.constant_(m.bn2.weight, 0)
113
-
114
- def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
115
- downsample = None
116
- previous_dilation = self.dilation
117
- if dilate:
118
- self.dilation *= stride
119
- stride = 1
120
- if stride != 1 or self.inplanes != planes * block.expansion:
121
- downsample = nn.Sequential(
122
- conv1x1(self.inplanes, planes * block.expansion, stride),
123
- nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
124
- )
125
- layers = []
126
- layers.append(
127
- block(self.inplanes, planes, stride, downsample, self.groups,
128
- self.base_width, previous_dilation))
129
- self.inplanes = planes * block.expansion
130
- for _ in range(1, blocks):
131
- layers.append(
132
- block(self.inplanes,
133
- planes,
134
- groups=self.groups,
135
- base_width=self.base_width,
136
- dilation=self.dilation))
137
-
138
- return nn.Sequential(*layers)
139
-
140
- def forward(self, x):
141
- with torch.cuda.amp.autocast(self.fp16):
142
- x = self.conv1(x)
143
- x = self.bn1(x)
144
- x = self.prelu(x)
145
- x = self.layer1(x)
146
- x = self.layer2(x)
147
- x = self.layer3(x)
148
- x = self.layer4(x)
149
- x = self.bn2(x)
150
- x = torch.flatten(x, 1)
151
- x = self.dropout(x)
152
- x = self.fc(x.float() if self.fp16 else x)
153
- x = self.features(x)
154
- return x
155
-
156
-
157
- def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
158
- model = IResNet(block, layers, **kwargs)
159
- if pretrained:
160
- raise ValueError()
161
- return model
162
-
163
-
164
- def iresnet18(pretrained=False, progress=True, **kwargs):
165
- return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
166
- progress, **kwargs)
167
-
168
-
169
- def iresnet34(pretrained=False, progress=True, **kwargs):
170
- return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
171
- progress, **kwargs)
172
-
173
-
174
- def iresnet50(pretrained=False, progress=True, **kwargs):
175
- return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
176
- progress, **kwargs)
177
-
178
-
179
- def iresnet100(pretrained=False, progress=True, **kwargs):
180
- return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
181
- progress, **kwargs)
182
-
183
-
184
- def iresnet200(pretrained=False, progress=True, **kwargs):
185
- return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
186
- progress, **kwargs)
187
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/face3d/models/arcface_torch/backbones/iresnet2060.py DELETED
@@ -1,176 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
- assert torch.__version__ >= "1.8.1"
5
- from torch.utils.checkpoint import checkpoint_sequential
6
-
7
- __all__ = ['iresnet2060']
8
-
9
-
10
- def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
11
- """3x3 convolution with padding"""
12
- return nn.Conv2d(in_planes,
13
- out_planes,
14
- kernel_size=3,
15
- stride=stride,
16
- padding=dilation,
17
- groups=groups,
18
- bias=False,
19
- dilation=dilation)
20
-
21
-
22
- def conv1x1(in_planes, out_planes, stride=1):
23
- """1x1 convolution"""
24
- return nn.Conv2d(in_planes,
25
- out_planes,
26
- kernel_size=1,
27
- stride=stride,
28
- bias=False)
29
-
30
-
31
- class IBasicBlock(nn.Module):
32
- expansion = 1
33
-
34
- def __init__(self, inplanes, planes, stride=1, downsample=None,
35
- groups=1, base_width=64, dilation=1):
36
- super(IBasicBlock, self).__init__()
37
- if groups != 1 or base_width != 64:
38
- raise ValueError('BasicBlock only supports groups=1 and base_width=64')
39
- if dilation > 1:
40
- raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
41
- self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
42
- self.conv1 = conv3x3(inplanes, planes)
43
- self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
44
- self.prelu = nn.PReLU(planes)
45
- self.conv2 = conv3x3(planes, planes, stride)
46
- self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
47
- self.downsample = downsample
48
- self.stride = stride
49
-
50
- def forward(self, x):
51
- identity = x
52
- out = self.bn1(x)
53
- out = self.conv1(out)
54
- out = self.bn2(out)
55
- out = self.prelu(out)
56
- out = self.conv2(out)
57
- out = self.bn3(out)
58
- if self.downsample is not None:
59
- identity = self.downsample(x)
60
- out += identity
61
- return out
62
-
63
-
64
- class IResNet(nn.Module):
65
- fc_scale = 7 * 7
66
-
67
- def __init__(self,
68
- block, layers, dropout=0, num_features=512, zero_init_residual=False,
69
- groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
70
- super(IResNet, self).__init__()
71
- self.fp16 = fp16
72
- self.inplanes = 64
73
- self.dilation = 1
74
- if replace_stride_with_dilation is None:
75
- replace_stride_with_dilation = [False, False, False]
76
- if len(replace_stride_with_dilation) != 3:
77
- raise ValueError("replace_stride_with_dilation should be None "
78
- "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
79
- self.groups = groups
80
- self.base_width = width_per_group
81
- self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
82
- self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
83
- self.prelu = nn.PReLU(self.inplanes)
84
- self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
85
- self.layer2 = self._make_layer(block,
86
- 128,
87
- layers[1],
88
- stride=2,
89
- dilate=replace_stride_with_dilation[0])
90
- self.layer3 = self._make_layer(block,
91
- 256,
92
- layers[2],
93
- stride=2,
94
- dilate=replace_stride_with_dilation[1])
95
- self.layer4 = self._make_layer(block,
96
- 512,
97
- layers[3],
98
- stride=2,
99
- dilate=replace_stride_with_dilation[2])
100
- self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
101
- self.dropout = nn.Dropout(p=dropout, inplace=True)
102
- self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
103
- self.features = nn.BatchNorm1d(num_features, eps=1e-05)
104
- nn.init.constant_(self.features.weight, 1.0)
105
- self.features.weight.requires_grad = False
106
-
107
- for m in self.modules():
108
- if isinstance(m, nn.Conv2d):
109
- nn.init.normal_(m.weight, 0, 0.1)
110
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
111
- nn.init.constant_(m.weight, 1)
112
- nn.init.constant_(m.bias, 0)
113
-
114
- if zero_init_residual:
115
- for m in self.modules():
116
- if isinstance(m, IBasicBlock):
117
- nn.init.constant_(m.bn2.weight, 0)
118
-
119
- def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
120
- downsample = None
121
- previous_dilation = self.dilation
122
- if dilate:
123
- self.dilation *= stride
124
- stride = 1
125
- if stride != 1 or self.inplanes != planes * block.expansion:
126
- downsample = nn.Sequential(
127
- conv1x1(self.inplanes, planes * block.expansion, stride),
128
- nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
129
- )
130
- layers = []
131
- layers.append(
132
- block(self.inplanes, planes, stride, downsample, self.groups,
133
- self.base_width, previous_dilation))
134
- self.inplanes = planes * block.expansion
135
- for _ in range(1, blocks):
136
- layers.append(
137
- block(self.inplanes,
138
- planes,
139
- groups=self.groups,
140
- base_width=self.base_width,
141
- dilation=self.dilation))
142
-
143
- return nn.Sequential(*layers)
144
-
145
- def checkpoint(self, func, num_seg, x):
146
- if self.training:
147
- return checkpoint_sequential(func, num_seg, x)
148
- else:
149
- return func(x)
150
-
151
- def forward(self, x):
152
- with torch.cuda.amp.autocast(self.fp16):
153
- x = self.conv1(x)
154
- x = self.bn1(x)
155
- x = self.prelu(x)
156
- x = self.layer1(x)
157
- x = self.checkpoint(self.layer2, 20, x)
158
- x = self.checkpoint(self.layer3, 100, x)
159
- x = self.layer4(x)
160
- x = self.bn2(x)
161
- x = torch.flatten(x, 1)
162
- x = self.dropout(x)
163
- x = self.fc(x.float() if self.fp16 else x)
164
- x = self.features(x)
165
- return x
166
-
167
-
168
- def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
169
- model = IResNet(block, layers, **kwargs)
170
- if pretrained:
171
- raise ValueError()
172
- return model
173
-
174
-
175
- def iresnet2060(pretrained=False, progress=True, **kwargs):
176
- return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/face3d/models/arcface_torch/backbones/mobilefacenet.py DELETED
@@ -1,130 +0,0 @@
1
- '''
2
- Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
3
- Original author cavalleria
4
- '''
5
-
6
- import torch.nn as nn
7
- from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
8
- import torch
9
-
10
-
11
- class Flatten(Module):
12
- def forward(self, x):
13
- return x.view(x.size(0), -1)
14
-
15
-
16
- class ConvBlock(Module):
17
- def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
18
- super(ConvBlock, self).__init__()
19
- self.layers = nn.Sequential(
20
- Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
21
- BatchNorm2d(num_features=out_c),
22
- PReLU(num_parameters=out_c)
23
- )
24
-
25
- def forward(self, x):
26
- return self.layers(x)
27
-
28
-
29
- class LinearBlock(Module):
30
- def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
31
- super(LinearBlock, self).__init__()
32
- self.layers = nn.Sequential(
33
- Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
34
- BatchNorm2d(num_features=out_c)
35
- )
36
-
37
- def forward(self, x):
38
- return self.layers(x)
39
-
40
-
41
- class DepthWise(Module):
42
- def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
43
- super(DepthWise, self).__init__()
44
- self.residual = residual
45
- self.layers = nn.Sequential(
46
- ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
47
- ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
48
- LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
49
- )
50
-
51
- def forward(self, x):
52
- short_cut = None
53
- if self.residual:
54
- short_cut = x
55
- x = self.layers(x)
56
- if self.residual:
57
- output = short_cut + x
58
- else:
59
- output = x
60
- return output
61
-
62
-
63
- class Residual(Module):
64
- def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
65
- super(Residual, self).__init__()
66
- modules = []
67
- for _ in range(num_block):
68
- modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
69
- self.layers = Sequential(*modules)
70
-
71
- def forward(self, x):
72
- return self.layers(x)
73
-
74
-
75
- class GDC(Module):
76
- def __init__(self, embedding_size):
77
- super(GDC, self).__init__()
78
- self.layers = nn.Sequential(
79
- LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
80
- Flatten(),
81
- Linear(512, embedding_size, bias=False),
82
- BatchNorm1d(embedding_size))
83
-
84
- def forward(self, x):
85
- return self.layers(x)
86
-
87
-
88
- class MobileFaceNet(Module):
89
- def __init__(self, fp16=False, num_features=512):
90
- super(MobileFaceNet, self).__init__()
91
- scale = 2
92
- self.fp16 = fp16
93
- self.layers = nn.Sequential(
94
- ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
95
- ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
96
- DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
97
- Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
98
- DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
99
- Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
100
- DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
101
- Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
102
- )
103
- self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
104
- self.features = GDC(num_features)
105
- self._initialize_weights()
106
-
107
- def _initialize_weights(self):
108
- for m in self.modules():
109
- if isinstance(m, nn.Conv2d):
110
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
111
- if m.bias is not None:
112
- m.bias.data.zero_()
113
- elif isinstance(m, nn.BatchNorm2d):
114
- m.weight.data.fill_(1)
115
- m.bias.data.zero_()
116
- elif isinstance(m, nn.Linear):
117
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
118
- if m.bias is not None:
119
- m.bias.data.zero_()
120
-
121
- def forward(self, x):
122
- with torch.cuda.amp.autocast(self.fp16):
123
- x = self.layers(x)
124
- x = self.conv_sep(x.float() if self.fp16 else x)
125
- x = self.features(x)
126
- return x
127
-
128
-
129
- def get_mbf(fp16, num_features):
130
- return MobileFaceNet(fp16, num_features)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/face3d/models/arcface_torch/configs/3millions.py DELETED
@@ -1,23 +0,0 @@
1
- from easydict import EasyDict as edict
2
-
3
- # configs for test speed
4
-
5
- config = edict()
6
- config.loss = "arcface"
7
- config.network = "r50"
8
- config.resume = False
9
- config.output = None
10
- config.embedding_size = 512
11
- config.sample_rate = 1.0
12
- config.fp16 = True
13
- config.momentum = 0.9
14
- config.weight_decay = 5e-4
15
- config.batch_size = 128
16
- config.lr = 0.1 # batch size is 512
17
-
18
- config.rec = "synthetic"
19
- config.num_classes = 300 * 10000
20
- config.num_epoch = 30
21
- config.warmup_epoch = -1
22
- config.decay_epoch = [10, 16, 22]
23
- config.val_targets = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/face3d/models/arcface_torch/configs/3millions_pfc.py DELETED
@@ -1,23 +0,0 @@
1
- from easydict import EasyDict as edict
2
-
3
- # configs for test speed
4
-
5
- config = edict()
6
- config.loss = "arcface"
7
- config.network = "r50"
8
- config.resume = False
9
- config.output = None
10
- config.embedding_size = 512
11
- config.sample_rate = 0.1
12
- config.fp16 = True
13
- config.momentum = 0.9
14
- config.weight_decay = 5e-4
15
- config.batch_size = 128
16
- config.lr = 0.1 # batch size is 512
17
-
18
- config.rec = "synthetic"
19
- config.num_classes = 300 * 10000
20
- config.num_epoch = 30
21
- config.warmup_epoch = -1
22
- config.decay_epoch = [10, 16, 22]
23
- config.val_targets = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/face3d/models/arcface_torch/configs/__init__.py DELETED
File without changes
SadTalker/src/face3d/models/arcface_torch/configs/base.py DELETED
@@ -1,56 +0,0 @@
1
- from easydict import EasyDict as edict
2
-
3
- # make training faster
4
- # our RAM is 256G
5
- # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
-
7
- config = edict()
8
- config.loss = "arcface"
9
- config.network = "r50"
10
- config.resume = False
11
- config.output = "ms1mv3_arcface_r50"
12
-
13
- config.dataset = "ms1m-retinaface-t1"
14
- config.embedding_size = 512
15
- config.sample_rate = 1
16
- config.fp16 = False
17
- config.momentum = 0.9
18
- config.weight_decay = 5e-4
19
- config.batch_size = 128
20
- config.lr = 0.1 # batch size is 512
21
-
22
- if config.dataset == "emore":
23
- config.rec = "/train_tmp/faces_emore"
24
- config.num_classes = 85742
25
- config.num_image = 5822653
26
- config.num_epoch = 16
27
- config.warmup_epoch = -1
28
- config.decay_epoch = [8, 14, ]
29
- config.val_targets = ["lfw", ]
30
-
31
- elif config.dataset == "ms1m-retinaface-t1":
32
- config.rec = "/train_tmp/ms1m-retinaface-t1"
33
- config.num_classes = 93431
34
- config.num_image = 5179510
35
- config.num_epoch = 25
36
- config.warmup_epoch = -1
37
- config.decay_epoch = [11, 17, 22]
38
- config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
39
-
40
- elif config.dataset == "glint360k":
41
- config.rec = "/train_tmp/glint360k"
42
- config.num_classes = 360232
43
- config.num_image = 17091657
44
- config.num_epoch = 20
45
- config.warmup_epoch = -1
46
- config.decay_epoch = [8, 12, 15, 18]
47
- config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
48
-
49
- elif config.dataset == "webface":
50
- config.rec = "/train_tmp/faces_webface_112x112"
51
- config.num_classes = 10572
52
- config.num_image = "forget"
53
- config.num_epoch = 34
54
- config.warmup_epoch = -1
55
- config.decay_epoch = [20, 28, 32]
56
- config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SadTalker/src/face3d/models/arcface_torch/configs/glint360k_mbf.py DELETED
@@ -1,26 +0,0 @@
1
- from easydict import EasyDict as edict
2
-
3
- # make training faster
4
- # our RAM is 256G
5
- # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
-
7
- config = edict()
8
- config.loss = "cosface"
9
- config.network = "mbf"
10
- config.resume = False
11
- config.output = None
12
- config.embedding_size = 512
13
- config.sample_rate = 0.1
14
- config.fp16 = True
15
- config.momentum = 0.9
16
- config.weight_decay = 2e-4
17
- config.batch_size = 128
18
- config.lr = 0.1 # batch size is 512
19
-
20
- config.rec = "/train_tmp/glint360k"
21
- config.num_classes = 360232
22
- config.num_image = 17091657
23
- config.num_epoch = 20
24
- config.warmup_epoch = -1
25
- config.decay_epoch = [8, 12, 15, 18]
26
- config.val_targets = ["lfw", "cfp_fp", "agedb_30"]