diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..05dc67685b581fd71ce44b63652aa320666d02ad --- /dev/null +++ b/.gitignore @@ -0,0 +1,43 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*__pycache__ +**/__pycache__/ +*.py[cod] +**/*.py[cod] +*$py.class + +# Model weights +checkpoints +**/*.pth +**/*.onnx +**/*.pt +**/*.pth.tar + +.idea +.vscode +.DS_Store +*.DS_Store + +*.swp +tmp* + +*build +*.egg-info/ +*.mp4 + +log/* +*.mp4 +*.png +*.jpg +*.wav +*.pth +*.pyc +*.jpeg + +# Folders to ignore +example/ +ToDo/ + +!example/audio.wav +!example/image.png + diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f49a4e16e68b128803cc2dcea614603632b04eac --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README_ditto-talkinghead.md b/README_ditto-talkinghead.md new file mode 100644 index 0000000000000000000000000000000000000000..a078e6a14afc9084d34976f48d796a58cce64d22 --- /dev/null +++ b/README_ditto-talkinghead.md @@ -0,0 +1,232 @@ +

Ditto: Motion-Space Diffusion for Controllable Realtime Talking Head Synthesis

+ +
+ Tianqi Li + · + Ruobing Zheng + · + Minghui Yang + · + Jingdong Chen + · + Ming Yang +
+
+Ant Group +
+
+
+ + + + + + +
+
+
+ +

+ ✨ For more results, visit our Project Page ✨ +

+
+ + +## 📌 Updates +* [2025.07.11] 🔥 The [PyTorch model](#-pytorch-model) is now available. +* [2025.07.07] 🔥 Ditto is accepted by ACM MM 2025. +* [2025.01.21] 🔥 We update the [Colab](https://colab.research.google.com/drive/19SUi1TiO32IS-Crmsu9wrkNspWE8tFbs?usp=sharing) demo, welcome to try it. +* [2025.01.10] 🔥 We release our inference [codes](https://github.com/antgroup/ditto-talkinghead) and [models](https://huggingface.co/digital-avatar/ditto-talkinghead). +* [2024.11.29] 🔥 Our [paper](https://arxiv.org/abs/2411.19509) is in public on arxiv. + + + +## 🛠️ Installation + +Tested Environment +- System: Centos 7.2 +- GPU: A100 +- Python: 3.10 +- tensorRT: 8.6.1 + + +Clone the codes from [GitHub](https://github.com/antgroup/ditto-talkinghead): +```bash +git clone https://github.com/antgroup/ditto-talkinghead +cd ditto-talkinghead +``` + +### Conda +Create `conda` environment: +```bash +conda env create -f environment.yaml +conda activate ditto +``` + +### Pip +If you have problems creating a conda environment, you can also refer to our [Colab](https://colab.research.google.com/drive/19SUi1TiO32IS-Crmsu9wrkNspWE8tFbs?usp=sharing). +After correctly installing `pytorch`, `cuda` and `cudnn`, you only need to install a few packages using pip: +```bash +pip install \ + tensorrt==8.6.1 \ + librosa \ + tqdm \ + filetype \ + imageio \ + opencv_python_headless \ + scikit-image \ + cython \ + cuda-python \ + imageio-ffmpeg \ + colored \ + polygraphy \ + numpy==2.0.1 +``` + +If you don't use `conda`, you may also need to install `ffmpeg` according to the [official website](https://www.ffmpeg.org/download.html). + + +## 📥 Download Checkpoints + +Download checkpoints from [HuggingFace](https://huggingface.co/digital-avatar/ditto-talkinghead) and put them in `checkpoints` dir: +```bash +git lfs install +git clone https://huggingface.co/digital-avatar/ditto-talkinghead checkpoints +``` + +The `checkpoints` should be like: +```text +./checkpoints/ +├── ditto_cfg +│   ├── v0.4_hubert_cfg_trt.pkl +│   └── v0.4_hubert_cfg_trt_online.pkl +├── ditto_onnx +│   ├── appearance_extractor.onnx +│   ├── blaze_face.onnx +│   ├── decoder.onnx +│   ├── face_mesh.onnx +│   ├── hubert.onnx +│   ├── insightface_det.onnx +│   ├── landmark106.onnx +│   ├── landmark203.onnx +│   ├── libgrid_sample_3d_plugin.so +│   ├── lmdm_v0.4_hubert.onnx +│   ├── motion_extractor.onnx +│   ├── stitch_network.onnx +│   └── warp_network.onnx +└── ditto_trt_Ampere_Plus + ├── appearance_extractor_fp16.engine + ├── blaze_face_fp16.engine + ├── decoder_fp16.engine + ├── face_mesh_fp16.engine + ├── hubert_fp32.engine + ├── insightface_det_fp16.engine + ├── landmark106_fp16.engine + ├── landmark203_fp16.engine + ├── lmdm_v0.4_hubert_fp32.engine + ├── motion_extractor_fp32.engine + ├── stitch_network_fp16.engine + └── warp_network_fp16.engine +``` + +- The `ditto_cfg/v0.4_hubert_cfg_trt_online.pkl` is online config +- The `ditto_cfg/v0.4_hubert_cfg_trt.pkl` is offline config + + +## 🚀 Inference + +Run `inference.py`: + +```shell +python inference.py \ + --data_root "" \ + --cfg_pkl "" \ + --audio_path "" \ + --source_path "" \ + --output_path "" +``` + +For example: + +```shell +python inference.py \ + --data_root "./checkpoints/ditto_trt_Ampere_Plus" \ + --cfg_pkl "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl" \ + --audio_path "./example/audio.wav" \ + --source_path "./example/image.png" \ + --output_path "./tmp/result.mp4" +``` + +❗Note: + +We have provided the tensorRT model with `hardware-compatibility-level=Ampere_Plus` (`checkpoints/ditto_trt_Ampere_Plus/`). If your GPU does not support it, please execute the `cvt_onnx_to_trt.py` script to convert from the general onnx model (`checkpoints/ditto_onnx/`) to the tensorRT model. + +```bash +python scripts/cvt_onnx_to_trt.py --onnx_dir "./checkpoints/ditto_onnx" --trt_dir "./checkpoints/ditto_trt_custom" +``` + +Then run `inference.py` with `--data_root=./checkpoints/ditto_trt_custom`. + + +## ⚡ PyTorch Model +*Based on community interest and to better support further development, we are now open-sourcing the PyTorch version of the model.* + + +We have added the PyTorch model and corresponding configuration files to the [HuggingFace](https://huggingface.co/digital-avatar/ditto-talkinghead). Please refer to [Download Checkpoints](#-download-checkpoints) to prepare the model files. + +The `checkpoints` should be like: +```text +./checkpoints/ +├── ditto_cfg +│   ├── ... +│   └── v0.4_hubert_cfg_pytorch.pkl +├── ... +└── ditto_pytorch + ├── aux_models + │ ├── 2d106det.onnx + │ ├── det_10g.onnx + │ ├── face_landmarker.task + │ ├── hubert_streaming_fix_kv.onnx + │ └── landmark203.onnx + └── models + ├── appearance_extractor.pth + ├── decoder.pth + ├── lmdm_v0.4_hubert.pth + ├── motion_extractor.pth + ├── stitch_network.pth + └── warp_network.pth +``` + +To run inference, execute the following command: + +```shell +python inference.py \ + --data_root "./checkpoints/ditto_pytorch" \ + --cfg_pkl "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl" \ + --audio_path "./example/audio.wav" \ + --source_path "./example/image.png" \ + --output_path "./tmp/result.mp4" +``` + + +## 📧 Acknowledgement +Our implementation is based on [S2G-MDDiffusion](https://github.com/thuhcsi/S2G-MDDiffusion) and [LivePortrait](https://github.com/KwaiVGI/LivePortrait). Thanks for their remarkable contribution and released code! If we missed any open-source projects or related articles, we would like to complement the acknowledgement of this specific work immediately. + +## ⚖️ License +This repository is released under the Apache-2.0 license as found in the [LICENSE](LICENSE) file. + +## 📚 Citation +If you find this codebase useful for your research, please use the following entry. +```BibTeX +@article{li2024ditto, + title={Ditto: Motion-Space Diffusion for Controllable Realtime Talking Head Synthesis}, + author={Li, Tianqi and Zheng, Ruobing and Yang, Minghui and Chen, Jingdong and Yang, Ming}, + journal={arXiv preprint arXiv:2411.19509}, + year={2024} +} +``` + + +## 🌟 Star History + +[![Star History Chart](https://api.star-history.com/svg?repos=antgroup/ditto-talkinghead&type=Date)](https://www.star-history.com/#antgroup/ditto-talkinghead&Date) diff --git a/core/atomic_components/audio2motion.py b/core/atomic_components/audio2motion.py new file mode 100644 index 0000000000000000000000000000000000000000..12d2cf6e0d5c57303b6948d56455e6508c565ece --- /dev/null +++ b/core/atomic_components/audio2motion.py @@ -0,0 +1,196 @@ +import numpy as np +from ..models.lmdm import LMDM + + +""" +lmdm_cfg = { + "model_path": "", + "device": "cuda", + "motion_feat_dim": 265, + "audio_feat_dim": 1024+35, + "seq_frames": 80, +} +""" + + +def _cvt_LP_motion_info(inp, mode, ignore_keys=()): + ks_shape_map = [ + ['scale', (1, 1), 1], + ['pitch', (1, 66), 66], + ['yaw', (1, 66), 66], + ['roll', (1, 66), 66], + ['t', (1, 3), 3], + ['exp', (1, 63), 63], + ['kp', (1, 63), 63], + ] + + def _dic2arr(_dic): + arr = [] + for k, _, ds in ks_shape_map: + if k not in _dic or k in ignore_keys: + continue + v = _dic[k].reshape(ds) + if k == 'scale': + v = v - 1 + arr.append(v) + arr = np.concatenate(arr, -1) # (133) + return arr + + def _arr2dic(_arr): + dic = {} + s = 0 + for k, ds, ss in ks_shape_map: + if k in ignore_keys: + continue + v = _arr[s:s + ss].reshape(ds) + if k == 'scale': + v = v + 1 + dic[k] = v + s += ss + if s >= len(_arr): + break + return dic + + if mode == 'dic2arr': + assert isinstance(inp, dict) + return _dic2arr(inp) # (dim) + elif mode == 'arr2dic': + assert inp.shape[0] >= 265, f"{inp.shape}" + return _arr2dic(inp) # {k: (1, dim)} + else: + raise ValueError() + + +class Audio2Motion: + def __init__( + self, + lmdm_cfg, + ): + self.lmdm = LMDM(**lmdm_cfg) + + def setup( + self, + x_s_info, + overlap_v2=10, + fix_kp_cond=0, + fix_kp_cond_dim=None, + sampling_timesteps=50, + online_mode=False, + v_min_max_for_clip=None, + smo_k_d=3, + ): + self.smo_k_d = smo_k_d + self.overlap_v2 = overlap_v2 + self.seq_frames = self.lmdm.seq_frames + self.valid_clip_len = self.seq_frames - self.overlap_v2 + + # for fuse + self.online_mode = online_mode + if self.online_mode: + self.fuse_length = min(self.overlap_v2, self.valid_clip_len) + else: + self.fuse_length = self.overlap_v2 + self.fuse_alpha = np.arange(self.fuse_length, dtype=np.float32).reshape(1, -1, 1) / self.fuse_length + + self.fix_kp_cond = fix_kp_cond + self.fix_kp_cond_dim = fix_kp_cond_dim + self.sampling_timesteps = sampling_timesteps + + self.v_min_max_for_clip = v_min_max_for_clip + if self.v_min_max_for_clip is not None: + self.v_min = self.v_min_max_for_clip[0][None] # [dim, 1] + self.v_max = self.v_min_max_for_clip[1][None] + + kp_source = _cvt_LP_motion_info(x_s_info, mode='dic2arr', ignore_keys={'kp'})[None] + self.s_kp_cond = kp_source.copy().reshape(1, -1) + self.kp_cond = self.s_kp_cond.copy() + + self.lmdm.setup(sampling_timesteps) + + self.clip_idx = 0 + + def _fuse(self, res_kp_seq, pred_kp_seq): + ## ======================== + ## offline fuse mode + ## last clip: ------- + ## fuse part: ***** + ## curr clip: ------- + ## output: ^^ + # + ## online fuse mode + ## last clip: ------- + ## fuse part: ** + ## curr clip: ------- + ## output: ^^ + ## ======================== + + fuse_r1_s = res_kp_seq.shape[1] - self.fuse_length + fuse_r1_e = res_kp_seq.shape[1] + fuse_r2_s = self.seq_frames - self.valid_clip_len - self.fuse_length + fuse_r2_e = self.seq_frames - self.valid_clip_len + + r1 = res_kp_seq[:, fuse_r1_s:fuse_r1_e] # [1, fuse_len, dim] + r2 = pred_kp_seq[:, fuse_r2_s: fuse_r2_e] # [1, fuse_len, dim] + r_fuse = r1 * (1 - self.fuse_alpha) + r2 * self.fuse_alpha + + res_kp_seq[:, fuse_r1_s:fuse_r1_e] = r_fuse # fuse last + res_kp_seq = np.concatenate([res_kp_seq, pred_kp_seq[:, fuse_r2_e:]], 1) # len(res_kp_seq) + valid_clip_len + + return res_kp_seq + + def _update_kp_cond(self, res_kp_seq, idx): + if self.fix_kp_cond == 0: # 不重置 + self.kp_cond = res_kp_seq[:, idx-1] + elif self.fix_kp_cond > 0: + if self.clip_idx % self.fix_kp_cond == 0: # 重置 + self.kp_cond = self.s_kp_cond.copy() # 重置所有 + if self.fix_kp_cond_dim is not None: + ds, de = self.fix_kp_cond_dim + self.kp_cond[:, ds:de] = res_kp_seq[:, idx-1, ds:de] + else: + self.kp_cond = res_kp_seq[:, idx-1] + + def _smo(self, res_kp_seq, s, e): + if self.smo_k_d <= 1: + return res_kp_seq + new_res_kp_seq = res_kp_seq.copy() + n = res_kp_seq.shape[1] + half_k = self.smo_k_d // 2 + for i in range(s, e): + ss = max(0, i - half_k) + ee = min(n, i + half_k + 1) + res_kp_seq[:, i, :202] = np.mean(new_res_kp_seq[:, ss:ee, :202], axis=1) + return res_kp_seq + + def __call__(self, aud_cond, res_kp_seq=None): + """ + aud_cond: (1, seq_frames, dim) + """ + + pred_kp_seq = self.lmdm(self.kp_cond, aud_cond, self.sampling_timesteps) + if res_kp_seq is None: + res_kp_seq = pred_kp_seq # [1, seq_frames, dim] + res_kp_seq = self._smo(res_kp_seq, 0, res_kp_seq.shape[1]) + else: + res_kp_seq = self._fuse(res_kp_seq, pred_kp_seq) # len(res_kp_seq) + valid_clip_len + res_kp_seq = self._smo(res_kp_seq, res_kp_seq.shape[1] - self.valid_clip_len - self.fuse_length, res_kp_seq.shape[1] - self.valid_clip_len + 1) + + self.clip_idx += 1 + + idx = res_kp_seq.shape[1] - self.overlap_v2 + self._update_kp_cond(res_kp_seq, idx) + + return res_kp_seq + + def cvt_fmt(self, res_kp_seq): + # res_kp_seq: [1, n, dim] + if self.v_min_max_for_clip is not None: + tmp_res_kp_seq = np.clip(res_kp_seq[0], self.v_min, self.v_max) + else: + tmp_res_kp_seq = res_kp_seq[0] + + x_d_info_list = [] + for i in range(tmp_res_kp_seq.shape[0]): + x_d_info = _cvt_LP_motion_info(tmp_res_kp_seq[i], 'arr2dic') # {k: (1, dim)} + x_d_info_list.append(x_d_info) + return x_d_info_list diff --git a/core/atomic_components/avatar_registrar.py b/core/atomic_components/avatar_registrar.py new file mode 100644 index 0000000000000000000000000000000000000000..47c2a4c72b2197f720c5fe0464d4e44720daf6f8 --- /dev/null +++ b/core/atomic_components/avatar_registrar.py @@ -0,0 +1,102 @@ +import numpy as np + +from .loader import load_source_frames +from .source2info import Source2Info + + +def _mean_filter(arr, k): + n = arr.shape[0] + half_k = k // 2 + res = [] + for i in range(n): + s = max(0, i - half_k) + e = min(n, i + half_k + 1) + res.append(arr[s:e].mean(0)) + res = np.stack(res, 0) + return res + + +def smooth_x_s_info_lst(x_s_info_list, ignore_keys=(), smo_k=13): + keys = x_s_info_list[0].keys() + N = len(x_s_info_list) + smo_dict = {} + for k in keys: + _lst = [x_s_info_list[i][k] for i in range(N)] + if k not in ignore_keys: + _lst = np.stack(_lst, 0) + _smo_lst = _mean_filter(_lst, smo_k) + else: + _smo_lst = _lst + smo_dict[k] = _smo_lst + + smo_res = [] + for i in range(N): + x_s_info = {k: smo_dict[k][i] for k in keys} + smo_res.append(x_s_info) + return smo_res + + +class AvatarRegistrar: + """ + source image|video -> rgb_list -> source_info + """ + def __init__( + self, + insightface_det_cfg, + landmark106_cfg, + landmark203_cfg, + landmark478_cfg, + appearance_extractor_cfg, + motion_extractor_cfg, + ): + self.source2info = Source2Info( + insightface_det_cfg, + landmark106_cfg, + landmark203_cfg, + landmark478_cfg, + appearance_extractor_cfg, + motion_extractor_cfg, + ) + + def register( + self, + source_path, # image | video + max_dim=1920, + n_frames=-1, + **kwargs, + ): + """ + kwargs: + crop_scale: 2.3 + crop_vx_ratio: 0 + crop_vy_ratio: -0.125 + crop_flag_do_rot: True + """ + rgb_list, is_image_flag = load_source_frames(source_path, max_dim=max_dim, n_frames=n_frames) + source_info = { + "x_s_info_lst": [], + "f_s_lst": [], + "M_c2o_lst": [], + "eye_open_lst": [], + "eye_ball_lst": [], + } + keys = ["x_s_info", "f_s", "M_c2o", "eye_open", "eye_ball"] + last_lmk = None + for rgb in rgb_list: + info = self.source2info(rgb, last_lmk, **kwargs) + for k in keys: + source_info[f"{k}_lst"].append(info[k]) + + last_lmk = info["lmk203"] + + sc_f0 = source_info['x_s_info_lst'][0]['kp'].flatten() + + source_info["sc"] = sc_f0 + source_info["is_image_flag"] = is_image_flag + source_info["img_rgb_lst"] = rgb_list + + return source_info + + def __call__(self, *args, **kwargs): + return self.register(*args, **kwargs) + \ No newline at end of file diff --git a/core/atomic_components/cfg.py b/core/atomic_components/cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..28e28ec9e53c3bef4f7de601771be92093a9bffc --- /dev/null +++ b/core/atomic_components/cfg.py @@ -0,0 +1,111 @@ +import os +import pickle +import numpy as np + + +def load_pkl(pkl): + with open(pkl, "rb") as f: + return pickle.load(f) + + +def parse_cfg(cfg_pkl, data_root, replace_cfg=None): + + def _check_path(p): + if os.path.isfile(p): + return p + else: + return os.path.join(data_root, p) + + cfg = load_pkl(cfg_pkl) + + # --- + # replace cfg for debug + if isinstance(replace_cfg, dict): + for k, v in replace_cfg.items(): + if not isinstance(v, dict): + continue + for kk, vv in v.items(): + cfg[k][kk] = vv + # --- + + base_cfg = cfg["base_cfg"] + audio2motion_cfg = cfg["audio2motion_cfg"] + default_kwargs = cfg["default_kwargs"] + + for k in base_cfg: + if k == "landmark478_cfg": + for kk in ["task_path", "blaze_face_model_path", "face_mesh_model_path"]: + if kk in base_cfg[k] and base_cfg[k][kk]: + base_cfg[k][kk] = _check_path(base_cfg[k][kk]) + else: + base_cfg[k]["model_path"] = _check_path(base_cfg[k]["model_path"]) + + audio2motion_cfg["model_path"] = _check_path(audio2motion_cfg["model_path"]) + + avatar_registrar_cfg = { + k: base_cfg[k] + for k in [ + "insightface_det_cfg", + "landmark106_cfg", + "landmark203_cfg", + "landmark478_cfg", + "appearance_extractor_cfg", + "motion_extractor_cfg", + ] + } + + stitch_network_cfg = base_cfg["stitch_network_cfg"] + warp_network_cfg = base_cfg["warp_network_cfg"] + decoder_cfg = base_cfg["decoder_cfg"] + + condition_handler_cfg = { + k: audio2motion_cfg[k] + for k in [ + "use_emo", + "use_sc", + "use_eye_open", + "use_eye_ball", + "seq_frames", + ] + } + + lmdm_cfg = { + k: audio2motion_cfg[k] + for k in [ + "model_path", + "device", + "motion_feat_dim", + "audio_feat_dim", + "seq_frames", + ] + } + + w2f_type = audio2motion_cfg["w2f_type"] + wav2feat_cfg = { + "w2f_cfg": base_cfg["hubert_cfg"] if w2f_type == "hubert" else base_cfg["wavlm_cfg"], + "w2f_type": w2f_type, + } + + return [ + avatar_registrar_cfg, + condition_handler_cfg, + lmdm_cfg, + stitch_network_cfg, + warp_network_cfg, + decoder_cfg, + wav2feat_cfg, + default_kwargs, + ] + + +def print_cfg(**kwargs): + for k, v in kwargs.items(): + if k == "ch_info": + print(k, type(v)) + elif k == "ctrl_info": + print(k, type(v), len(v)) + else: + if isinstance(v, np.ndarray): + print(k, type(v), v.shape) + else: + print(k, type(v), v) diff --git a/core/atomic_components/condition_handler.py b/core/atomic_components/condition_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..ba1af63eecb9646eb85a91d864a6d912b286e5d7 --- /dev/null +++ b/core/atomic_components/condition_handler.py @@ -0,0 +1,168 @@ +import numpy as np +from scipy.special import softmax +import copy + + +def _get_emo_avg(idx=6): + emo_avg = np.zeros(8, dtype=np.float32) + if isinstance(idx, (list, tuple)): + for i in idx: + emo_avg[i] = 8 + else: + emo_avg[idx] = 8 + emo_avg = softmax(emo_avg) + #emo_avg = None + # 'Angry', 'Disgust', 'Fear', 'Happy', 'Neutral', 'Sad', 'Surprise', 'Contempt' + return emo_avg + + +def _mirror_index(index, size): + turn = index // size + res = index % size + if turn % 2 == 0: + return res + else: + return size - res - 1 + + +class ConditionHandler: + """ + aud_feat, emo_seq, eye_seq, sc_seq -> cond_seq + """ + def __init__( + self, + use_emo=True, + use_sc=True, + use_eye_open=True, + use_eye_ball=True, + seq_frames=80, + ): + self.use_emo = use_emo + self.use_sc = use_sc + self.use_eye_open = use_eye_open + self.use_eye_ball = use_eye_ball + + self.seq_frames = seq_frames + + def setup(self, setup_info, emo, eye_f0_mode=False, ch_info=None): + """ + emo: int | [int] | [[int]] | numpy + """ + if ch_info is None: + source_info = copy.deepcopy(setup_info) + else: + source_info = ch_info + + self.eye_f0_mode = eye_f0_mode + self.x_s_info_0 = source_info['x_s_info_lst'][0] + + if self.use_sc: + self.sc = source_info["sc"] # 63 + self.sc_seq = np.stack([self.sc] * self.seq_frames, 0) + + if self.use_eye_open: + self.eye_open_lst = np.concatenate(source_info["eye_open_lst"], 0) # [n, 2] + self.num_eye_open = len(self.eye_open_lst) + if self.num_eye_open == 1 or self.eye_f0_mode: + self.eye_open_seq = np.stack([self.eye_open_lst[0]] * self.seq_frames, 0) + else: + self.eye_open_seq = None + + if self.use_eye_ball: + self.eye_ball_lst = np.concatenate(source_info["eye_ball_lst"], 0) # [n, 6] + self.num_eye_ball = len(self.eye_ball_lst) + if self.num_eye_ball == 1 or self.eye_f0_mode: + self.eye_ball_seq = np.stack([self.eye_ball_lst[0]] * self.seq_frames, 0) + else: + self.eye_ball_seq = None + + if self.use_emo: + self.emo_lst = self._parse_emo_seq(emo) + self.num_emo = len(self.emo_lst) + if self.num_emo == 1: + self.emo_seq = np.concatenate([self.emo_lst] * self.seq_frames, 0) + else: + self.emo_seq = None + + @staticmethod + def _parse_emo_seq(emo, seq_len=-1): + if isinstance(emo, np.ndarray) and emo.ndim == 2 and emo.shape[1] == 8: + # emo arr, e.g. real + emo_seq = emo # [m, 8] + elif isinstance(emo, int) and 0 <= emo < 8: + # emo label, e.g. 4 + emo_seq = _get_emo_avg(emo).reshape(1, 8) # [1, 8] + elif isinstance(emo, (list, tuple)) and 0 < len(emo) < 8 and isinstance(emo[0], int): + # emo labels, e.g. [3,4] + emo_seq = _get_emo_avg(emo).reshape(1, 8) # [1, 8] + elif isinstance(emo, list) and emo and isinstance(emo[0], (list, tuple)): + # emo label list, e.g. [[4], [3,4], [3],[3,4,5], ...] + emo_seq = np.stack([_get_emo_avg(i) for i in emo], 0) # [m, 8] + else: + raise ValueError(f"Unsupported emo type: {emo}") + + if seq_len > 0: + if len(emo_seq) == seq_len: + return emo_seq + elif len(emo_seq) == 1: + return np.concatenate([emo_seq] * seq_len, 0) + elif len(emo_seq) > seq_len: + return emo_seq[:seq_len] + else: + raise ValueError(f"emo len {len(emo_seq)} can not match seq len ({seq_len})") + else: + return emo_seq + + def __call__(self, aud_feat, idx, emo=None): + """ + aud_feat: [n, 1024] + idx: int, <0 means pad (first clip buffer) + """ + + frame_num = len(aud_feat) + more_cond = [aud_feat] + if self.use_emo: + if emo is not None: + emo_seq = self._parse_emo_seq(emo, frame_num) + elif self.emo_seq is not None and len(self.emo_seq) == frame_num: + emo_seq = self.emo_seq + else: + emo_idx_list = [max(i, 0) % self.num_emo for i in range(idx, idx + frame_num)] + emo_seq = self.emo_lst[emo_idx_list] + more_cond.append(emo_seq) + + if self.use_eye_open: + if self.eye_open_seq is not None and len(self.eye_open_seq) == frame_num: + eye_open_seq = self.eye_open_seq + else: + if self.eye_f0_mode: + eye_idx_list = [0] * frame_num + else: + eye_idx_list = [_mirror_index(max(i, 0), self.num_eye_open) for i in range(idx, idx + frame_num)] + eye_open_seq = self.eye_open_lst[eye_idx_list] + more_cond.append(eye_open_seq) + + if self.use_eye_ball: + if self.eye_ball_seq is not None and len(self.eye_ball_seq) == frame_num: + eye_ball_seq = self.eye_ball_seq + else: + if self.eye_f0_mode: + eye_idx_list = [0] * frame_num + else: + eye_idx_list = [_mirror_index(max(i, 0), self.num_eye_ball) for i in range(idx, idx + frame_num)] + eye_ball_seq = self.eye_ball_lst[eye_idx_list] + more_cond.append(eye_ball_seq) + + if self.use_sc: + if len(self.sc_seq) == frame_num: + sc_seq = self.sc_seq + else: + sc_seq = np.stack([self.sc] * frame_num, 0) + more_cond.append(sc_seq) + + if len(more_cond) > 1: + cond_seq = np.concatenate(more_cond, -1) # [n, dim_cond] + else: + cond_seq = aud_feat + + return cond_seq diff --git a/core/atomic_components/decode_f3d.py b/core/atomic_components/decode_f3d.py new file mode 100644 index 0000000000000000000000000000000000000000..96b85d2ac855ccb4b80804416e1c4e9c818789b2 --- /dev/null +++ b/core/atomic_components/decode_f3d.py @@ -0,0 +1,22 @@ +from ..models.decoder import Decoder + + +""" +# __init__ +decoder_cfg = { + "model_path": "", + "device": "cuda", +} +""" + +class DecodeF3D: + def __init__( + self, + decoder_cfg, + ): + self.decoder = Decoder(**decoder_cfg) + + def __call__(self, f_s): + out = self.decoder(f_s) + return out + \ No newline at end of file diff --git a/core/atomic_components/loader.py b/core/atomic_components/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..fb46cec89081c04e641b6ab0f87afa6589c3e340 --- /dev/null +++ b/core/atomic_components/loader.py @@ -0,0 +1,133 @@ +import filetype +import imageio +import cv2 + + +def is_image(file_path): + return filetype.is_image(file_path) + + +def is_video(file_path): + return filetype.is_video(file_path) + + +def check_resize(h, w, max_dim=1920, division=2): + rsz_flag = False + # ajust the size of the image according to the maximum dimension + if max_dim > 0 and max(h, w) > max_dim: + rsz_flag = True + if h > w: + new_h = max_dim + new_w = int(round(w * max_dim / h)) + else: + new_w = max_dim + new_h = int(round(h * max_dim / w)) + else: + new_h = h + new_w = w + + # ensure that the image dimensions are multiples of n + if new_h % division != 0: + new_h = new_h - (new_h % division) + rsz_flag = True + if new_w % division != 0: + new_w = new_w - (new_w % division) + rsz_flag = True + + return new_h, new_w, rsz_flag + + +def load_image(image_path, max_dim=-1): + img = cv2.imread(image_path, cv2.IMREAD_COLOR) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + h, w = img.shape[:2] + new_h, new_w, rsz_flag = check_resize(h, w, max_dim) + if rsz_flag: + img = cv2.resize(img, (new_w, new_h)) + return img + + +def load_video(video_path, n_frames=-1, max_dim=-1): + reader = imageio.get_reader(video_path, "ffmpeg") + + new_h, new_w, rsz_flag = None, None, None + + ret = [] + for idx, frame_rgb in enumerate(reader): + if n_frames > 0 and idx >= n_frames: + break + + if rsz_flag is None: + h, w = frame_rgb.shape[:2] + new_h, new_w, rsz_flag = check_resize(h, w, max_dim) + + if rsz_flag: + frame_rgb = cv2.resize(frame_rgb, (new_w, new_h)) + + ret.append(frame_rgb) + + reader.close() + return ret + + +def load_source_frames(source_path, max_dim=-1, n_frames=-1): + if is_image(source_path): + rgb = load_image(source_path, max_dim) + rgb_list = [rgb] + is_image_flag = True + elif is_video(source_path): + rgb_list = load_video(source_path, n_frames, max_dim) + is_image_flag = False + else: + raise ValueError(f"Unsupported source type: {source_path}") + return rgb_list, is_image_flag + + +def _mirror_index(index, size): + turn = index // size + res = index % size + if turn % 2 == 0: + return res + else: + return size - res - 1 + + +class LoopLoader: + def __init__(self, item_list, max_iter_num=-1, mirror_loop=True): + self.item_list = item_list + self.idx = 0 + self.item_num = len(self.item_list) + self.max_iter_num = max_iter_num if max_iter_num > 0 else self.item_num + self.mirror_loop = mirror_loop + + def __len__(self): + return self.max_iter_num + + def __iter__(self): + return self + + def __next__(self): + if self.idx >= self.max_iter_num: + raise StopIteration + + if self.mirror_loop: + idx = _mirror_index(self.idx, self.item_num) + else: + idx = self.idx % self.item_num + item = self.item_list[idx] + + self.idx += 1 + return item + + def __call__(self): + return self.__iter__() + + def reset(self, max_iter_num=-1): + self.frame_idx = 0 + self.max_iter_num = max_iter_num if max_iter_num > 0 else self.item_num + + + + + + diff --git a/core/atomic_components/motion_stitch.py b/core/atomic_components/motion_stitch.py new file mode 100644 index 0000000000000000000000000000000000000000..7967984b0abd39bcf31fc5784756c360ce40dbbc --- /dev/null +++ b/core/atomic_components/motion_stitch.py @@ -0,0 +1,491 @@ +import copy +import random +import numpy as np +from scipy.special import softmax + +from ..models.stitch_network import StitchNetwork + + +""" +# __init__ +stitch_network_cfg = { + "model_path": "", + "device": "cuda", +} + +# __call__ +kwargs: + fade_alpha + fade_out_keys + + delta_pitch + delta_yaw + delta_roll + +""" + + +def ctrl_motion(x_d_info, **kwargs): + # pose + offset + for kk in ["delta_pitch", "delta_yaw", "delta_roll"]: + if kk in kwargs: + k = kk[6:] + x_d_info[k] = bin66_to_degree(x_d_info[k]) + kwargs[kk] + + # pose * alpha + for kk in ["alpha_pitch", "alpha_yaw", "alpha_roll"]: + if kk in kwargs: + k = kk[6:] + x_d_info[k] = x_d_info[k] * kwargs[kk] + + # exp + offset + if "delta_exp" in kwargs: + k = "exp" + x_d_info[k] = x_d_info[k] + kwargs["delta_exp"] + + return x_d_info + + +def fade(x_d_info, dst, alpha, keys=None): + if keys is None: + keys = x_d_info.keys() + for k in keys: + if k == 'kp': + continue + x_d_info[k] = x_d_info[k] * alpha + dst[k] * (1 - alpha) + return x_d_info + + +def ctrl_vad(x_d_info, dst, alpha): + exp = x_d_info["exp"] + exp_dst = dst["exp"] + + _lip = [6, 12, 14, 17, 19, 20] + _a1 = np.zeros((21, 3), dtype=np.float32) + _a1[_lip] = alpha + _a1 = _a1.reshape(1, -1) + x_d_info["exp"] = exp * alpha + exp_dst * (1 - alpha) + + return x_d_info + + + +def _mix_s_d_info( + x_s_info, + x_d_info, + use_d_keys=("exp", "pitch", "yaw", "roll", "t"), + d0=None, +): + if d0 is not None: + if isinstance(use_d_keys, dict): + x_d_info = { + k: x_s_info[k] + (v - d0[k]) * use_d_keys.get(k, 1) + for k, v in x_d_info.items() + } + else: + x_d_info = {k: x_s_info[k] + (v - d0[k]) for k, v in x_d_info.items()} + + for k, v in x_s_info.items(): + if k not in x_d_info or k not in use_d_keys: + x_d_info[k] = v + + if isinstance(use_d_keys, dict) and d0 is None: + for k, alpha in use_d_keys.items(): + x_d_info[k] *= alpha + return x_d_info + + +def _set_eye_blink_idx(N, blink_n=15, open_n=-1): + """ + open_n: + -1: no blink + 0: random open_n + >0: fix open_n + list: loop open_n + """ + OPEN_MIN = 60 + OPEN_MAX = 100 + + idx = [0] * N + if isinstance(open_n, int): + if open_n < 0: # no blink + return idx + elif open_n > 0: # fix open_n + open_ns = [open_n] + else: # open_n == 0: # random open_n, 60-100 + open_ns = [] + elif isinstance(open_n, list): + open_ns = open_n # loop open_n + else: + raise ValueError() + + blink_idx = list(range(blink_n)) + + start_n = open_ns[0] if open_ns else random.randint(OPEN_MIN, OPEN_MAX) + end_n = open_ns[-1] if open_ns else random.randint(OPEN_MIN, OPEN_MAX) + max_i = N - max(end_n, blink_n) + cur_i = start_n + cur_n_i = 1 + while cur_i < max_i: + idx[cur_i : cur_i + blink_n] = blink_idx + + if open_ns: + cur_n = open_ns[cur_n_i % len(open_ns)] + cur_n_i += 1 + else: + cur_n = random.randint(OPEN_MIN, OPEN_MAX) + + cur_i = cur_i + blink_n + cur_n + + return idx + + +def _fix_exp_for_x_d_info(x_d_info, x_s_info, delta_eye=None, drive_eye=True): + _eye = [11, 13, 15, 16, 18] + _lip = [6, 12, 14, 17, 19, 20] + alpha = np.zeros((21, 3), dtype=x_d_info["exp"].dtype) + alpha[_lip] = 1 + if delta_eye is None and drive_eye: # use d eye + alpha[_eye] = 1 + alpha = alpha.reshape(1, -1) + x_d_info["exp"] = x_d_info["exp"] * alpha + x_s_info["exp"] * (1 - alpha) + + if delta_eye is not None and drive_eye: + alpha = np.zeros((21, 3), dtype=x_d_info["exp"].dtype) + alpha[_eye] = 1 + alpha = alpha.reshape(1, -1) + x_d_info["exp"] = (delta_eye + x_s_info["exp"]) * alpha + x_d_info["exp"] * ( + 1 - alpha + ) + + return x_d_info + + +def _fix_exp_for_x_d_info_v2(x_d_info, x_s_info, delta_eye, a1, a2, a3): + x_d_info["exp"] = x_d_info["exp"] * a1 + x_s_info["exp"] * a2 + delta_eye * a3 + return x_d_info + + +def bin66_to_degree(pred): + if pred.ndim > 1 and pred.shape[1] == 66: + idx = np.arange(66).astype(np.float32) + pred = softmax(pred, axis=1) + degree = np.sum(pred * idx, axis=1) * 3 - 97.5 + return degree + return pred + + +def _eye_delta(exp, dx=0, dy=0): + if dx > 0: + exp[0, 33] += dx * 0.0007 + exp[0, 45] += dx * 0.001 + else: + exp[0, 33] += dx * 0.001 + exp[0, 45] += dx * 0.0007 + + exp[0, 34] += dy * -0.001 + exp[0, 46] += dy * -0.001 + return exp + +def _fix_gaze(pose_s, x_d_info): + x_ratio = 0.26 + y_ratio = 0.28 + + yaw_s, pitch_s = pose_s + yaw_d = bin66_to_degree(x_d_info['yaw']).item() + pitch_d = bin66_to_degree(x_d_info['pitch']).item() + + delta_yaw = yaw_d - yaw_s + delta_pitch = pitch_d - pitch_s + + dx = delta_yaw * x_ratio + dy = delta_pitch * y_ratio + + x_d_info['exp'] = _eye_delta(x_d_info['exp'], dx, dy) + return x_d_info + + +def get_rotation_matrix(pitch_, yaw_, roll_): + """ the input is in degree + """ + # transform to radian + pitch = pitch_ / 180 * np.pi + yaw = yaw_ / 180 * np.pi + roll = roll_ / 180 * np.pi + + if pitch.ndim == 1: + pitch = pitch[:, None] + if yaw.ndim == 1: + yaw = yaw[:, None] + if roll.ndim == 1: + roll = roll[:, None] + + # calculate the euler matrix + bs = pitch.shape[0] + ones = np.ones((bs, 1), dtype=np.float32) + zeros = np.zeros((bs, 1), dtype=np.float32) + x, y, z = pitch, yaw, roll + + rot_x = np.concatenate([ + ones, zeros, zeros, + zeros, np.cos(x), -np.sin(x), + zeros, np.sin(x), np.cos(x) + ], axis=1).reshape(bs, 3, 3) + + rot_y = np.concatenate([ + np.cos(y), zeros, np.sin(y), + zeros, ones, zeros, + -np.sin(y), zeros, np.cos(y) + ], axis=1).reshape(bs, 3, 3) + + rot_z = np.concatenate([ + np.cos(z), -np.sin(z), zeros, + np.sin(z), np.cos(z), zeros, + zeros, zeros, ones + ], axis=1).reshape(bs, 3, 3) + + rot = np.matmul(np.matmul(rot_z, rot_y), rot_x) + return np.transpose(rot, (0, 2, 1)) + + +def transform_keypoint(kp_info: dict): + """ + transform the implicit keypoints with the pose, shift, and expression deformation + kp: BxNx3 + """ + kp = kp_info['kp'] # (bs, k, 3) + pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll'] + + t, exp = kp_info['t'], kp_info['exp'] + scale = kp_info['scale'] + + pitch = bin66_to_degree(pitch) + yaw = bin66_to_degree(yaw) + roll = bin66_to_degree(roll) + + bs = kp.shape[0] + if kp.ndim == 2: + num_kp = kp.shape[1] // 3 # Bx(num_kpx3) + else: + num_kp = kp.shape[1] # Bxnum_kpx3 + + rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3) + + # Eqn.2: s * (R * x_c,s + exp) + t + kp_transformed = np.matmul(kp.reshape(bs, num_kp, 3), rot_mat) + exp.reshape(bs, num_kp, 3) + kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3) + kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty + + return kp_transformed + + +class MotionStitch: + def __init__( + self, + stitch_network_cfg, + ): + self.stitch_net = StitchNetwork(**stitch_network_cfg) + + def set_Nd(self, N_d=-1): + # only for offline (make start|end eye open) + if N_d == self.N_d: + return + + self.N_d = N_d + if self.drive_eye and self.delta_eye_arr is not None: + N = 3000 if self.N_d == -1 else self.N_d + self.delta_eye_idx_list = _set_eye_blink_idx( + N, len(self.delta_eye_arr), self.delta_eye_open_n + ) + + def setup( + self, + N_d=-1, + use_d_keys=None, + relative_d=True, + drive_eye=None, # use d eye or s eye + delta_eye_arr=None, # fix eye + delta_eye_open_n=-1, # int|list + fade_out_keys=("exp",), + fade_type="", # "" | "d0" | "s" + flag_stitching=True, + is_image_flag=True, + x_s_info=None, + d0=None, + ch_info=None, + overall_ctrl_info=None, + ): + self.is_image_flag = is_image_flag + if use_d_keys is None: + if self.is_image_flag: + self.use_d_keys = ("exp", "pitch", "yaw", "roll", "t") + else: + self.use_d_keys = ("exp", ) + else: + self.use_d_keys = use_d_keys + + if drive_eye is None: + if self.is_image_flag: + self.drive_eye = True + else: + self.drive_eye = False + else: + self.drive_eye = drive_eye + + self.N_d = N_d + self.relative_d = relative_d + self.delta_eye_arr = delta_eye_arr + self.delta_eye_open_n = delta_eye_open_n + self.fade_out_keys = fade_out_keys + self.fade_type = fade_type + self.flag_stitching = flag_stitching + + _eye = [11, 13, 15, 16, 18] + _lip = [6, 12, 14, 17, 19, 20] + _a1 = np.zeros((21, 3), dtype=np.float32) + _a1[_lip] = 1 + _a2 = 0 + if self.drive_eye: + if self.delta_eye_arr is None: + _a1[_eye] = 1 + else: + _a2 = np.zeros((21, 3), dtype=np.float32) + _a2[_eye] = 1 + _a2 = _a2.reshape(1, -1) + _a1 = _a1.reshape(1, -1) + + self.fix_exp_a1 = _a1 * (1 - _a2) + self.fix_exp_a2 = (1 - _a1) + _a1 * _a2 + self.fix_exp_a3 = _a2 + + if self.drive_eye and self.delta_eye_arr is not None: + N = 3000 if self.N_d == -1 else self.N_d + self.delta_eye_idx_list = _set_eye_blink_idx( + N, len(self.delta_eye_arr), self.delta_eye_open_n + ) + + self.pose_s = None + self.x_s = None + self.fade_dst = None + if self.is_image_flag and x_s_info is not None: + yaw_s = bin66_to_degree(x_s_info['yaw']).item() + pitch_s = bin66_to_degree(x_s_info['pitch']).item() + self.pose_s = [yaw_s, pitch_s] + self.x_s = transform_keypoint(x_s_info) + + if self.fade_type == "s": + self.fade_dst = copy.deepcopy(x_s_info) + + if ch_info is not None: + self.scale_a = ch_info['x_s_info_lst'][0]['scale'].item() + if x_s_info is not None: + self.scale_b = x_s_info['scale'].item() + self.scale_ratio = self.scale_a / self.scale_b + self._set_scale_ratio(self.scale_ratio) + else: + self.scale_ratio = None + else: + self.scale_ratio = 1 + + self.overall_ctrl_info = overall_ctrl_info + + self.d0 = d0 + self.idx = 0 + + def _set_scale_ratio(self, scale_ratio=1): + if scale_ratio == 1: + return + if isinstance(self.use_d_keys, dict): + self.use_d_keys = {k: v * (scale_ratio if k in {"exp", "pitch", "yaw", "roll"} else 1) for k, v in self.use_d_keys.items()} + else: + self.use_d_keys = {k: scale_ratio if k in {"exp", "pitch", "yaw", "roll"} else 1 for k in self.use_d_keys} + + @staticmethod + def _merge_kwargs(default_kwargs, run_kwargs): + if default_kwargs is None: + return run_kwargs + + for k, v in default_kwargs.items(): + if k not in run_kwargs: + run_kwargs[k] = v + return run_kwargs + + def __call__(self, x_s_info, x_d_info, **kwargs): + # return x_s, x_d + + kwargs = self._merge_kwargs(self.overall_ctrl_info, kwargs) + + if self.scale_ratio is None: + self.scale_b = x_s_info['scale'].item() + self.scale_ratio = self.scale_a / self.scale_b + self._set_scale_ratio(self.scale_ratio) + + if self.relative_d and self.d0 is None: + self.d0 = copy.deepcopy(x_d_info) + + x_d_info = _mix_s_d_info( + x_s_info, + x_d_info, + self.use_d_keys, + self.d0, + ) + + delta_eye = 0 + if self.drive_eye and self.delta_eye_arr is not None: + delta_eye = self.delta_eye_arr[ + self.delta_eye_idx_list[self.idx % len(self.delta_eye_idx_list)] + ][None] + x_d_info = _fix_exp_for_x_d_info_v2( + x_d_info, + x_s_info, + delta_eye, + self.fix_exp_a1, + self.fix_exp_a2, + self.fix_exp_a3, + ) + + if kwargs.get("vad_alpha", 1) < 1: + x_d_info = ctrl_vad(x_d_info, x_s_info, kwargs.get("vad_alpha", 1)) + + x_d_info = ctrl_motion(x_d_info, **kwargs) + + if self.fade_type == "d0" and self.fade_dst is None: + self.fade_dst = copy.deepcopy(x_d_info) + + # fade + if "fade_alpha" in kwargs and self.fade_type in ["d0", "s"]: + fade_alpha = kwargs["fade_alpha"] + fade_keys = kwargs.get("fade_out_keys", self.fade_out_keys) + if self.fade_type == "d0": + fade_dst = self.fade_dst + elif self.fade_type == "s": + if self.fade_dst is not None: + fade_dst = self.fade_dst + else: + fade_dst = copy.deepcopy(x_s_info) + if self.is_image_flag: + self.fade_dst = fade_dst + x_d_info = fade(x_d_info, fade_dst, fade_alpha, fade_keys) + + if self.drive_eye: + if self.pose_s is None: + yaw_s = bin66_to_degree(x_s_info['yaw']).item() + pitch_s = bin66_to_degree(x_s_info['pitch']).item() + self.pose_s = [yaw_s, pitch_s] + x_d_info = _fix_gaze(self.pose_s, x_d_info) + + if self.x_s is not None: + x_s = self.x_s + else: + x_s = transform_keypoint(x_s_info) + if self.is_image_flag: + self.x_s = x_s + + x_d = transform_keypoint(x_d_info) + + if self.flag_stitching: + x_d = self.stitch_net(x_s, x_d) + + self.idx += 1 + + return x_s, x_d diff --git a/core/atomic_components/putback.py b/core/atomic_components/putback.py new file mode 100644 index 0000000000000000000000000000000000000000..39a5a6dc1581697774de0d5b62b214a09c8dfb72 --- /dev/null +++ b/core/atomic_components/putback.py @@ -0,0 +1,60 @@ +import cv2 +import numpy as np +from ..utils.blend import blend_images_cy +from ..utils.get_mask import get_mask + + +class PutBackNumpy: + def __init__( + self, + mask_template_path=None, + ): + if mask_template_path is None: + mask = get_mask(512, 512, 0.9, 0.9) + self.mask_ori_float = np.concatenate([mask] * 3, 2) + else: + mask = cv2.imread(mask_template_path, cv2.IMREAD_COLOR) + self.mask_ori_float = mask.astype(np.float32) / 255.0 + + def __call__(self, frame_rgb, render_image, M_c2o): + h, w = frame_rgb.shape[:2] + mask_warped = cv2.warpAffine( + self.mask_ori_float, M_c2o[:2, :], dsize=(w, h), flags=cv2.INTER_LINEAR + ).clip(0, 1) + frame_warped = cv2.warpAffine( + render_image, M_c2o[:2, :], dsize=(w, h), flags=cv2.INTER_LINEAR + ) + result = mask_warped * frame_warped + (1 - mask_warped) * frame_rgb + result = np.clip(result, 0, 255) + result = result.astype(np.uint8) + return result + + +class PutBack: + def __init__( + self, + mask_template_path=None, + ): + if mask_template_path is None: + mask = get_mask(512, 512, 0.9, 0.9) + mask = np.concatenate([mask] * 3, 2) + else: + mask = cv2.imread(mask_template_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0 + + self.mask_ori_float = np.ascontiguousarray(mask)[:,:,0] + self.result_buffer = None + + def __call__(self, frame_rgb, render_image, M_c2o): + h, w = frame_rgb.shape[:2] + mask_warped = cv2.warpAffine( + self.mask_ori_float, M_c2o[:2, :], dsize=(w, h), flags=cv2.INTER_LINEAR + ).clip(0, 1) + frame_warped = cv2.warpAffine( + render_image, M_c2o[:2, :], dsize=(w, h), flags=cv2.INTER_LINEAR + ) + self.result_buffer = np.empty((h, w, 3), dtype=np.uint8) + + # Use Cython implementation for blending + blend_images_cy(mask_warped, frame_warped, frame_rgb, self.result_buffer) + + return self.result_buffer \ No newline at end of file diff --git a/core/atomic_components/source2info.py b/core/atomic_components/source2info.py new file mode 100644 index 0000000000000000000000000000000000000000..1a64fb6fbe16928ea60f4281930525f643fc6d7e --- /dev/null +++ b/core/atomic_components/source2info.py @@ -0,0 +1,155 @@ +import numpy as np +import cv2 + +from ..aux_models.insightface_det import InsightFaceDet +from ..aux_models.insightface_landmark106 import Landmark106 +from ..aux_models.landmark203 import Landmark203 +from ..aux_models.mediapipe_landmark478 import Landmark478 +from ..models.appearance_extractor import AppearanceExtractor +from ..models.motion_extractor import MotionExtractor + +from ..utils.crop import crop_image +from ..utils.eye_info import EyeAttrUtilsByMP + + +""" +insightface_det_cfg = { + "model_path": "", + "device": "cuda", + "force_ori_type": False, +} +landmark106_cfg = { + "model_path": "", + "device": "cuda", + "force_ori_type": False, +} +landmark203_cfg = { + "model_path": "", + "device": "cuda", + "force_ori_type": False, +} +landmark478_cfg = { + "blaze_face_model_path": "", + "face_mesh_model_path": "", + "device": "cuda", + "force_ori_type": False, + "task_path": "", +} +appearance_extractor_cfg = { + "model_path": "", + "device": "cuda", +} +motion_extractor_cfg = { + "model_path": "", + "device": "cuda", +} +""" + + +class Source2Info: + def __init__( + self, + insightface_det_cfg, + landmark106_cfg, + landmark203_cfg, + landmark478_cfg, + appearance_extractor_cfg, + motion_extractor_cfg, + ): + self.insightface_det = InsightFaceDet(**insightface_det_cfg) + self.landmark106 = Landmark106(**landmark106_cfg) + self.landmark203 = Landmark203(**landmark203_cfg) + self.landmark478 = Landmark478(**landmark478_cfg) + + self.appearance_extractor = AppearanceExtractor(**appearance_extractor_cfg) + self.motion_extractor = MotionExtractor(**motion_extractor_cfg) + + def _crop(self, img, last_lmk=None, **kwargs): + # img_rgb -> det->landmark106->landmark203->crop + + if last_lmk is None: # det for first frame or image + det, _ = self.insightface_det(img) + boxes = det[np.argsort(-(det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1]))] + if len(boxes) == 0: + return None + lmk_for_track = self.landmark106(img, boxes[0]) # 106 + else: # track for video frames + lmk_for_track = last_lmk # 203 + + crop_dct = crop_image( + img, + lmk_for_track, + dsize=self.landmark203.dsize, + scale=1.5, + vy_ratio=-0.1, + pt_crop_flag=False, + ) + lmk203 = self.landmark203(crop_dct["img_crop"], crop_dct["M_c2o"]) + + ret_dct = crop_image( + img, + lmk203, + dsize=512, + scale=kwargs.get("crop_scale", 2.3), + vx_ratio=kwargs.get("crop_vx_ratio", 0), + vy_ratio=kwargs.get("crop_vy_ratio", -0.125), + flag_do_rot=kwargs.get("crop_flag_do_rot", True), + pt_crop_flag=False, + ) + + img_crop = ret_dct["img_crop"] + M_c2o = ret_dct["M_c2o"] + + return img_crop, M_c2o, lmk203 + + @staticmethod + def _img_crop_to_bchw256(img_crop): + rgb_256 = cv2.resize(img_crop, (256, 256), interpolation=cv2.INTER_AREA) + rgb_256_bchw = (rgb_256.astype(np.float32) / 255.0)[None].transpose(0, 3, 1, 2) + return rgb_256_bchw + + def _get_kp_info(self, img): + # rgb_256_bchw_norm01 + kp_info = self.motion_extractor(img) + return kp_info + + def _get_f3d(self, img): + # rgb_256_bchw_norm01 + fs = self.appearance_extractor(img) + return fs + + def _get_eye_info(self, img): + # rgb uint8 + lmk478 = self.landmark478(img) # [1, 478, 3] + attr = EyeAttrUtilsByMP(lmk478) + lr_open = attr.LR_open().reshape(-1, 2) # [1, 2] + lr_ball = attr.LR_ball_move().reshape(-1, 6) # [1, 3, 2] -> [1, 6] + return [lr_open, lr_ball] + + def __call__(self, img, last_lmk=None, **kwargs): + """ + img: rgb, uint8 + last_lmk: last frame lmk203, for video tracking + kwargs: optional crop cfg + crop_scale: 2.3 + crop_vx_ratio: 0 + crop_vy_ratio: -0.125 + crop_flag_do_rot: True + """ + img_crop, M_c2o, lmk203 = self._crop(img, last_lmk=last_lmk, **kwargs) + + eye_open, eye_ball = self._get_eye_info(img_crop) + + rgb_256_bchw = self._img_crop_to_bchw256(img_crop) + kp_info = self._get_kp_info(rgb_256_bchw) + fs = self._get_f3d(rgb_256_bchw) + + source_info = { + "x_s_info": kp_info, + "f_s": fs, + "M_c2o": M_c2o, + "eye_open": eye_open, # [1, 2] + "eye_ball": eye_ball, # [1, 6] + "lmk203": lmk203, # for track + } + return source_info diff --git a/core/atomic_components/warp_f3d.py b/core/atomic_components/warp_f3d.py new file mode 100644 index 0000000000000000000000000000000000000000..be4262f6d329f75b7b3ca661ebc8ad8d3503c6e3 --- /dev/null +++ b/core/atomic_components/warp_f3d.py @@ -0,0 +1,22 @@ +from ..models.warp_network import WarpNetwork + + +""" +# __init__ +warp_network_cfg = { + "model_path": "", + "device": "cuda", +} +""" + +class WarpF3D: + def __init__( + self, + warp_network_cfg, + ): + self.warp_net = WarpNetwork(**warp_network_cfg) + + def __call__(self, f_s, x_s, x_d): + out = self.warp_net(f_s, x_s, x_d) + return out + \ No newline at end of file diff --git a/core/atomic_components/wav2feat.py b/core/atomic_components/wav2feat.py new file mode 100644 index 0000000000000000000000000000000000000000..10f360edb26b86c0ade9e59de4b71bc64983de82 --- /dev/null +++ b/core/atomic_components/wav2feat.py @@ -0,0 +1,110 @@ +import librosa +import numpy as np +import math + +from ..aux_models.hubert_stream import HubertStreaming + +""" +wavlm_cfg = { + "model_path": "", + "device": "cuda", + "force_ori_type": False, +} +hubert_cfg = { + "model_path": "", + "device": "cuda", + "force_ori_type": False, +} +""" + + +class Wav2Feat: + def __init__(self, w2f_cfg, w2f_type="hubert"): + self.w2f_type = w2f_type.lower() + if self.w2f_type == "hubert": + self.w2f = Wav2FeatHubert(hubert_cfg=w2f_cfg) + self.feat_dim = 1024 + self.support_streaming = True + else: + raise ValueError(f"Unsupported w2f_type: {w2f_type}") + + def __call__( + self, + audio, + sr=16000, + norm_mean_std=None, # for s2g + chunksize=(3, 5, 2), # for hubert + ): + if self.w2f_type == "hubert": + feat = self.w2f(audio, chunksize=chunksize) + elif self.w2f_type == "s2g": + feat = self.w2f(audio, sr=sr, norm_mean_std=norm_mean_std) + else: + raise ValueError(f"Unsupported w2f_type: {self.w2f_type}") + return feat + + def wav2feat( + self, + audio, + sr=16000, + norm_mean_std=None, # for s2g + chunksize=(3, 5, 2), + ): + # for offline + if self.w2f_type == "hubert": + feat = self.w2f.wav2feat(audio, sr=sr, chunksize=chunksize) + elif self.w2f_type == "s2g": + feat = self.w2f(audio, sr=sr, norm_mean_std=norm_mean_std) + else: + raise ValueError(f"Unsupported w2f_type: {self.w2f_type}") + return feat + + +class Wav2FeatHubert: + def __init__( + self, + hubert_cfg, + ): + self.hubert = HubertStreaming(**hubert_cfg) + + def __call__(self, audio_chunk, chunksize=(3, 5, 2)): + """ + audio_chunk: int(sum(chunksize) * 0.04 * 16000) + 80 # 6480 + """ + valid_feat_s = - sum(chunksize[1:]) * 2 # -7 + valid_feat_e = - chunksize[2] * 2 # -2 + + encoding_chunk = self.hubert(audio_chunk) + valid_encoding = encoding_chunk[valid_feat_s:valid_feat_e] + valid_feat = valid_encoding.reshape(chunksize[1], 2, 1024).mean(1) # [5, 1024] + return valid_feat + + def wav2feat(self, audio, sr, chunksize=(3, 5, 2)): + # for offline + if sr != 16000: + audio_16k = librosa.resample(audio, orig_sr=sr, target_sr=16000) + else: + audio_16k = audio + + num_f = math.ceil(len(audio_16k) / 16000 * 25) + split_len = int(sum(chunksize) * 0.04 * 16000) + 80 # 6480 + + speech_pad = np.concatenate([ + np.zeros((split_len - int(sum(chunksize[1:]) * 0.04 * 16000),), dtype=audio_16k.dtype), + audio_16k, + np.zeros((split_len,), dtype=audio_16k.dtype), + ], 0) + + i = 0 + res_lst = [] + while i < num_f: + sss = int(i * 0.04 * 16000) + eee = sss + split_len + audio_chunk = speech_pad[sss:eee] + valid_feat = self.__call__(audio_chunk, chunksize) + res_lst.append(valid_feat) + i += chunksize[1] + + ret = np.concatenate(res_lst, 0) + ret = ret[:num_f] + return ret diff --git a/core/atomic_components/writer.py b/core/atomic_components/writer.py new file mode 100644 index 0000000000000000000000000000000000000000..4c075b5a65fc75b992ebd1addf31cce81387eb5c --- /dev/null +++ b/core/atomic_components/writer.py @@ -0,0 +1,36 @@ +import imageio +import os + + +class VideoWriterByImageIO: + def __init__(self, video_path, fps=25, **kwargs): + video_format = kwargs.get("format", "mp4") # default is mp4 format + codec = kwargs.get("vcodec", "libx264") # default is libx264 encoding + quality = kwargs.get("quality") # video quality + pixelformat = kwargs.get("pixelformat", "yuv420p") # video pixel format + macro_block_size = kwargs.get("macro_block_size", 2) + ffmpeg_params = ["-crf", str(kwargs.get("crf", 18))] + + os.makedirs(os.path.dirname(video_path), exist_ok=True) + + writer = imageio.get_writer( + video_path, + fps=fps, + format=video_format, + codec=codec, + quality=quality, + ffmpeg_params=ffmpeg_params, + pixelformat=pixelformat, + macro_block_size=macro_block_size, + ) + self.writer = writer + + def __call__(self, img, fmt="bgr"): + if fmt == "bgr": + frame = img[..., ::-1] + else: + frame = img + self.writer.append_data(frame) + + def close(self): + self.writer.close() diff --git a/core/aux_models/blaze_face.py b/core/aux_models/blaze_face.py new file mode 100644 index 0000000000000000000000000000000000000000..508a5907d01f0ca30d51a7fc6c3a8a4a5ee93606 --- /dev/null +++ b/core/aux_models/blaze_face.py @@ -0,0 +1,351 @@ +import numpy as np +import cv2 +from ..utils.load_model import load_model + + +def intersect(box_a, box_b): + """We resize both arrays to [A,B,2] without new malloc: + [A,2] -> [A,1,2] -> [A,B,2] + [B,2] -> [1,B,2] -> [A,B,2] + Then we compute the area of intersect between box_a and box_b. + Args: + box_a: (array) bounding boxes, Shape: [A,4]. + box_b: (array) bounding boxes, Shape: [B,4]. + Return: + (array) intersection area, Shape: [A,B]. + """ + A = box_a.shape[0] + B = box_b.shape[0] + max_xy = np.minimum( + np.expand_dims(box_a[:, 2:], axis=1).repeat(B, axis=1), + np.expand_dims(box_b[:, 2:], axis=0).repeat(A, axis=0), + ) + min_xy = np.maximum( + np.expand_dims(box_a[:, :2], axis=1).repeat(B, axis=1), + np.expand_dims(box_b[:, :2], axis=0).repeat(A, axis=0), + ) + inter = np.clip((max_xy - min_xy), a_min=0, a_max=None) + return inter[:, :, 0] * inter[:, :, 1] + + +def jaccard(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes. The jaccard overlap + is simply the intersection over union of two boxes. Here we operate on + ground truth boxes and default boxes. + E.g.: + A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) + Args: + box_a: (array) Ground truth bounding boxes, Shape: [num_objects,4] + box_b: (array) Prior boxes from priorbox layers, Shape: [num_priors,4] + Return: + jaccard overlap: (array) Shape: [box_a.size(0), box_b.size(0)] + """ + inter = intersect(box_a, box_b) + area_a = ( + ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])) + .reshape(-1, 1) + .repeat(box_b.shape[0], axis=1) + ) # [A,B] + area_b = ( + ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])) + .reshape(1, -1) + .repeat(box_a.shape[0], axis=0) + ) # [A,B] + union = area_a + area_b - inter + return inter / union # [A,B] + + +def overlap_similarity(box, other_boxes): + """Computes the IOU between a bounding box and set of other boxes.""" + box = np.expand_dims(box, axis=0) # Equivalent to unsqueeze(0) in PyTorch + iou = jaccard(box, other_boxes) + return np.squeeze(iou, axis=0) # Equivalent to squeeze(0) in PyTorch + + +class BlazeFace: + def __init__(self, model_path, device="cuda"): + self.anchor_options = { + "num_layers": 4, + "min_scale": 0.1484375, + "max_scale": 0.75, + "input_size_height": 128, + "input_size_width": 128, + "anchor_offset_x": 0.5, + "anchor_offset_y": 0.5, + "strides": [8, 16, 16, 16], + "aspect_ratios": [1.0], + "reduce_boxes_in_lowest_layer": False, + "interpolated_scale_aspect_ratio": 1.0, + "fixed_anchor_size": True, + } + self.num_classes = 1 + self.num_anchors = 896 + self.num_coords = 16 + self.x_scale = 128.0 + self.y_scale = 128.0 + self.h_scale = 128.0 + self.w_scale = 128.0 + self.min_score_thresh = 0.5 + self.min_suppression_threshold = 0.3 + self.anchors = self.generate_anchors(self.anchor_options) + self.anchors = np.array(self.anchors) + assert len(self.anchors) == 896 + self.model, self.model_type = load_model(model_path, device=device) + self.output_names = ["regressors", "classificators"] + + def __call__(self, image: np.ndarray): + """ + image: RGB image + """ + image = cv2.resize(image, (128, 128)) + image = image[np.newaxis, :, :, :].astype(np.float32) + image = image / 127.5 - 1.0 + outputs = {} + if self.model_type == "onnx": + out_list = self.model.run(None, {"input": image}) + for i, name in enumerate(self.output_names): + outputs[name] = out_list[i] + elif self.model_type == "tensorrt": + self.model.setup({"input": image}) + self.model.infer() + for name in self.output_names: + outputs[name] = self.model.buffer[name][0] + else: + raise ValueError(f"Unsupported model type: {self.model_type}") + boxes = self.postprocess(outputs["regressors"], outputs["classificators"]) + return boxes + + def calculate_scale(self, min_scale, max_scale, stride_index, num_strides): + return min_scale + (max_scale - min_scale) * stride_index / (num_strides - 1.0) + + def generate_anchors(self, options): + strides_size = len(options["strides"]) + assert options["num_layers"] == strides_size + + anchors = [] + layer_id = 0 + while layer_id < strides_size: + anchor_height = [] + anchor_width = [] + aspect_ratios = [] + scales = [] + + # For same strides, we merge the anchors in the same order. + last_same_stride_layer = layer_id + while (last_same_stride_layer < strides_size) and ( + options["strides"][last_same_stride_layer] + == options["strides"][layer_id] + ): + scale = self.calculate_scale( + options["min_scale"], + options["max_scale"], + last_same_stride_layer, + strides_size, + ) + + if ( + last_same_stride_layer == 0 + and options["reduce_boxes_in_lowest_layer"] + ): + # For first layer, it can be specified to use predefined anchors. + aspect_ratios.append(1.0) + aspect_ratios.append(2.0) + aspect_ratios.append(0.5) + scales.append(0.1) + scales.append(scale) + scales.append(scale) + else: + for aspect_ratio in options["aspect_ratios"]: + aspect_ratios.append(aspect_ratio) + scales.append(scale) + + if options["interpolated_scale_aspect_ratio"] > 0.0: + scale_next = ( + 1.0 + if last_same_stride_layer == strides_size - 1 + else self.calculate_scale( + options["min_scale"], + options["max_scale"], + last_same_stride_layer + 1, + strides_size, + ) + ) + scales.append(np.sqrt(scale * scale_next)) + aspect_ratios.append(options["interpolated_scale_aspect_ratio"]) + + last_same_stride_layer += 1 + + for i in range(len(aspect_ratios)): + ratio_sqrts = np.sqrt(aspect_ratios[i]) + anchor_height.append(scales[i] / ratio_sqrts) + anchor_width.append(scales[i] * ratio_sqrts) + + stride = options["strides"][layer_id] + feature_map_height = int(np.ceil(options["input_size_height"] / stride)) + feature_map_width = int(np.ceil(options["input_size_width"] / stride)) + + for y in range(feature_map_height): + for x in range(feature_map_width): + for anchor_id in range(len(anchor_height)): + x_center = (x + options["anchor_offset_x"]) / feature_map_width + y_center = (y + options["anchor_offset_y"]) / feature_map_height + + new_anchor = [x_center, y_center, 0, 0] + if options["fixed_anchor_size"]: + new_anchor[2] = 1.0 + new_anchor[3] = 1.0 + else: + new_anchor[2] = anchor_width[anchor_id] + new_anchor[3] = anchor_height[anchor_id] + anchors.append(new_anchor) + + layer_id = last_same_stride_layer + + return anchors + + def _tensors_to_detections(self, raw_box_tensor, raw_score_tensor, anchors): + """The output of the neural network is a tensor of shape (b, 896, 16) + containing the bounding box regressor predictions, as well as a tensor + of shape (b, 896, 1) with the classification confidences. + + This function converts these two "raw" tensors into proper detections. + Returns a list of (num_detections, 17) tensors, one for each image in + the batch. + + This is based on the source code from: + mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc + mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.proto + """ + assert raw_box_tensor.ndim == 3 + assert raw_box_tensor.shape[1] == self.num_anchors + assert raw_box_tensor.shape[2] == self.num_coords + + assert raw_score_tensor.ndim == 3 + assert raw_score_tensor.shape[1] == self.num_anchors + assert raw_score_tensor.shape[2] == self.num_classes + + assert raw_box_tensor.shape[0] == raw_score_tensor.shape[0] + + detection_boxes = self._decode_boxes(raw_box_tensor, anchors) + + raw_score_tensor = np.clip(raw_score_tensor, -50, 100) + detection_scores = 1 / (1 + np.exp(-raw_score_tensor)) + mask = detection_scores >= self.min_score_thresh + mask = mask[0, :, 0] + boxes = detection_boxes[0, mask, :] + scores = detection_scores[0, mask, :] + return np.concatenate((boxes, scores), axis=-1) + + def _decode_boxes(self, raw_boxes, anchors): + """Converts the predictions into actual coordinates using + the anchor boxes. Processes the entire batch at once. + """ + boxes = np.zeros_like(raw_boxes) + + x_center = raw_boxes[..., 0] / self.x_scale * anchors[:, 2] + anchors[:, 0] + y_center = raw_boxes[..., 1] / self.y_scale * anchors[:, 3] + anchors[:, 1] + + w = raw_boxes[..., 2] / self.w_scale * anchors[:, 2] + h = raw_boxes[..., 3] / self.h_scale * anchors[:, 3] + + boxes[..., 0] = self.x_scale * (x_center - w / 2.0) # xmin + boxes[..., 1] = self.y_scale * (y_center - h / 2.0) # ymin + boxes[..., 2] = self.w_scale * (x_center + w / 2.0) # xmax + boxes[..., 3] = self.h_scale * (y_center + h / 2.0) # ymax + + for k in range(6): + offset = 4 + k * 2 + keypoint_x = ( + raw_boxes[..., offset] / self.x_scale * anchors[:, 2] + anchors[:, 0] + ) + keypoint_y = ( + raw_boxes[..., offset + 1] / self.y_scale * anchors[:, 3] + + anchors[:, 1] + ) + boxes[..., offset] = keypoint_x + boxes[..., offset + 1] = keypoint_y + + return boxes + + def _weighted_non_max_suppression(self, detections): + """The alternative NMS method as mentioned in the BlazeFace paper: + + "We replace the suppression algorithm with a blending strategy that + estimates the regression parameters of a bounding box as a weighted + mean between the overlapping predictions." + + The original MediaPipe code assigns the score of the most confident + detection to the weighted detection, but we take the average score + of the overlapping detections. + + The input detections should be a NumPy array of shape (count, 17). + + Returns a list of NumPy arrays, one for each detected face. + + This is based on the source code from: + mediapipe/calculators/util/non_max_suppression_calculator.cc + mediapipe/calculators/util/non_max_suppression_calculator.proto + """ + if len(detections) == 0: + return [] + + output_detections = [] + + # Sort the detections from highest to lowest score. + remaining = np.argsort(detections[:, 16])[::-1] + + while len(remaining) > 0: + detection = detections[remaining[0]] + + # Compute the overlap between the first box and the other + # remaining boxes. (Note that the other_boxes also include + # the first_box.) + first_box = detection[:4] + other_boxes = detections[remaining, :4] + ious = overlap_similarity(first_box, other_boxes) + + # If two detections don't overlap enough, they are considered + # to be from different faces. + mask = ious > self.min_suppression_threshold + overlapping = remaining[mask] + remaining = remaining[~mask] + + # Take an average of the coordinates from the overlapping + # detections, weighted by their confidence scores. + weighted_detection = detection.copy() + if len(overlapping) > 1: + coordinates = detections[overlapping, :16] + scores = detections[overlapping, 16:17] + total_score = scores.sum() + weighted = (coordinates * scores).sum(axis=0) / total_score + weighted_detection[:16] = weighted + weighted_detection[16] = total_score / len(overlapping) + + output_detections.append(weighted_detection) + + return output_detections + + def postprocess(self, raw_boxes, scores): + detections = self._tensors_to_detections(raw_boxes, scores, self.anchors) + + detections = self._weighted_non_max_suppression(detections) + detections = np.array(detections) + return detections + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="") + parser.add_argument("--image", type=str, default=None) + args = parser.parse_args() + + blaze_face = BlazeFace(args.model) + image = cv2.imread(args.image) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = cv2.resize(image, (128, 128)) + image = image[np.newaxis, :, :, :].astype(np.float32) + image = image / 127.5 - 1.0 + boxes = blaze_face(image) + print(boxes) diff --git a/core/aux_models/face_mesh.py b/core/aux_models/face_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..0557f572b5a89c13f5caa887aaf14308b1451b11 --- /dev/null +++ b/core/aux_models/face_mesh.py @@ -0,0 +1,101 @@ +import cv2 +import numpy as np + +from ..utils.load_model import load_model + + +class FaceMesh: + def __init__(self, model_path, device="cuda"): + self.model, self.model_type = load_model(model_path, device=device) + self.input_size = (256, 256) # (w, h) + self.output_names = [ + "Identity", + "Identity_1", + "Identity_2", + ] # Identity is the mesh + + def project_landmarks(self, points, roi): + width, height = self.input_size + points /= (width, height, width) + sin, cos = np.sin(roi[4]), np.cos(roi[4]) + matrix = np.array([[cos, sin, 0.0], [-sin, cos, 0.0], [1.0, 1.0, 1.0]]) + points -= (0.5, 0.5, 0.0) + rotated = np.matmul(points * (1, 1, 0), matrix) + points *= (0, 0, 1) + points += rotated + points *= (roi[2], roi[3], roi[2]) + points += (roi[0], roi[1], 0.0) + return points + + def __call__(self, image, roi): + """ + image: np.ndarray, RGB, (H, W, C), [0, 255] + roi: np.ndarray, (cx, cy, w, h, rotation), rotation in radian + """ + cx, cy, w, h = roi[:4] + w_half, h_half = w / 2, h / 2 + pts = [ + (cx - w_half, cy - h_half), + (cx + w_half, cy - h_half), + (cx + w_half, cy + h_half), + (cx - w_half, cy + h_half), + ] + rotation = roi[4] + s, c = np.sin(rotation), np.cos(rotation) + t = np.array(pts) - (cx, cy) + r = np.array([[c, s], [-s, c]]) + src_pts = np.matmul(t, r) + (cx, cy) + src_pts = src_pts.astype(np.float32) + + dst_pts = np.array( + [ + [0.0, 0.0], + [self.input_size[0], 0.0], + [self.input_size[0], self.input_size[1]], + [0.0, self.input_size[1]], + ] + ).astype(np.float32) + M = cv2.getPerspectiveTransform(src_pts, dst_pts) + roi_image = cv2.warpPerspective( + image, M, self.input_size, flags=cv2.INTER_LINEAR + ) + # cv2.imwrite('test.jpg', cv2.cvtColor(roi_image, cv2.COLOR_RGB2BGR)) + roi_image = roi_image / 255.0 + roi_image = roi_image.astype(np.float32) + roi_image = roi_image[np.newaxis, :, :, :] + + outputs = {} + if self.model_type == "onnx": + out_list = self.model.run(None, {"input": roi_image}) + for i, name in enumerate(self.output_names): + outputs[name] = out_list[i] + elif self.model_type == "tensorrt": + self.model.setup({"input": roi_image}) + self.model.infer() + for name in self.output_names: + outputs[name] = self.model.buffer[name][0] + else: + raise ValueError(f"Unsupported model type: {self.model_type}") + points = outputs["Identity"].reshape(1434 // 3, 3) + points = self.project_landmarks(points, roi) + return points + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, help="model path") + parser.add_argument("--image", type=str, help="image path") + parser.add_argument("--device", type=str, default="cuda", help="device") + args = parser.parse_args() + + face_mesh = FaceMesh(args.model, args.device) + image = cv2.imread(args.image, cv2.IMREAD_COLOR) + image = cv2.resize(image, (256, 256)) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + face_mesh = FaceMesh(args.model, args.device) + roi = np.array([128, 128, 256, 256, np.pi / 2]) + mesh = face_mesh(image, roi) + print(mesh.shape) diff --git a/core/aux_models/hubert_stream.py b/core/aux_models/hubert_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..30387c6df754d129d6f1815e131718e68069e411 --- /dev/null +++ b/core/aux_models/hubert_stream.py @@ -0,0 +1,29 @@ +from ..utils.load_model import load_model + + +class HubertStreaming: + def __init__(self, model_path, device="cuda", **kwargs): + kwargs["model_file"] = model_path + kwargs["module_name"] = "HubertStreamingONNX" + kwargs["package_name"] = "..aux_models.modules" + + self.model, self.model_type = load_model(model_path, device=device, **kwargs) + self.device = device + + def forward_chunk(self, audio_chunk): + if self.model_type == "onnx": + output = self.model.run(None, {"input_values": audio_chunk.reshape(1, -1)})[0] + elif self.model_type == "tensorrt": + self.model.setup({"input_values": audio_chunk.reshape(1, -1)}) + self.model.infer() + output = self.model.buffer["encoding_out"][0] + else: + raise ValueError(f"Unsupported model type: {self.model_type}") + return output + + def __call__(self, audio_chunk): + if self.model_type == "ori": + output = self.model.forward_chunk(audio_chunk) + else: + output = self.forward_chunk(audio_chunk) + return output diff --git a/core/aux_models/insightface_det.py b/core/aux_models/insightface_det.py new file mode 100644 index 0000000000000000000000000000000000000000..10a6774b2dd69ee0dfecc4fb3b74a04093c7f5ba --- /dev/null +++ b/core/aux_models/insightface_det.py @@ -0,0 +1,245 @@ +from __future__ import division +import numpy as np +import cv2 + +from ..utils.load_model import load_model + + +def distance2bbox(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + x1 = points[:, 0] - distance[:, 0] + y1 = points[:, 1] - distance[:, 1] + x2 = points[:, 0] + distance[:, 2] + y2 = points[:, 1] + distance[:, 3] + if max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) + return np.stack([x1, y1, x2, y2], axis=-1) + + +def distance2kps(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + preds = [] + for i in range(0, distance.shape[1], 2): + px = points[:, i%2] + distance[:, i] + py = points[:, i%2+1] + distance[:, i+1] + if max_shape is not None: + px = px.clamp(min=0, max=max_shape[1]) + py = py.clamp(min=0, max=max_shape[0]) + preds.append(px) + preds.append(py) + return np.stack(preds, axis=-1) + + +class InsightFaceDet: + def __init__(self, model_path, device="cuda", **kwargs): + kwargs["model_file"] = model_path + kwargs["module_name"] = "RetinaFace" + kwargs["package_name"] = "..aux_models.modules" + + self.model, self.model_type = load_model(model_path, device=device, **kwargs) + self.device = device + + if self.model_type != "ori": + self._init_vars() + + def _init_vars(self): + self.center_cache = {} + + self.nms_thresh = 0.4 + self.det_thresh = 0.5 + + self.input_size = (512, 512) + self.input_mean = 127.5 + self.input_std = 128.0 + self._anchor_ratio = 1.0 + self.fmc = 3 + self._feat_stride_fpn = [8, 16, 32] + self._num_anchors = 2 + self.use_kps = True + + self.output_names = [ + "scores1", + "scores2", + "scores3", + "boxes1", + "boxes2", + "boxes3", + "kps1", + "kps2", + "kps3", + ] + + def _run_model(self, blob): + if self.model_type == "onnx": + net_outs = self.model.run(None, {"image": blob}) + elif self.model_type == "tensorrt": + self.model.setup({"image": blob}) + self.model.infer() + net_outs = [self.model.buffer[name][0] for name in self.output_names] + else: + raise ValueError(f"Unsupported model type: {self.model_type}") + return net_outs + + def _forward(self, img, threshold): + """ + img: np.ndarray, shape (h, w, 3) + """ + scores_list = [] + bboxes_list = [] + kpss_list = [] + input_size = tuple(img.shape[0:2][::-1]) + blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + # (1, 3, 512, 512) + net_outs = self._run_model(blob) + + input_height = blob.shape[2] + input_width = blob.shape[3] + fmc = self.fmc + for idx, stride in enumerate(self._feat_stride_fpn): + scores = net_outs[idx] + bbox_preds = net_outs[idx+fmc] + bbox_preds = bbox_preds * stride + if self.use_kps: + kps_preds = net_outs[idx+fmc*2] * stride + height = input_height // stride + width = input_width // stride + # K = height * width + key = (height, width, stride) + if key in self.center_cache: + anchor_centers = self.center_cache[key] + else: + #solution-3: + anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32) + anchor_centers = (anchor_centers * stride).reshape( (-1, 2) ) + if self._num_anchors>1: + anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) ) + if len(self.center_cache)<100: + self.center_cache[key] = anchor_centers + + pos_inds = np.where(scores>=threshold)[0] + bboxes = distance2bbox(anchor_centers, bbox_preds) + pos_scores = scores[pos_inds] + pos_bboxes = bboxes[pos_inds] + scores_list.append(pos_scores) + bboxes_list.append(pos_bboxes) + if self.use_kps: + kpss = distance2kps(anchor_centers, kps_preds) + kpss = kpss.reshape( (kpss.shape[0], -1, 2) ) + pos_kpss = kpss[pos_inds] + kpss_list.append(pos_kpss) + return scores_list, bboxes_list, kpss_list + + def detect(self, img, input_size=None, max_num=0, metric='default', det_thresh=None): + input_size = self.input_size if input_size is None else input_size + det_thresh = self.det_thresh if det_thresh is None else det_thresh + + im_ratio = float(img.shape[0]) / img.shape[1] + model_ratio = float(input_size[1]) / input_size[0] + if im_ratio>model_ratio: + new_height = input_size[1] + new_width = int(new_height / im_ratio) + else: + new_width = input_size[0] + new_height = int(new_width * im_ratio) + det_scale = float(new_height) / img.shape[0] + resized_img = cv2.resize(img, (new_width, new_height)) + det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 ) + det_img[:new_height, :new_width, :] = resized_img + + scores_list, bboxes_list, kpss_list = self._forward(det_img, det_thresh) + + scores = np.vstack(scores_list) + scores_ravel = scores.ravel() + order = scores_ravel.argsort()[::-1] + bboxes = np.vstack(bboxes_list) / det_scale + if self.use_kps: + kpss = np.vstack(kpss_list) / det_scale + pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False) + pre_det = pre_det[order, :] + keep = self.nms(pre_det) + det = pre_det[keep, :] + if self.use_kps: + kpss = kpss[order,:,:] + kpss = kpss[keep,:,:] + else: + kpss = None + if max_num > 0 and det.shape[0] > max_num: + area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1]) + img_center = img.shape[0] // 2, img.shape[1] // 2 + offsets = np.vstack([ + (det[:, 0] + det[:, 2]) / 2 - img_center[1], + (det[:, 1] + det[:, 3]) / 2 - img_center[0] + ]) + offset_dist_squared = np.sum(np.power(offsets, 2.0), 0) + if metric=='max': + values = area + else: + values = area - offset_dist_squared * 2.0 # some extra weight on the centering + bindex = np.argsort(values)[::-1] # some extra weight on the centering + bindex = bindex[0:max_num] + det = det[bindex, :] + if kpss is not None: + kpss = kpss[bindex, :] + return det, kpss + + def nms(self, dets): + thresh = self.nms_thresh + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + + def __call__(self, img, **kwargs): + if self.model_type == "ori": + det, kpss = self.model.detect(img, **kwargs) + else: + det, kpss = self.detect(img, **kwargs) + + return det, kpss + diff --git a/core/aux_models/insightface_landmark106.py b/core/aux_models/insightface_landmark106.py new file mode 100644 index 0000000000000000000000000000000000000000..e2a701a628cb848fd36218d1bc5631c627bd7f9f --- /dev/null +++ b/core/aux_models/insightface_landmark106.py @@ -0,0 +1,100 @@ +from __future__ import division +import numpy as np +import torch +import cv2 +from skimage import transform as trans + +from ..utils.load_model import load_model + + +def transform(data, center, output_size, scale, rotation): + scale_ratio = scale + rot = float(rotation) * np.pi / 180.0 + + t1 = trans.SimilarityTransform(scale=scale_ratio) + cx = center[0] * scale_ratio + cy = center[1] * scale_ratio + t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy)) + t3 = trans.SimilarityTransform(rotation=rot) + t4 = trans.SimilarityTransform(translation=(output_size / 2, + output_size / 2)) + t = t1 + t2 + t3 + t4 + M = t.params[0:2] + cropped = cv2.warpAffine(data, + M, (output_size, output_size), + borderValue=0.0) + return cropped, M + + +def trans_points2d(pts, M): + new_pts = np.zeros(shape=pts.shape, dtype=np.float32) + for i in range(pts.shape[0]): + pt = pts[i] + new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) + new_pt = np.dot(M, new_pt) + new_pts[i] = new_pt[0:2] + + return new_pts + + +class Landmark106: + def __init__(self, model_path, device="cuda", **kwargs): + kwargs["model_file"] = model_path + kwargs["module_name"] = "Landmark106" + kwargs["package_name"] = "..aux_models.modules" + + self.model, self.model_type = load_model(model_path, device=device, **kwargs) + self.device = device + + if self.model_type != "ori": + self._init_vars() + + def _init_vars(self): + self.input_mean = 0.0 + self.input_std = 1.0 + self.input_size = (192, 192) + self.lmk_num = 106 + + self.output_names = ["fc1"] + + def _run_model(self, blob): + if self.model_type == "onnx": + pred = self.model.run(None, {"data": blob})[0] + elif self.model_type == "tensorrt": + self.model.setup({"data": blob}) + self.model.infer() + pred = self.model.buffer[self.output_names[0]][0] + else: + raise ValueError(f"Unsupported model type: {self.model_type}") + return pred + + def get(self, img, bbox): + w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1]) + center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2 + rotate = 0 + _scale = self.input_size[0] / (max(w, h)*1.5) + + aimg, M = transform(img, center, self.input_size[0], _scale, rotate) + input_size = tuple(aimg.shape[0:2][::-1]) + + blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + + pred = self._run_model(blob) + + pred = pred.reshape((-1, 2)) + if self.lmk_num < pred.shape[0]: + pred = pred[self.lmk_num*-1:,:] + pred[:, 0:2] += 1 + pred[:, 0:2] *= (self.input_size[0] // 2) + + IM = cv2.invertAffineTransform(M) + pred = trans_points2d(pred, IM) + return pred + + def __call__(self, img, bbox): + if self.model_type == "ori": + pred = self.model.get(img, bbox) + else: + pred = self.get(img, bbox) + + return pred diff --git a/core/aux_models/landmark203.py b/core/aux_models/landmark203.py new file mode 100644 index 0000000000000000000000000000000000000000..72f6ef1898e0d2ed54d9b68b71ba72a1be0a6da2 --- /dev/null +++ b/core/aux_models/landmark203.py @@ -0,0 +1,58 @@ +import numpy as np +from ..utils.load_model import load_model + + +def _transform_pts(pts, M): + """ conduct similarity or affine transformation to the pts + pts: Nx2 ndarray + M: 2x3 matrix or 3x3 matrix + return: Nx2 + """ + return pts @ M[:2, :2].T + M[:2, 2] + + +class Landmark203: + def __init__(self, model_path, device="cuda", **kwargs): + kwargs["model_file"] = model_path + kwargs["module_name"] = "Landmark203" + kwargs["package_name"] = "..aux_models.modules" + + self.model, self.model_type = load_model(model_path, device=device, **kwargs) + self.device = device + + self.output_names = ["landmarks"] + self.dsize = 224 + + def _run_model(self, inp): + if self.model_type == "onnx": + out_pts = self.model.run(None, {"input": inp})[0] + elif self.model_type == "tensorrt": + self.model.setup({"input": inp}) + self.model.infer() + out_pts = self.model.buffer[self.output_names[0]][0] + else: + raise ValueError(f"Unsupported model type: {self.model_type}") + return out_pts + + def run(self, img_crop_rgb, M_c2o=None): + # img_crop_rgb: 224x224 + + inp = (img_crop_rgb.astype(np.float32) / 255.).transpose(2, 0, 1)[None, ...] # HxWx3 (BGR) -> 1x3xHxW (RGB!) + + out_pts = self._run_model(inp) + + # 2d landmarks 203 points + lmk = out_pts[0].reshape(-1, 2) * self.dsize # scale to 0-224 + if M_c2o is not None: + lmk = _transform_pts(lmk, M=M_c2o) + + return lmk + + def __call__(self, img_crop_rgb, M_c2o=None): + if self.model_type == "ori": + lmk = self.model.run(img_crop_rgb, M_c2o) + else: + lmk = self.run(img_crop_rgb, M_c2o) + + return lmk + \ No newline at end of file diff --git a/core/aux_models/mediapipe_landmark478.py b/core/aux_models/mediapipe_landmark478.py new file mode 100644 index 0000000000000000000000000000000000000000..305545b25cef270b5f00e453f13ef4b29a12faff --- /dev/null +++ b/core/aux_models/mediapipe_landmark478.py @@ -0,0 +1,118 @@ +from enum import Enum +import numpy as np + +from ..utils.load_model import load_model +from .blaze_face import BlazeFace +from .face_mesh import FaceMesh + + +class SizeMode(Enum): + DEFAULT = 0 + SQUARE_LONG = 1 + SQUARE_SHORT = 2 + + +def _select_roi_size( + bbox: np.ndarray, image_size, size_mode: SizeMode # x1, y1, x2, y2 # w,h +): + """Return the size of an ROI based on bounding box, image size and mode""" + width, height = bbox[2] - bbox[0], bbox[3] - bbox[1] + image_width, image_height = image_size + if size_mode == SizeMode.SQUARE_LONG: + long_size = max(width, height) + width, height = long_size, long_size + elif size_mode == SizeMode.SQUARE_SHORT: + short_side = min(width, height) + width, height = short_side, short_side + return width, height + + +def bbox_to_roi( + bbox: np.ndarray, + image_size, # w,h + rotation_keypoints=None, + scale=(1.0, 1.0), # w, h + size_mode: SizeMode = SizeMode.SQUARE_LONG, +): + PI = np.pi + TWO_PI = 2 * np.pi + # select ROI dimensions + width, height = _select_roi_size(bbox, image_size, size_mode) + scale_x, scale_y = scale + # calculate ROI size and -centre + width, height = width * scale_x, height * scale_y + cx = (bbox[0] + bbox[2]) / 2 + cy = (bbox[1] + bbox[3]) / 2 + # calculate rotation of required + if rotation_keypoints is None or len(rotation_keypoints) < 2: + return np.array([cx, cy, width, height, 0]) + x0, y0 = rotation_keypoints[0] + x1, y1 = rotation_keypoints[1] + angle = -np.atan2(y0 - y1, x1 - x0) + # normalise to [0, 2*PI] + rotation = angle - TWO_PI * np.floor((angle + PI) / TWO_PI) + return np.array([cx, cy, width, height, rotation]) + + +class Landmark478: + def __init__(self, blaze_face_model_path="", face_mesh_model_path="", device="cuda", **kwargs): + if kwargs.get("force_ori_type", False): + assert "task_path" in kwargs + kwargs["module_name"] = "Landmark478" + kwargs["package_name"] = "..aux_models.modules" + self.model, self.model_type = load_model("", device=device, **kwargs) + else: + self.blaze_face = BlazeFace(blaze_face_model_path, device) + self.face_mesh = FaceMesh(face_mesh_model_path, device) + self.model_type = "" + + def get(self, image): + bboxes = self.blaze_face(image) + if len(bboxes) == 0: + return None + bbox = bboxes[0] + scale = (image.shape[1] / 128.0, image.shape[0] / 128.0) + + # The first 4 numbers describe the bounding box corners: + # + # ymin, xmin, ymax, xmax + # These are normalized coordinates (between 0 and 1). + # The next 12 numbers are the x,y-coordinates of the 6 facial landmark keypoints: + # + # right_eye_x, right_eye_y + # left_eye_x, left_eye_y + # nose_x, nose_y + # mouth_x, mouth_y + # right_ear_x, right_ear_y + # left_ear_x, left_ear_y + # Tip: these labeled as seen from the perspective of the person, so their right is your left. + # The final number is the confidence score that this detection really is a face. + + bbox[0] = bbox[0] * scale[1] + bbox[1] = bbox[1] * scale[0] + bbox[2] = bbox[2] * scale[1] + bbox[3] = bbox[3] * scale[0] + left_eye = (bbox[4], bbox[5]) + right_eye = (bbox[6], bbox[7]) + + roi = bbox_to_roi( + bbox, + (image.shape[1], image.shape[0]), + rotation_keypoints=[left_eye, right_eye], + scale=(1.5, 1.5), + size_mode=SizeMode.SQUARE_LONG, + ) + + mesh = self.face_mesh(image, roi) + mesh = mesh / (image.shape[1], image.shape[0], image.shape[1]) + return mesh + + def __call__(self, image): + if self.model_type == "ori": + det = self.model.detect_from_npimage(image.copy()) + lmk = self.model.mplmk_to_nplmk(det) + return lmk + else: + lmk = self.get(image) + lmk = lmk.reshape(1, -1, 3).astype(np.float32) + return lmk diff --git a/core/aux_models/modules/__init__.py b/core/aux_models/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2542938c6fcffcc75d908fbe3c9d1c547fcadec4 --- /dev/null +++ b/core/aux_models/modules/__init__.py @@ -0,0 +1,5 @@ +from .retinaface import RetinaFace +from .landmark106 import Landmark106 +from .landmark203 import Landmark203 +from .landmark478 import Landmark478 +from .hubert_stream import HubertStreamingONNX \ No newline at end of file diff --git a/core/aux_models/modules/hubert_stream.py b/core/aux_models/modules/hubert_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..04f0060f3a6934fb273c357e3a69b2263f4d67ab --- /dev/null +++ b/core/aux_models/modules/hubert_stream.py @@ -0,0 +1,21 @@ + +import onnxruntime + + +class HubertStreamingONNX: + def __init__(self, model_file, device="cuda"): + if device == "cuda": + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + + self.session = onnxruntime.InferenceSession(model_file, providers=providers) + + def forward_chunk(self, input_values): + encoding_out = self.session.run( + None, + {"input_values": input_values.reshape(1, -1)} + )[0] + return encoding_out + + diff --git a/core/aux_models/modules/landmark106.py b/core/aux_models/modules/landmark106.py new file mode 100644 index 0000000000000000000000000000000000000000..cc6210dabed2c7493775c3483bb7a6e861130fb8 --- /dev/null +++ b/core/aux_models/modules/landmark106.py @@ -0,0 +1,83 @@ +# insightface +from __future__ import division +import onnxruntime +import cv2 +import numpy as np +from skimage import transform as trans + + +def transform(data, center, output_size, scale, rotation): + scale_ratio = scale + rot = float(rotation) * np.pi / 180.0 + + t1 = trans.SimilarityTransform(scale=scale_ratio) + cx = center[0] * scale_ratio + cy = center[1] * scale_ratio + t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy)) + t3 = trans.SimilarityTransform(rotation=rot) + t4 = trans.SimilarityTransform(translation=(output_size / 2, + output_size / 2)) + t = t1 + t2 + t3 + t4 + M = t.params[0:2] + cropped = cv2.warpAffine(data, + M, (output_size, output_size), + borderValue=0.0) + return cropped, M + + +def trans_points2d(pts, M): + new_pts = np.zeros(shape=pts.shape, dtype=np.float32) + for i in range(pts.shape[0]): + pt = pts[i] + new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) + new_pt = np.dot(M, new_pt) + new_pts[i] = new_pt[0:2] + + return new_pts + + + +class Landmark106: + def __init__(self, model_file, device="cuda"): + if device == "cuda": + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + self.session = onnxruntime.InferenceSession(model_file, providers=providers) + + self.input_mean = 0.0 + self.input_std = 1.0 + self.input_size = (192, 192) + input_cfg = self.session.get_inputs()[0] + input_name = input_cfg.name + outputs = self.session.get_outputs() + output_names = [] + for out in outputs: + output_names.append(out.name) + self.input_name = input_name + self.output_names = output_names + self.lmk_num = 106 + + def get(self, img, bbox): + w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1]) + center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2 + rotate = 0 + _scale = self.input_size[0] / (max(w, h)*1.5) + + aimg, M = transform(img, center, self.input_size[0], _scale, rotate) + input_size = tuple(aimg.shape[0:2][::-1]) + + blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + + pred = self.session.run(self.output_names, {self.input_name : blob})[0][0] + + pred = pred.reshape((-1, 2)) + if self.lmk_num < pred.shape[0]: + pred = pred[self.lmk_num*-1:,:] + pred[:, 0:2] += 1 + pred[:, 0:2] *= (self.input_size[0] // 2) + + IM = cv2.invertAffineTransform(M) + pred = trans_points2d(pred, IM) + return pred + diff --git a/core/aux_models/modules/landmark203.py b/core/aux_models/modules/landmark203.py new file mode 100644 index 0000000000000000000000000000000000000000..0ff41f1f33e5b155e86113802fa8490ba30eb3eb --- /dev/null +++ b/core/aux_models/modules/landmark203.py @@ -0,0 +1,42 @@ +import onnxruntime +import numpy as np + + +def _transform_pts(pts, M): + """ conduct similarity or affine transformation to the pts + pts: Nx2 ndarray + M: 2x3 matrix or 3x3 matrix + return: Nx2 + """ + return pts @ M[:2, :2].T + M[:2, 2] + + +class Landmark203: + def __init__(self, model_file, device="cuda"): + if device == "cuda": + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + self.session = onnxruntime.InferenceSession(model_file, providers=providers) + + self.dsize = 224 + + def _run(self, inp): + out = self.session.run(None, {'input': inp}) + return out + + def run(self, img_crop_rgb, M_c2o=None): + # img_crop_rgb: 224x224 + + inp = (img_crop_rgb.astype(np.float32) / 255.).transpose(2, 0, 1)[None, ...] # HxWx3 (BGR) -> 1x3xHxW (RGB!) + + out_lst = self._run(inp) + out_pts = out_lst[2] + + # 2d landmarks 203 points + lmk = out_pts[0].reshape(-1, 2) * self.dsize # scale to 0-224 + if M_c2o is not None: + lmk = _transform_pts(lmk, M=M_c2o) + + return lmk + diff --git a/core/aux_models/modules/landmark478.py b/core/aux_models/modules/landmark478.py new file mode 100644 index 0000000000000000000000000000000000000000..2614a02b9eda7a09f97646dcb5a1fe218c3f5df0 --- /dev/null +++ b/core/aux_models/modules/landmark478.py @@ -0,0 +1,35 @@ +import numpy as np +import mediapipe as mp +from mediapipe.tasks.python import vision, BaseOptions + + +class Landmark478: + def __init__(self, task_path): + base_options = BaseOptions(model_asset_path=task_path) + options = vision.FaceLandmarkerOptions( + base_options=base_options, + output_face_blendshapes=True, + output_facial_transformation_matrixes=True, + num_faces=1, + ) + detector = vision.FaceLandmarker.create_from_options(options) + self.detector = detector + + def detect_from_imp(self, imp): + image = mp.Image.create_from_file(imp) + detection_result = self.detector.detect(image) + return detection_result + + def detect_from_npimage(self, img): + image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img) + detection_result = self.detector.detect(image) + return detection_result + + @staticmethod + def mplmk_to_nplmk(results): + face_landmarks_list = results.face_landmarks + np_lms = [] + for face_lms in face_landmarks_list: + lms = [[lm.x, lm.y, lm.z] for lm in face_lms] + np_lms.append(lms) + return np.array(np_lms).astype(np.float32) diff --git a/core/aux_models/modules/retinaface.py b/core/aux_models/modules/retinaface.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5f33300d0f535c1d715e48ee860010f2451320 --- /dev/null +++ b/core/aux_models/modules/retinaface.py @@ -0,0 +1,215 @@ +# insightface +from __future__ import division +import onnxruntime +import cv2 +import numpy as np + + +def distance2bbox(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + x1 = points[:, 0] - distance[:, 0] + y1 = points[:, 1] - distance[:, 1] + x2 = points[:, 0] + distance[:, 2] + y2 = points[:, 1] + distance[:, 3] + if max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) + return np.stack([x1, y1, x2, y2], axis=-1) + + +def distance2kps(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + preds = [] + for i in range(0, distance.shape[1], 2): + px = points[:, i%2] + distance[:, i] + py = points[:, i%2+1] + distance[:, i+1] + if max_shape is not None: + px = px.clamp(min=0, max=max_shape[1]) + py = py.clamp(min=0, max=max_shape[0]) + preds.append(px) + preds.append(py) + return np.stack(preds, axis=-1) + + +class RetinaFace: + def __init__(self, model_file, device="cuda"): + if device == "cuda": + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + self.session = onnxruntime.InferenceSession(model_file, providers=providers) + + self.center_cache = {} + self.nms_thresh = 0.4 + self.det_thresh = 0.5 + self._init_vars() + + def _init_vars(self): + self.input_size = (512, 512) + input_cfg = self.session.get_inputs()[0] + input_name = input_cfg.name + outputs = self.session.get_outputs() + output_names = [] + for o in outputs: + output_names.append(o.name) + self.input_name = input_name + self.output_names = output_names + self.input_mean = 127.5 + self.input_std = 128.0 + self._anchor_ratio = 1.0 + self.fmc = 3 + self._feat_stride_fpn = [8, 16, 32] + self._num_anchors = 2 + self.use_kps = True + + def forward(self, img, threshold): + scores_list = [] + bboxes_list = [] + kpss_list = [] + input_size = tuple(img.shape[0:2][::-1]) + blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + net_outs = self.session.run(self.output_names, {self.input_name : blob}) + + input_height = blob.shape[2] + input_width = blob.shape[3] + fmc = self.fmc + for idx, stride in enumerate(self._feat_stride_fpn): + scores = net_outs[idx] + bbox_preds = net_outs[idx+fmc] + bbox_preds = bbox_preds * stride + if self.use_kps: + kps_preds = net_outs[idx+fmc*2] * stride + height = input_height // stride + width = input_width // stride + # K = height * width + key = (height, width, stride) + if key in self.center_cache: + anchor_centers = self.center_cache[key] + else: + #solution-3: + anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32) + anchor_centers = (anchor_centers * stride).reshape( (-1, 2) ) + if self._num_anchors>1: + anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) ) + if len(self.center_cache)<100: + self.center_cache[key] = anchor_centers + + pos_inds = np.where(scores>=threshold)[0] + bboxes = distance2bbox(anchor_centers, bbox_preds) + pos_scores = scores[pos_inds] + pos_bboxes = bboxes[pos_inds] + scores_list.append(pos_scores) + bboxes_list.append(pos_bboxes) + if self.use_kps: + kpss = distance2kps(anchor_centers, kps_preds) + kpss = kpss.reshape( (kpss.shape[0], -1, 2) ) + pos_kpss = kpss[pos_inds] + kpss_list.append(pos_kpss) + return scores_list, bboxes_list, kpss_list + + + def detect(self, img, input_size=None, max_num=0, metric='default', det_thresh=None): + input_size = self.input_size if input_size is None else input_size + det_thresh = self.det_thresh if det_thresh is None else det_thresh + + im_ratio = float(img.shape[0]) / img.shape[1] + model_ratio = float(input_size[1]) / input_size[0] + if im_ratio>model_ratio: + new_height = input_size[1] + new_width = int(new_height / im_ratio) + else: + new_width = input_size[0] + new_height = int(new_width * im_ratio) + det_scale = float(new_height) / img.shape[0] + resized_img = cv2.resize(img, (new_width, new_height)) + det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 ) + det_img[:new_height, :new_width, :] = resized_img + + scores_list, bboxes_list, kpss_list = self.forward(det_img, det_thresh) + + scores = np.vstack(scores_list) + scores_ravel = scores.ravel() + order = scores_ravel.argsort()[::-1] + bboxes = np.vstack(bboxes_list) / det_scale + if self.use_kps: + kpss = np.vstack(kpss_list) / det_scale + pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False) + pre_det = pre_det[order, :] + keep = self.nms(pre_det) + det = pre_det[keep, :] + if self.use_kps: + kpss = kpss[order,:,:] + kpss = kpss[keep,:,:] + else: + kpss = None + if max_num > 0 and det.shape[0] > max_num: + area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1]) + img_center = img.shape[0] // 2, img.shape[1] // 2 + offsets = np.vstack([ + (det[:, 0] + det[:, 2]) / 2 - img_center[1], + (det[:, 1] + det[:, 3]) / 2 - img_center[0] + ]) + offset_dist_squared = np.sum(np.power(offsets, 2.0), 0) + if metric=='max': + values = area + else: + values = area - offset_dist_squared * 2.0 # some extra weight on the centering + bindex = np.argsort(values)[::-1] # some extra weight on the centering + bindex = bindex[0:max_num] + det = det[bindex, :] + if kpss is not None: + kpss = kpss[bindex, :] + return det, kpss + + def nms(self, dets): + thresh = self.nms_thresh + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + diff --git a/core/models/appearance_extractor.py b/core/models/appearance_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..906ca1211b936c3a3e5bc94ec1da76dd8334d0fe --- /dev/null +++ b/core/models/appearance_extractor.py @@ -0,0 +1,29 @@ +import numpy as np +import torch +from ..utils.load_model import load_model + + +class AppearanceExtractor: + def __init__(self, model_path, device="cuda"): + kwargs = { + "module_name": "AppearanceFeatureExtractor", + } + self.model, self.model_type = load_model(model_path, device=device, **kwargs) + self.device = device + + def __call__(self, image): + """ + image: np.ndarray, shape (1, 3, 256, 256), float32, range [0, 1] + """ + if self.model_type == "onnx": + pred = self.model.run(None, {"image": image})[0] + elif self.model_type == "tensorrt": + self.model.setup({"image": image}) + self.model.infer() + pred = self.model.buffer["pred"][0].copy() + elif self.model_type == 'pytorch': + with torch.no_grad(), torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=True): + pred = self.model(torch.from_numpy(image).to(self.device)).float().cpu().numpy() + else: + raise ValueError(f"Unsupported model type: {self.model_type}") + return pred diff --git a/core/models/decoder.py b/core/models/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..54f6336bdd3e4cf02410798274d365f38ec5c941 --- /dev/null +++ b/core/models/decoder.py @@ -0,0 +1,30 @@ +import numpy as np +import torch +from ..utils.load_model import load_model + + +class Decoder: + def __init__(self, model_path, device="cuda"): + kwargs = { + "module_name": "SPADEDecoder", + } + self.model, self.model_type = load_model(model_path, device=device, **kwargs) + self.device = device + + def __call__(self, feature): + + if self.model_type == "onnx": + pred = self.model.run(None, {"feature": feature})[0] + elif self.model_type == "tensorrt": + self.model.setup({"feature": feature}) + self.model.infer() + pred = self.model.buffer["output"][0].copy() + elif self.model_type == 'pytorch': + with torch.no_grad(), torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=True): + pred = self.model(torch.from_numpy(feature).to(self.device)).float().cpu().numpy() + else: + raise ValueError(f"Unsupported model type: {self.model_type}") + + pred = np.transpose(pred[0], [1, 2, 0]).clip(0, 1) * 255 # [h, w, c] + + return pred diff --git a/core/models/lmdm.py b/core/models/lmdm.py new file mode 100644 index 0000000000000000000000000000000000000000..a6467fa8f77e072165c461a16afa3525d8cebd27 --- /dev/null +++ b/core/models/lmdm.py @@ -0,0 +1,140 @@ +import numpy as np +import torch +from ..utils.load_model import load_model + + +def make_beta(n_timestep, cosine_s=8e-3): + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + return betas.numpy() + + +class LMDM: + def __init__(self, model_path, device="cuda", **kwargs): + kwargs["module_name"] = "LMDM" + + self.model, self.model_type = load_model(model_path, device=device, **kwargs) + self.device = device + + self.motion_feat_dim = kwargs.get("motion_feat_dim", 265) + self.audio_feat_dim = kwargs.get("audio_feat_dim", 1024+35) + self.seq_frames = kwargs.get("seq_frames", 80) + + if self.model_type == "pytorch": + pass + else: + self._init_np() + + def setup(self, sampling_timesteps): + if self.model_type == "pytorch": + self.model.setup(sampling_timesteps) + else: + self._setup_np(sampling_timesteps) + + def _init_np(self): + self.sampling_timesteps = None + self.n_timestep = 1000 + + betas = torch.Tensor(make_beta(n_timestep=self.n_timestep)) + alphas = 1.0 - betas + self.alphas_cumprod = torch.cumprod(alphas, axis=0).cpu().numpy() + + def _setup_np(self, sampling_timesteps=50): + if self.sampling_timesteps == sampling_timesteps: + return + + self.sampling_timesteps = sampling_timesteps + + total_timesteps = self.n_timestep + eta = 1 + shape = (1, self.seq_frames, self.motion_feat_dim) + + times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps + times = list(reversed(times.int().tolist())) + self.time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] + + self.time_cond_list = [] + self.alpha_next_sqrt_list = [] + self.sigma_list = [] + self.c_list = [] + self.noise_list = [] + + for time, time_next in self.time_pairs: + time_cond = np.full((1,), time, dtype=np.int64) + self.time_cond_list.append(time_cond) + if time_next < 0: + continue + + alpha = self.alphas_cumprod[time] + alpha_next = self.alphas_cumprod[time_next] + + sigma = eta * np.sqrt((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)) + c = np.sqrt(1 - alpha_next - sigma ** 2) + noise = np.random.randn(*shape).astype(np.float32) + + self.alpha_next_sqrt_list.append(np.sqrt(alpha_next)) + self.sigma_list.append(sigma) + self.c_list.append(c) + self.noise_list.append(noise) + + def _one_step(self, x, cond_frame, cond, time_cond): + if self.model_type == "onnx": + pred = self.model.run(None, {"x": x, "cond_frame": cond_frame, "cond": cond, "time_cond": time_cond}) + pred_noise, x_start = pred[0], pred[1] + elif self.model_type == "tensorrt": + self.model.setup({"x": x, "cond_frame": cond_frame, "cond": cond, "time_cond": time_cond}) + self.model.infer() + pred_noise, x_start = self.model.buffer["pred_noise"][0], self.model.buffer["x_start"][0] + elif self.model_type == "pytorch": + with torch.no_grad(): + pred_noise, x_start = self.model(x, cond_frame, cond, time_cond) + else: + raise ValueError(f"Unsupported model type: {self.model_type}") + + return pred_noise, x_start + + def _call_np(self, kp_cond, aud_cond, sampling_timesteps): + self._setup_np(sampling_timesteps) + + cond_frame = kp_cond + cond = aud_cond + + x = np.random.randn(1, self.seq_frames, self.motion_feat_dim).astype(np.float32) + + x_start = None + i = 0 + for _, time_next in self.time_pairs: + time_cond = self.time_cond_list[i] + pred_noise, x_start = self._one_step(x, cond_frame, cond, time_cond) + if time_next < 0: + x = x_start + continue + + alpha_next_sqrt = self.alpha_next_sqrt_list[i] + c = self.c_list[i] + sigma = self.sigma_list[i] + noise = self.noise_list[i] + x = x_start * alpha_next_sqrt + c * pred_noise + sigma * noise + + i += 1 + + return x + + def __call__(self, kp_cond, aud_cond, sampling_timesteps): + if self.model_type == "pytorch": + pred_kp_seq = self.model.ddim_sample( + torch.from_numpy(kp_cond).to(self.device), + torch.from_numpy(aud_cond).to(self.device), + sampling_timesteps, + ).cpu().numpy() + else: + pred_kp_seq = self._call_np(kp_cond, aud_cond, sampling_timesteps) + return pred_kp_seq + + diff --git a/core/models/modules/LMDM.py b/core/models/modules/LMDM.py new file mode 100644 index 0000000000000000000000000000000000000000..323ff25c61e85eb6344b75ba19338eb934853613 --- /dev/null +++ b/core/models/modules/LMDM.py @@ -0,0 +1,154 @@ +# Latent Motion Diffusion Model +import torch +import torch.nn as nn +from .lmdm_modules.model import MotionDecoder +from .lmdm_modules.utils import extract, make_beta_schedule + + +class LMDM(nn.Module): + def __init__( + self, + motion_feat_dim=265, + audio_feat_dim=1024+35, + seq_frames=80, + checkpoint='', + device='cuda', + clip_denoised=False, # clip denoised (-1,1) + multi_cond_frame=False, + ): + super().__init__() + + self.motion_feat_dim = motion_feat_dim + self.audio_feat_dim = audio_feat_dim + self.seq_frames = seq_frames + self.device = device + + self.n_timestep = 1000 + self.clip_denoised = clip_denoised + self.guidance_weight = 2 + + self.model = MotionDecoder( + nfeats=motion_feat_dim, + seq_len=seq_frames, + latent_dim=512, + ff_size=1024, + num_layers=8, + num_heads=8, + dropout=0.1, + cond_feature_dim=audio_feat_dim, + multi_cond_frame=multi_cond_frame, + ) + + self.init_diff() + + self.sampling_timesteps = None + + def init_diff(self): + n_timestep = self.n_timestep + betas = torch.Tensor( + make_beta_schedule(schedule="cosine", n_timestep=n_timestep) + ) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, axis=0) + + self.register_buffer("alphas_cumprod", alphas_cumprod) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1) + ) + self.register_buffer("sqrt_recip1m_alphas_cumprod", torch.sqrt(1.0 / (1.0 - alphas_cumprod))) + + def predict_noise_from_start(self, x_t, t, x0): + a = extract(self.sqrt_recip1m_alphas_cumprod, t, x_t.shape) + b = extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + return (a * x_t - x0 / b) + + def maybe_clip(self, x): + if self.clip_denoised: + return torch.clamp(x, min=-1., max=1.) + else: + return x + + def model_predictions(self, x, cond_frame, cond, t): + weight = self.guidance_weight + x_start = self.model.guided_forward(x, cond_frame, cond, t, weight) + x_start = self.maybe_clip(x_start) + pred_noise = self.predict_noise_from_start(x, t, x_start) + return pred_noise, x_start + + @torch.no_grad() + def forward(self, x, cond_frame, cond, time_cond): + pred_noise, x_start = self.model_predictions(x, cond_frame, cond, time_cond) + return pred_noise, x_start + + def load_model(self, ckpt_path): + checkpoint = torch.load(ckpt_path, map_location='cpu') + self.model.load_state_dict(checkpoint["model_state_dict"]) + self.eval() + return self + + def setup(self, sampling_timesteps=50): + if self.sampling_timesteps == sampling_timesteps: + return + + self.sampling_timesteps = sampling_timesteps + + total_timesteps = self.n_timestep + device = self.device + eta = 1 + shape = (1, self.seq_frames, self.motion_feat_dim) + + times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps + times = list(reversed(times.int().tolist())) + self.time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] + + self.time_cond_list = [] + self.alpha_next_sqrt_list = [] + self.sigma_list = [] + self.c_list = [] + self.noise_list = [] + + for time, time_next in self.time_pairs: + time_cond = torch.full((1,), time, device=device, dtype=torch.long) + self.time_cond_list.append(time_cond) + if time_next < 0: + continue + alpha = self.alphas_cumprod[time] + alpha_next = self.alphas_cumprod[time_next] + + sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() + c = (1 - alpha_next - sigma ** 2).sqrt() + noise = torch.randn(shape, device=device) + + self.alpha_next_sqrt_list.append(alpha_next.sqrt()) + self.sigma_list.append(sigma) + self.c_list.append(c) + self.noise_list.append(noise) + + @torch.no_grad() + def ddim_sample(self, kp_cond, aud_cond, sampling_timesteps): + self.setup(sampling_timesteps) + + cond_frame = kp_cond + cond = aud_cond + + shape = (1, self.seq_frames, self.motion_feat_dim) + x = torch.randn(shape, device=self.device) + + x_start = None + i = 0 + for _, time_next in self.time_pairs: + time_cond = self.time_cond_list[i] + pred_noise, x_start = self.model_predictions(x, cond_frame, cond, time_cond) + if time_next < 0: + x = x_start + continue + + alpha_next_sqrt = self.alpha_next_sqrt_list[i] + c = self.c_list[i] + sigma = self.sigma_list[i] + noise = self.noise_list[i] + x = x_start * alpha_next_sqrt + c * pred_noise + sigma * noise + + i += 1 + return x # pred_kp_seq + diff --git a/core/models/modules/__init__.py b/core/models/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..925672f5f1221820a5dbb483af1f8ea6c492e613 --- /dev/null +++ b/core/models/modules/__init__.py @@ -0,0 +1,6 @@ +from .appearance_feature_extractor import AppearanceFeatureExtractor +from .motion_extractor import MotionExtractor +from .warping_network import WarpingNetwork +from .spade_generator import SPADEDecoder +from .stitching_network import StitchingNetwork +from .LMDM import LMDM \ No newline at end of file diff --git a/core/models/modules/appearance_feature_extractor.py b/core/models/modules/appearance_feature_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..dc5c1a72a613762d11a04fe9316859a78e591fec --- /dev/null +++ b/core/models/modules/appearance_feature_extractor.py @@ -0,0 +1,74 @@ +# coding: utf-8 + +""" +Appearance extractor(F) defined in paper, which maps the source image s to a 3D appearance feature volume. +""" + +import torch +from torch import nn +from .util import SameBlock2d, DownBlock2d, ResBlock3d + + +class AppearanceFeatureExtractor(nn.Module): + + def __init__( + self, + image_channel=3, + block_expansion=64, + num_down_blocks=2, + max_features=512, + reshape_channel=32, + reshape_depth=16, + num_resblocks=6, + ): + super(AppearanceFeatureExtractor, self).__init__() + self.image_channel = image_channel + self.block_expansion = block_expansion + self.num_down_blocks = num_down_blocks + self.max_features = max_features + self.reshape_channel = reshape_channel + self.reshape_depth = reshape_depth + + self.first = SameBlock2d( + image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1) + ) + + down_blocks = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2**i)) + out_features = min(max_features, block_expansion * (2 ** (i + 1))) + down_blocks.append( + DownBlock2d( + in_features, out_features, kernel_size=(3, 3), padding=(1, 1) + ) + ) + self.down_blocks = nn.ModuleList(down_blocks) + + self.second = nn.Conv2d( + in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1 + ) + + self.resblocks_3d = torch.nn.Sequential() + for i in range(num_resblocks): + self.resblocks_3d.add_module( + "3dr" + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1) + ) + + def forward(self, source_image): + out = self.first(source_image) # Bx3x256x256 -> Bx64x256x256 + + for i in range(len(self.down_blocks)): + out = self.down_blocks[i](out) + out = self.second(out) + bs, c, h, w = out.shape # ->Bx512x64x64 + + f_s = out.view( + bs, self.reshape_channel, self.reshape_depth, h, w + ) # ->Bx32x16x64x64 + f_s = self.resblocks_3d(f_s) # ->Bx32x16x64x64 + return f_s + + def load_model(self, ckpt_path): + self.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage)) + self.eval() + return self diff --git a/core/models/modules/convnextv2.py b/core/models/modules/convnextv2.py new file mode 100644 index 0000000000000000000000000000000000000000..1004d6ef08cb0d1c4b6a601f361f83a60833efee --- /dev/null +++ b/core/models/modules/convnextv2.py @@ -0,0 +1,150 @@ +# coding: utf-8 + +""" +This moudle is adapted to the ConvNeXtV2 version for the extraction of implicit keypoints, poses, and expression deformation. +""" + +import torch +import torch.nn as nn +# from timm.models.layers import trunc_normal_, DropPath +from .util import LayerNorm, DropPath, trunc_normal_, GRN + +__all__ = ['convnextv2_tiny'] + + +class Block(nn.Module): + """ ConvNeXtV2 Block. + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + """ + + def __init__(self, dim, drop_path=0.): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.grn = GRN(4 * dim) + self.pwconv2 = nn.Linear(4 * dim, dim) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class ConvNeXtV2(nn.Module): + """ ConvNeXt V2 + + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate (float): Stochastic depth rate. Default: 0. + head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. + """ + + def __init__( + self, + in_chans=3, + depths=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + drop_path_rate=0., + **kwargs + ): + super().__init__() + self.depths = depths + self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers + stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first") + ) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks + dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + cur = 0 + for i in range(4): + stage = nn.Sequential( + *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])] + ) + self.stages.append(stage) + cur += depths[i] + + self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer + + # NOTE: the output semantic items + num_bins = kwargs.get('num_bins', 66) + num_kp = kwargs.get('num_kp', 24) # the number of implicit keypoints + self.fc_kp = nn.Linear(dims[-1], 3 * num_kp) # implicit keypoints + + # print('dims[-1]: ', dims[-1]) + self.fc_scale = nn.Linear(dims[-1], 1) # scale + self.fc_pitch = nn.Linear(dims[-1], num_bins) # pitch bins + self.fc_yaw = nn.Linear(dims[-1], num_bins) # yaw bins + self.fc_roll = nn.Linear(dims[-1], num_bins) # roll bins + self.fc_t = nn.Linear(dims[-1], 3) # translation + self.fc_exp = nn.Linear(dims[-1], 3 * num_kp) # expression / delta + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + nn.init.constant_(m.bias, 0) + + def forward_features(self, x): + for i in range(4): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) + + def forward(self, x): + x = self.forward_features(x) + + # implicit keypoints + kp = self.fc_kp(x) + + # pose and expression deformation + pitch = self.fc_pitch(x) + yaw = self.fc_yaw(x) + roll = self.fc_roll(x) + t = self.fc_t(x) + exp = self.fc_exp(x) + scale = self.fc_scale(x) + + # ret_dct = { + # 'pitch': pitch, + # 'yaw': yaw, + # 'roll': roll, + # 't': t, + # 'exp': exp, + # 'scale': scale, + + # 'kp': kp, # canonical keypoint + # } + + # return ret_dct + return pitch, yaw, roll, t, exp, scale, kp + + +def convnextv2_tiny(**kwargs): + model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) + return model diff --git a/core/models/modules/dense_motion.py b/core/models/modules/dense_motion.py new file mode 100644 index 0000000000000000000000000000000000000000..c1a7f9a5a1b5463c4a1ddcee2bf36dcd34735706 --- /dev/null +++ b/core/models/modules/dense_motion.py @@ -0,0 +1,104 @@ +# coding: utf-8 + +""" +The module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving +""" + +from torch import nn +import torch.nn.functional as F +import torch +from .util import Hourglass, make_coordinate_grid, kp2gaussian + + +class DenseMotionNetwork(nn.Module): + def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, estimate_occlusion_map=True): + super(DenseMotionNetwork, self).__init__() + self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) # ~60+G + + self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) # 65G! NOTE: computation cost is large + self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) # 0.8G + self.norm = nn.BatchNorm3d(compress, affine=True) + self.num_kp = num_kp + self.flag_estimate_occlusion_map = estimate_occlusion_map + + if self.flag_estimate_occlusion_map: + self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3) + else: + self.occlusion = None + + def create_sparse_motions(self, feature, kp_driving, kp_source): + bs, _, d, h, w = feature.shape # (bs, 4, 16, 64, 64) + identity_grid = make_coordinate_grid((d, h, w), ref=kp_source) # (16, 64, 64, 3) + identity_grid = identity_grid.view(1, 1, d, h, w, 3) # (1, 1, d=16, h=64, w=64, 3) + coordinate_grid = identity_grid - kp_driving.view(bs, self.num_kp, 1, 1, 1, 3) + + k = coordinate_grid.shape[1] + + # NOTE: there lacks an one-order flow + driving_to_source = coordinate_grid + kp_source.view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3) + + # adding background feature + identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1) + sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) # (bs, 1+num_kp, d, h, w, 3) + return sparse_motions + + def create_deformed_feature(self, feature, sparse_motions): + bs, _, d, h, w = feature.shape + feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w) + feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w) + sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3) + sparse_deformed = F.grid_sample(feature_repeat, sparse_motions, align_corners=False) + sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w) + + return sparse_deformed + + def create_heatmap_representations(self, feature, kp_driving, kp_source): + spatial_size = feature.shape[3:] # (d=16, h=64, w=64) + gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w) + gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w) + heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w) + + # adding background feature + zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.dtype).to(heatmap.device) + heatmap = torch.cat([zeros, heatmap], dim=1) + heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w) + return heatmap + + def forward(self, feature, kp_driving, kp_source): + bs, _, d, h, w = feature.shape # (bs, 32, 16, 64, 64) + + feature = self.compress(feature) # (bs, 4, 16, 64, 64) + feature = self.norm(feature) # (bs, 4, 16, 64, 64) + feature = F.relu(feature) # (bs, 4, 16, 64, 64) + + out_dict = dict() + + # 1. deform 3d feature + sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) # (bs, 1+num_kp, d, h, w, 3) + deformed_feature = self.create_deformed_feature(feature, sparse_motion) # (bs, 1+num_kp, c=4, d=16, h=64, w=64) + + # 2. (bs, 1+num_kp, d, h, w) + heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) # (bs, 1+num_kp, 1, d, h, w) + + input = torch.cat([heatmap, deformed_feature], dim=2) # (bs, 1+num_kp, c=5, d=16, h=64, w=64) + input = input.view(bs, -1, d, h, w) # (bs, (1+num_kp)*c=105, d=16, h=64, w=64) + + prediction = self.hourglass(input) + + mask = self.mask(prediction) + mask = F.softmax(mask, dim=1) # (bs, 1+num_kp, d=16, h=64, w=64) + out_dict['mask'] = mask + mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) + sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w) + deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) mask take effect in this place + deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3) + + out_dict['deformation'] = deformation + + if self.flag_estimate_occlusion_map: + bs, _, d, h, w = prediction.shape + prediction_reshape = prediction.view(bs, -1, h, w) + occlusion_map = torch.sigmoid(self.occlusion(prediction_reshape)) # Bx1x64x64 + out_dict['occlusion_map'] = occlusion_map + + return out_dict diff --git a/core/models/modules/lmdm_modules/model.py b/core/models/modules/lmdm_modules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..1ad7eb6d77074c931ac986d3a81bdf8bcf5f48c0 --- /dev/null +++ b/core/models/modules/lmdm_modules/model.py @@ -0,0 +1,398 @@ +from typing import Callable, Optional, Union +import torch +import torch.nn as nn +from einops import rearrange +from einops.layers.torch import Rearrange +from torch import Tensor +from torch.nn import functional as F + +from .rotary_embedding_torch import RotaryEmbedding +from .utils import PositionalEncoding, SinusoidalPosEmb, prob_mask_like + + +class DenseFiLM(nn.Module): + """Feature-wise linear modulation (FiLM) generator.""" + + def __init__(self, embed_channels): + super().__init__() + self.embed_channels = embed_channels + self.block = nn.Sequential( + nn.Mish(), nn.Linear(embed_channels, embed_channels * 2) + ) + + def forward(self, position): + pos_encoding = self.block(position) + pos_encoding = rearrange(pos_encoding, "b c -> b 1 c") + scale_shift = pos_encoding.chunk(2, dim=-1) + return scale_shift + + +def featurewise_affine(x, scale_shift): + scale, shift = scale_shift + return (scale + 1) * x + shift + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-5, + batch_first: bool = False, + norm_first: bool = True, + device=None, + dtype=None, + rotary=None, + ) -> None: + super().__init__() + self.self_attn = nn.MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first + ) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm_first = norm_first + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.activation = activation + + self.rotary = rotary + self.use_rotary = rotary is not None + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + x = src + if self.norm_first: + x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) + x = x + self._ff_block(self.norm2(x)) + else: + x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) + x = self.norm2(x + self._ff_block(x)) + + return x + + # self-attention block + def _sa_block( + self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor] + ) -> Tensor: + qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x + x = self.self_attn( + qk, + qk, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class FiLMTransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward=2048, + dropout=0.1, + activation=F.relu, + layer_norm_eps=1e-5, + batch_first=False, + norm_first=True, + device=None, + dtype=None, + rotary=None, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first + ) + self.multihead_attn = nn.MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first + ) + # Feedforward + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm_first = norm_first + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + self.activation = activation + + self.film1 = DenseFiLM(d_model) + self.film2 = DenseFiLM(d_model) + self.film3 = DenseFiLM(d_model) + + self.rotary = rotary + self.use_rotary = rotary is not None + + # x, cond, t + def forward( + self, + tgt, + memory, + t, + tgt_mask=None, + memory_mask=None, + tgt_key_padding_mask=None, + memory_key_padding_mask=None, + ): + x = tgt + if self.norm_first: + # self-attention -> film -> residual + x_1 = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask) + x = x + featurewise_affine(x_1, self.film1(t)) + # cross-attention -> film -> residual + x_2 = self._mha_block( + self.norm2(x), memory, memory_mask, memory_key_padding_mask + ) + x = x + featurewise_affine(x_2, self.film2(t)) + # feedforward -> film -> residual + x_3 = self._ff_block(self.norm3(x)) + x = x + featurewise_affine(x_3, self.film3(t)) + else: + x = self.norm1( + x + + featurewise_affine( + self._sa_block(x, tgt_mask, tgt_key_padding_mask), self.film1(t) + ) + ) + x = self.norm2( + x + + featurewise_affine( + self._mha_block(x, memory, memory_mask, memory_key_padding_mask), + self.film2(t), + ) + ) + x = self.norm3(x + featurewise_affine(self._ff_block(x), self.film3(t))) + return x + + # self-attention block + # qkv + def _sa_block(self, x, attn_mask, key_padding_mask): + qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x + x = self.self_attn( + qk, + qk, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + return self.dropout1(x) + + # multihead attention block + # qkv + def _mha_block(self, x, mem, attn_mask, key_padding_mask): + q = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x + k = self.rotary.rotate_queries_or_keys(mem) if self.use_rotary else mem + x = self.multihead_attn( + q, + k, + mem, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + return self.dropout2(x) + + # feed forward block + def _ff_block(self, x): + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout3(x) + + +class DecoderLayerStack(nn.Module): + def __init__(self, stack): + super().__init__() + self.stack = stack + + def forward(self, x, cond, t): + for layer in self.stack: + x = layer(x, cond, t) + return x + + +class MotionDecoder(nn.Module): + def __init__( + self, + nfeats: int, + seq_len: int = 100, # 4 seconds, 25 fps + latent_dim: int = 256, + ff_size: int = 1024, + num_layers: int = 4, + num_heads: int = 4, + dropout: float = 0.1, + cond_feature_dim: int = 4800, + activation: Callable[[Tensor], Tensor] = F.gelu, + use_rotary=True, + multi_cond_frame=False, + **kwargs + ) -> None: + + super().__init__() + + self.multi_cond_frame = multi_cond_frame + + output_feats = nfeats + + # positional embeddings + self.rotary = None + self.abs_pos_encoding = nn.Identity() + # if rotary, replace absolute embedding with a rotary embedding instance (absolute becomes an identity) + if use_rotary: + self.rotary = RotaryEmbedding(dim=latent_dim) + else: + self.abs_pos_encoding = PositionalEncoding( + latent_dim, dropout, batch_first=True + ) + + # time embedding processing + self.time_mlp = nn.Sequential( + SinusoidalPosEmb(latent_dim), # learned? + nn.Linear(latent_dim, latent_dim * 4), + nn.Mish(), + ) + + self.to_time_cond = nn.Sequential(nn.Linear(latent_dim * 4, latent_dim),) + + self.to_time_tokens = nn.Sequential( + nn.Linear(latent_dim * 4, latent_dim * 2), # 2 time tokens + Rearrange("b (r d) -> b r d", r=2), + ) + + # null embeddings for guidance dropout + self.null_cond_embed = nn.Parameter(torch.randn(1, seq_len, latent_dim)) + self.null_cond_hidden = nn.Parameter(torch.randn(1, latent_dim)) + + self.norm_cond = nn.LayerNorm(latent_dim) + + # input projection + if self.multi_cond_frame: + self.input_projection = nn.Linear(nfeats * 2 + 1, latent_dim) + else: + self.input_projection = nn.Linear(nfeats * 2, latent_dim) + self.cond_encoder = nn.Sequential() + for _ in range(2): + self.cond_encoder.append( + TransformerEncoderLayer( + d_model=latent_dim, + nhead=num_heads, + dim_feedforward=ff_size, + dropout=dropout, + activation=activation, + batch_first=True, + rotary=self.rotary, + ) + ) + # conditional projection + self.cond_projection = nn.Linear(cond_feature_dim, latent_dim) + self.non_attn_cond_projection = nn.Sequential( + nn.LayerNorm(latent_dim), + nn.Linear(latent_dim, latent_dim), + nn.SiLU(), + nn.Linear(latent_dim, latent_dim), + ) + # decoder + decoderstack = nn.ModuleList([]) + for _ in range(num_layers): + decoderstack.append( + FiLMTransformerDecoderLayer( + latent_dim, + num_heads, + dim_feedforward=ff_size, + dropout=dropout, + activation=activation, + batch_first=True, + rotary=self.rotary, + ) + ) + + self.seqTransDecoder = DecoderLayerStack(decoderstack) + + self.final_layer = nn.Linear(latent_dim, output_feats) + + self.epsilon = 0.00001 + + def guided_forward(self, x, cond_frame, cond_embed, times, guidance_weight): + unc = self.forward(x, cond_frame, cond_embed, times, cond_drop_prob=1) + conditioned = self.forward(x, cond_frame, cond_embed, times, cond_drop_prob=0) + + return unc + (conditioned - unc) * guidance_weight + + def forward( + self, x: Tensor, cond_frame: Tensor, cond_embed: Tensor, times: Tensor, cond_drop_prob: float = 0.0 + ): + batch_size, device = x.shape[0], x.device + + # concat last frame, project to latent space + # cond_frame: [b, dim] | [b, n, dim+1] + if self.multi_cond_frame: + # [b, n, dim+1] (+1 mask) + x = torch.cat([x, cond_frame], dim=-1) + else: + # [b, dim] + x = torch.cat([x, cond_frame.unsqueeze(1).repeat(1, x.shape[1], 1)], dim=-1) + x = self.input_projection(x) + # add the positional embeddings of the input sequence to provide temporal information + x = self.abs_pos_encoding(x) + + # create audio conditional embedding with conditional dropout + keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device=device) + keep_mask_embed = rearrange(keep_mask, "b -> b 1 1") + keep_mask_hidden = rearrange(keep_mask, "b -> b 1") + + cond_tokens = self.cond_projection(cond_embed) + # encode tokens + cond_tokens = self.abs_pos_encoding(cond_tokens) + cond_tokens = self.cond_encoder(cond_tokens) + + null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype) + cond_tokens = torch.where(keep_mask_embed, cond_tokens, null_cond_embed) + + mean_pooled_cond_tokens = cond_tokens.mean(dim=-2) + cond_hidden = self.non_attn_cond_projection(mean_pooled_cond_tokens) + + # create the diffusion timestep embedding, add the extra audio projection + t_hidden = self.time_mlp(times) + + # project to attention and FiLM conditioning + t = self.to_time_cond(t_hidden) + t_tokens = self.to_time_tokens(t_hidden) + + # FiLM conditioning + null_cond_hidden = self.null_cond_hidden.to(t.dtype) + cond_hidden = torch.where(keep_mask_hidden, cond_hidden, null_cond_hidden) + t += cond_hidden + + # cross-attention conditioning + c = torch.cat((cond_tokens, t_tokens), dim=-2) + cond_tokens = self.norm_cond(c) + + # Pass through the transformer decoder + # attending to the conditional embedding + output = self.seqTransDecoder(x, cond_tokens, t) + + output = self.final_layer(output) + + return output \ No newline at end of file diff --git a/core/models/modules/lmdm_modules/rotary_embedding_torch.py b/core/models/modules/lmdm_modules/rotary_embedding_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..7f9410d04bfce12cfd7a19c7a8de9ebbf9ce4540 --- /dev/null +++ b/core/models/modules/lmdm_modules/rotary_embedding_torch.py @@ -0,0 +1,132 @@ +from inspect import isfunction +from math import log, pi + +import torch +from einops import rearrange, repeat +from torch import einsum, nn + +# helper functions + + +def exists(val): + return val is not None + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] + ), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +# rotary embedding helper functions + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +def apply_rotary_emb(freqs, t, start_index=0): + freqs = freqs.to(t) + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim + assert ( + rot_dim <= t.shape[-1] + ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" + t_left, t, t_right = ( + t[..., :start_index], + t[..., start_index:end_index], + t[..., end_index:], + ) + t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) + return torch.cat((t_left, t, t_right), dim=-1) + + +# learned rotation helpers + + +def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None): + if exists(freq_ranges): + rotations = einsum("..., f -> ... f", rotations, freq_ranges) + rotations = rearrange(rotations, "... r f -> ... (r f)") + + rotations = repeat(rotations, "... n -> ... (n r)", r=2) + return apply_rotary_emb(rotations, t, start_index=start_index) + + +# classes + + +class RotaryEmbedding(nn.Module): + def __init__( + self, + dim, + custom_freqs=None, + freqs_for="lang", + theta=10000, + max_freq=10, + num_freqs=1, + learned_freq=False, + ): + super().__init__() + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + self.cache = dict() + + if learned_freq: + self.freqs = nn.Parameter(freqs) + else: + self.register_buffer("freqs", freqs) + + def rotate_queries_or_keys(self, t, seq_dim=-2): + device = t.device + seq_len = t.shape[seq_dim] + freqs = self.forward( + lambda: torch.arange(seq_len, device=device), cache_key=seq_len + ) + return apply_rotary_emb(freqs, t) + + def forward(self, t, cache_key=None): + if exists(cache_key) and cache_key in self.cache: + return self.cache[cache_key] + + if isfunction(t): + t = t() + + freqs = self.freqs + + freqs = torch.einsum("..., f -> ... f", t.type(freqs.dtype), freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + + if exists(cache_key): + self.cache[cache_key] = freqs + + return freqs diff --git a/core/models/modules/lmdm_modules/utils.py b/core/models/modules/lmdm_modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..883cf5ba52cdadbb3033c870a95147486bd4487e --- /dev/null +++ b/core/models/modules/lmdm_modules/utils.py @@ -0,0 +1,96 @@ +import math +import numpy as np +import torch +from torch import nn + + +# absolute positional embedding used for vanilla transformer sequential data +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=500, batch_first=False): + super().__init__() + self.batch_first = batch_first + + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + + self.register_buffer("pe", pe) + + def forward(self, x): + if self.batch_first: + x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :] + else: + x = x + self.pe[: x.shape[0], :] + return self.dropout(x) + + +# very similar positional embedding used for diffusion timesteps +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +# dropout mask +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif prob == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob + + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def make_beta_schedule( + schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + if schedule == "linear": + betas = ( + torch.linspace( + linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) + elif schedule == "sqrt": + betas = ( + torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + ** 0.5 + ) + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() diff --git a/core/models/modules/motion_extractor.py b/core/models/modules/motion_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..1e6d1ac9b439e1280c01dd5a15aa47e930fdc380 --- /dev/null +++ b/core/models/modules/motion_extractor.py @@ -0,0 +1,25 @@ +# coding: utf-8 + +""" +Motion extractor(M), which directly predicts the canonical keypoints, head pose and expression deformation of the input image +""" + +from torch import nn +import torch + +from .convnextv2 import convnextv2_tiny + + +class MotionExtractor(nn.Module): + def __init__(self, num_kp=21, backbone="convnextv2_tiny"): + super(MotionExtractor, self).__init__() + self.detector = convnextv2_tiny(num_kp=num_kp, backbone=backbone) + + def forward(self, x): + out = self.detector(x) + return out # pitch, yaw, roll, t, exp, scale, kp + + def load_model(self, ckpt_path): + self.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage)) + self.eval() + return self diff --git a/core/models/modules/spade_generator.py b/core/models/modules/spade_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..f7bf31087dd8f1de265e40198099a8a9e743d78a --- /dev/null +++ b/core/models/modules/spade_generator.py @@ -0,0 +1,87 @@ +# coding: utf-8 + +""" +Spade decoder(G) defined in the paper, which input the warped feature to generate the animated image. +""" + +import torch +from torch import nn +import torch.nn.functional as F +from .util import SPADEResnetBlock + + +class SPADEDecoder(nn.Module): + def __init__( + self, + upscale=2, + max_features=512, + block_expansion=64, + out_channels=64, + num_down_blocks=2, + ): + for i in range(num_down_blocks): + input_channels = min(max_features, block_expansion * (2 ** (i + 1))) + self.upscale = upscale + super().__init__() + norm_G = "spadespectralinstance" + label_num_channels = input_channels # 256 + + self.fc = nn.Conv2d(input_channels, 2 * input_channels, 3, padding=1) + self.G_middle_0 = SPADEResnetBlock( + 2 * input_channels, 2 * input_channels, norm_G, label_num_channels + ) + self.G_middle_1 = SPADEResnetBlock( + 2 * input_channels, 2 * input_channels, norm_G, label_num_channels + ) + self.G_middle_2 = SPADEResnetBlock( + 2 * input_channels, 2 * input_channels, norm_G, label_num_channels + ) + self.G_middle_3 = SPADEResnetBlock( + 2 * input_channels, 2 * input_channels, norm_G, label_num_channels + ) + self.G_middle_4 = SPADEResnetBlock( + 2 * input_channels, 2 * input_channels, norm_G, label_num_channels + ) + self.G_middle_5 = SPADEResnetBlock( + 2 * input_channels, 2 * input_channels, norm_G, label_num_channels + ) + self.up_0 = SPADEResnetBlock( + 2 * input_channels, input_channels, norm_G, label_num_channels + ) + self.up_1 = SPADEResnetBlock( + input_channels, out_channels, norm_G, label_num_channels + ) + self.up = nn.Upsample(scale_factor=2) + + if self.upscale is None or self.upscale <= 1: + self.conv_img = nn.Conv2d(out_channels, 3, 3, padding=1) + else: + self.conv_img = nn.Sequential( + nn.Conv2d(out_channels, 3 * (2 * 2), kernel_size=3, padding=1), + nn.PixelShuffle(upscale_factor=2), + ) + + def forward(self, feature): + seg = feature # Bx256x64x64 + x = self.fc(feature) # Bx512x64x64 + x = self.G_middle_0(x, seg) + x = self.G_middle_1(x, seg) + x = self.G_middle_2(x, seg) + x = self.G_middle_3(x, seg) + x = self.G_middle_4(x, seg) + x = self.G_middle_5(x, seg) + + x = self.up(x) # Bx512x64x64 -> Bx512x128x128 + x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128 + x = self.up(x) # Bx256x128x128 -> Bx256x256x256 + x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256 + + x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW + x = torch.sigmoid(x) # Bx3xHxW + + return x + + def load_model(self, ckpt_path): + self.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage)) + self.eval() + return self diff --git a/core/models/modules/stitching_network.py b/core/models/modules/stitching_network.py new file mode 100644 index 0000000000000000000000000000000000000000..1db44d67c781fe2585bd5a4eeb5664ebcf0546f0 --- /dev/null +++ b/core/models/modules/stitching_network.py @@ -0,0 +1,65 @@ +# coding: utf-8 + +""" +Stitching module(S) and two retargeting modules(R) defined in the paper. + +- The stitching module pastes the animated portrait back into the original image space without pixel misalignment, such as in +the stitching region. + +- The eyes retargeting module is designed to address the issue of incomplete eye closure during cross-id reenactment, especially +when a person with small eyes drives a person with larger eyes. + +- The lip retargeting module is designed similarly to the eye retargeting module, and can also normalize the input by ensuring that +the lips are in a closed state, which facilitates better animation driving. +""" +import torch +from torch import nn + + +def remove_ddp_dumplicate_key(state_dict): + from collections import OrderedDict + state_dict_new = OrderedDict() + for key in state_dict.keys(): + state_dict_new[key.replace('module.', '')] = state_dict[key] + return state_dict_new + + +class StitchingNetwork(nn.Module): + def __init__(self, input_size=126, hidden_sizes=[128, 128, 64], output_size=65): + super(StitchingNetwork, self).__init__() + layers = [] + for i in range(len(hidden_sizes)): + if i == 0: + layers.append(nn.Linear(input_size, hidden_sizes[i])) + else: + layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i])) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Linear(hidden_sizes[-1], output_size)) + self.mlp = nn.Sequential(*layers) + + def _forward(self, x): + return self.mlp(x) + + def load_model(self, ckpt_path): + checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage) + self.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_shoulder'])) + self.eval() + return self + + def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: + """ conduct the stitching + kp_source: Bxnum_kpx3 + kp_driving: Bxnum_kpx3 + """ + bs, num_kp = kp_source.shape[:2] + kp_driving_new = kp_driving.clone() + delta = self._forward(torch.cat([kp_source.view(bs, -1), kp_driving_new.view(bs, -1)], dim=1)) + delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) # 1x20x3 + delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) # 1x1x2 + kp_driving_new += delta_exp + kp_driving_new[..., :2] += delta_tx_ty + return kp_driving_new + + def forward(self, kp_source, kp_driving): + out = self.stitching(kp_source, kp_driving) + return out \ No newline at end of file diff --git a/core/models/modules/util.py b/core/models/modules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..dc6b925ff4d93dbb89d0d1e593bee15c888c39ee --- /dev/null +++ b/core/models/modules/util.py @@ -0,0 +1,452 @@ +# coding: utf-8 + +""" +This file defines various neural network modules and utility functions, including convolutional and residual blocks, +normalizations, and functions for spatial transformation and tensor manipulation. +""" + +from torch import nn +import torch.nn.functional as F +import torch +import torch.nn.utils.spectral_norm as spectral_norm +import math +import warnings +import collections.abc +from itertools import repeat + +def kp2gaussian(kp, spatial_size, kp_variance): + """ + Transform a keypoint into gaussian like representation + """ + mean = kp + + coordinate_grid = make_coordinate_grid(spatial_size, mean) + number_of_leading_dimensions = len(mean.shape) - 1 + shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape + coordinate_grid = coordinate_grid.view(*shape) + repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1) + coordinate_grid = coordinate_grid.repeat(*repeats) + + # Preprocess kp shape + shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3) + mean = mean.view(*shape) + + mean_sub = (coordinate_grid - mean) + + out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) + + return out + + +def make_coordinate_grid(spatial_size, ref, **kwargs): + d, h, w = spatial_size + x = torch.arange(w).type(ref.dtype).to(ref.device) + y = torch.arange(h).type(ref.dtype).to(ref.device) + z = torch.arange(d).type(ref.dtype).to(ref.device) + + # NOTE: must be right-down-in + x = (2 * (x / (w - 1)) - 1) # the x axis faces to the right + y = (2 * (y / (h - 1)) - 1) # the y axis faces to the bottom + z = (2 * (z / (d - 1)) - 1) # the z axis faces to the inner + + yy = y.view(1, -1, 1).repeat(d, 1, w) + xx = x.view(1, 1, -1).repeat(d, h, 1) + zz = z.view(-1, 1, 1).repeat(1, h, w) + + meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3) + + return meshed + + +class ConvT2d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1): + super(ConvT2d, self).__init__() + + self.convT = nn.ConvTranspose2d(in_features, out_features, kernel_size=kernel_size, stride=stride, + padding=padding, output_padding=output_padding) + self.norm = nn.InstanceNorm2d(out_features) + + def forward(self, x): + out = self.convT(x) + out = self.norm(out) + out = F.leaky_relu(out) + return out + + +class ResBlock3d(nn.Module): + """ + Res block, preserve spatial resolution. + """ + + def __init__(self, in_features, kernel_size, padding): + super(ResBlock3d, self).__init__() + self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) + self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) + self.norm1 = nn.BatchNorm3d(in_features, affine=True) + self.norm2 = nn.BatchNorm3d(in_features, affine=True) + + def forward(self, x): + out = self.norm1(x) + out = F.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv2(out) + out += x + return out + + +class UpBlock3d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(UpBlock3d, self).__init__() + + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = nn.BatchNorm3d(out_features, affine=True) + + def forward(self, x): + out = F.interpolate(x, scale_factor=(1, 2, 2)) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class DownBlock2d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) + self.norm = nn.BatchNorm2d(out_features, affine=True) + self.pool = nn.AvgPool2d(kernel_size=(2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class DownBlock3d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock3d, self).__init__() + ''' + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups, stride=(1, 2, 2)) + ''' + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = nn.BatchNorm3d(out_features, affine=True) + self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class SameBlock2d(nn.Module): + """ + Simple block, preserve spatial resolution. + """ + + def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False): + super(SameBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) + self.norm = nn.BatchNorm2d(out_features, affine=True) + if lrelu: + self.ac = nn.LeakyReLU() + else: + self.ac = nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.ac(out) + return out + + +class Encoder(nn.Module): + """ + Hourglass Encoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Encoder, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), kernel_size=3, padding=1)) + self.down_blocks = nn.ModuleList(down_blocks) + + def forward(self, x): + outs = [x] + for down_block in self.down_blocks: + outs.append(down_block(outs[-1])) + return outs + + +class Decoder(nn.Module): + """ + Hourglass Decoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Decoder, self).__init__() + + up_blocks = [] + + for i in range(num_blocks)[::-1]: + in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) + out_filters = min(max_features, block_expansion * (2 ** i)) + up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) + + self.up_blocks = nn.ModuleList(up_blocks) + self.out_filters = block_expansion + in_features + + self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1) + self.norm = nn.BatchNorm3d(self.out_filters, affine=True) + + def forward(self, x): + out = x.pop() + for up_block in self.up_blocks: + out = up_block(out) + skip = x.pop() + out = torch.cat([out, skip], dim=1) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class Hourglass(nn.Module): + """ + Hourglass architecture. + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Hourglass, self).__init__() + self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) + self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) + self.out_filters = self.decoder.out_filters + + def forward(self, x): + return self.decoder(self.encoder(x)) + + +class SPADE(nn.Module): + def __init__(self, norm_nc, label_nc): + super().__init__() + + self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) + nhidden = 128 + + self.mlp_shared = nn.Sequential( + nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1), + nn.ReLU()) + self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) + self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) + + def forward(self, x, segmap): + normalized = self.param_free_norm(x) + segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + out = normalized * (1 + gamma) + beta + return out + + +class SPADEResnetBlock(nn.Module): + def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1): + super().__init__() + # Attributes + self.learned_shortcut = (fin != fout) + fmiddle = min(fin, fout) + self.use_se = use_se + # create conv layers + self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation) + self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation) + if self.learned_shortcut: + self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) + # apply spectral norm if specified + if 'spectral' in norm_G: + self.conv_0 = spectral_norm(self.conv_0) + self.conv_1 = spectral_norm(self.conv_1) + if self.learned_shortcut: + self.conv_s = spectral_norm(self.conv_s) + # define normalization layers + self.norm_0 = SPADE(fin, label_nc) + self.norm_1 = SPADE(fmiddle, label_nc) + if self.learned_shortcut: + self.norm_s = SPADE(fin, label_nc) + + def forward(self, x, seg1): + x_s = self.shortcut(x, seg1) + dx = self.conv_0(self.actvn(self.norm_0(x, seg1))) + dx = self.conv_1(self.actvn(self.norm_1(dx, seg1))) + out = x_s + dx + return out + + def shortcut(self, x, seg1): + if self.learned_shortcut: + x_s = self.conv_s(self.norm_s(x, seg1)) + else: + x_s = x + return x_s + + def actvn(self, x): + return F.leaky_relu(x, 2e-1) + + +def filter_state_dict(state_dict, remove_name='fc'): + new_state_dict = {} + for key in state_dict: + if remove_name in key: + continue + new_state_dict[key] = state_dict[key] + return new_state_dict + + +class GRN(nn.Module): + """ GRN (Global Response Normalization) layer + """ + + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +class LayerNorm(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def drop_path(x, drop_prob=0., training=False, scale_by_keep=True): + """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + return parse + +to_2tuple = _ntuple(2) diff --git a/core/models/modules/warping_network.py b/core/models/modules/warping_network.py new file mode 100644 index 0000000000000000000000000000000000000000..3895b899eb574a513557f9345df8f69bab552444 --- /dev/null +++ b/core/models/modules/warping_network.py @@ -0,0 +1,87 @@ +# coding: utf-8 + +""" +Warping field estimator(W) defined in the paper, which generates a warping field using the implicit +keypoint representations x_s and x_d, and employs this flow field to warp the source feature volume f_s. +""" +import torch +from torch import nn +import torch.nn.functional as F +from .util import SameBlock2d +from .dense_motion import DenseMotionNetwork + + +class WarpingNetwork(nn.Module): + def __init__( + self, + num_kp=21, + block_expansion=64, + max_features=512, + num_down_blocks=2, + reshape_channel=32, + estimate_occlusion_map=True, + **kwargs + ): + super(WarpingNetwork, self).__init__() + + self.upscale = kwargs.get('upscale', 1) + self.flag_use_occlusion_map = kwargs.get('flag_use_occlusion_map', True) + + dense_motion_params = { + "block_expansion": 32, + "max_features": 1024, + "num_blocks": 5, + "reshape_depth": 16, + "compress": 4, + } + + self.dense_motion_network = DenseMotionNetwork( + num_kp=num_kp, + feature_channel=reshape_channel, + estimate_occlusion_map=estimate_occlusion_map, + **dense_motion_params + ) + + self.third = SameBlock2d(max_features, block_expansion * (2 ** num_down_blocks), kernel_size=(3, 3), padding=(1, 1), lrelu=True) + self.fourth = nn.Conv2d(in_channels=block_expansion * (2 ** num_down_blocks), out_channels=block_expansion * (2 ** num_down_blocks), kernel_size=1, stride=1) + + self.estimate_occlusion_map = estimate_occlusion_map + + def deform_input(self, inp, deformation): + return F.grid_sample(inp, deformation, align_corners=False) + + def forward(self, feature_3d, kp_source, kp_driving): + # Feature warper, Transforming feature representation according to deformation and occlusion + dense_motion = self.dense_motion_network( + feature=feature_3d, kp_driving=kp_driving, kp_source=kp_source + ) + if 'occlusion_map' in dense_motion: + occlusion_map = dense_motion['occlusion_map'] # Bx1x64x64 + else: + occlusion_map = None + + deformation = dense_motion['deformation'] # Bx16x64x64x3 + out = self.deform_input(feature_3d, deformation) # Bx32x16x64x64 + + bs, c, d, h, w = out.shape # Bx32x16x64x64 + out = out.view(bs, c * d, h, w) # -> Bx512x64x64 + out = self.third(out) # -> Bx256x64x64 + out = self.fourth(out) # -> Bx256x64x64 + + if self.flag_use_occlusion_map and (occlusion_map is not None): + out = out * occlusion_map + + # ret_dct = { + # 'occlusion_map': occlusion_map, + # 'deformation': deformation, + # 'out': out, + # } + + # return ret_dct + + return out + + def load_model(self, ckpt_path): + self.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage)) + self.eval() + return self diff --git a/core/models/motion_extractor.py b/core/models/motion_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..d4e05ac286334e99912af134eaef373a7d8d1969 --- /dev/null +++ b/core/models/motion_extractor.py @@ -0,0 +1,49 @@ +import numpy as np +import torch +from ..utils.load_model import load_model + + +class MotionExtractor: + def __init__(self, model_path, device="cuda"): + kwargs = { + "module_name": "MotionExtractor", + } + self.model, self.model_type = load_model(model_path, device=device, **kwargs) + self.device = device + + self.output_names = [ + "pitch", + "yaw", + "roll", + "t", + "exp", + "scale", + "kp", + ] + + def __call__(self, image): + """ + image: np.ndarray, shape (1, 3, 256, 256), RGB, 0-1 + """ + outputs = {} + if self.model_type == "onnx": + out_list = self.model.run(None, {"image": image}) + for i, name in enumerate(self.output_names): + outputs[name] = out_list[i] + elif self.model_type == "tensorrt": + self.model.setup({"image": image}) + self.model.infer() + for name in self.output_names: + outputs[name] = self.model.buffer[name][0].copy() + elif self.model_type == "pytorch": + with torch.no_grad(), torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=True): + pred = self.model(torch.from_numpy(image).to(self.device)) + for i, name in enumerate(self.output_names): + outputs[name] = pred[i].float().cpu().numpy() + else: + raise ValueError(f"Unsupported model type: {self.model_type}") + outputs["exp"] = outputs["exp"].reshape(1, -1) + outputs["kp"] = outputs["kp"].reshape(1, -1) + return outputs + + diff --git a/core/models/stitch_network.py b/core/models/stitch_network.py new file mode 100644 index 0000000000000000000000000000000000000000..56efa7ad8b1e79eafffd92121a15cbfc92af65dd --- /dev/null +++ b/core/models/stitch_network.py @@ -0,0 +1,30 @@ +import numpy as np +import torch +from ..utils.load_model import load_model + + +class StitchNetwork: + def __init__(self, model_path, device="cuda"): + kwargs = { + "module_name": "StitchingNetwork", + } + self.model, self.model_type = load_model(model_path, device=device, **kwargs) + self.device = device + + def __call__(self, kp_source, kp_driving): + if self.model_type == "onnx": + pred = self.model.run(None, {"kp_source": kp_source, "kp_driving": kp_driving})[0] + elif self.model_type == "tensorrt": + self.model.setup({"kp_source": kp_source, "kp_driving": kp_driving}) + self.model.infer() + pred = self.model.buffer["out"][0].copy() + elif self.model_type == 'pytorch': + with torch.no_grad(): + pred = self.model( + torch.from_numpy(kp_source).to(self.device), + torch.from_numpy(kp_driving).to(self.device) + ).cpu().numpy() + else: + raise ValueError(f"Unsupported model type: {self.model_type}") + + return pred diff --git a/core/models/warp_network.py b/core/models/warp_network.py new file mode 100644 index 0000000000000000000000000000000000000000..8f14edf66c57567bebd5341a4c376dca0a646d46 --- /dev/null +++ b/core/models/warp_network.py @@ -0,0 +1,35 @@ +import numpy as np +import torch +from ..utils.load_model import load_model + + +class WarpNetwork: + def __init__(self, model_path, device="cuda"): + kwargs = { + "module_name": "WarpingNetwork", + } + self.model, self.model_type = load_model(model_path, device=device, **kwargs) + self.device = device + + def __call__(self, feature_3d, kp_source, kp_driving): + """ + feature_3d: np.ndarray, shape (1, 32, 16, 64, 64) + kp_source | kp_driving: np.ndarray, shape (1, 21, 3) + """ + if self.model_type == "onnx": + pred = self.model.run(None, {"feature_3d": feature_3d, "kp_source": kp_source, "kp_driving": kp_driving})[0] + elif self.model_type == "tensorrt": + self.model.setup({"feature_3d": feature_3d, "kp_source": kp_source, "kp_driving": kp_driving}) + self.model.infer() + pred = self.model.buffer["out"][0].copy() + elif self.model_type == 'pytorch': + with torch.no_grad(), torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=True): + pred = self.model( + torch.from_numpy(feature_3d).to(self.device), + torch.from_numpy(kp_source).to(self.device), + torch.from_numpy(kp_driving).to(self.device) + ).float().cpu().numpy() + else: + raise ValueError(f"Unsupported model type: {self.model_type}") + + return pred diff --git a/core/utils/blend/__init__.py b/core/utils/blend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9e91b0dfd73ec50b041d112d91aa1893a27c600 --- /dev/null +++ b/core/utils/blend/__init__.py @@ -0,0 +1,4 @@ +import pyximport +pyximport.install() + +from .blend import blend_images_cy \ No newline at end of file diff --git a/core/utils/blend/blend.pyx b/core/utils/blend/blend.pyx new file mode 100644 index 0000000000000000000000000000000000000000..68e4695dc644ad1f6e2cc0538a13c853fbef2f7a --- /dev/null +++ b/core/utils/blend/blend.pyx @@ -0,0 +1,38 @@ +#cython: language_level=3 +import numpy as np +cimport numpy as np + +cdef extern from "blend_impl.h": + void _blend_images_cy_impl( + const float* mask_warped, + const float* frame_warped, + const unsigned char* frame_rgb, + const int height, + const int width, + unsigned char* result + ) noexcept nogil + +def blend_images_cy( + np.ndarray[np.float32_t, ndim=2] mask_warped, + np.ndarray[np.float32_t, ndim=3] frame_warped, + np.ndarray[np.uint8_t, ndim=3] frame_rgb, + np.ndarray[np.uint8_t, ndim=3] result +): + cdef int h = mask_warped.shape[0] + cdef int w = mask_warped.shape[1] + + if not mask_warped.flags['C_CONTIGUOUS']: + mask_warped = np.ascontiguousarray(mask_warped) + if not frame_warped.flags['C_CONTIGUOUS']: + frame_warped = np.ascontiguousarray(frame_warped) + if not frame_rgb.flags['C_CONTIGUOUS']: + frame_rgb = np.ascontiguousarray(frame_rgb) + + with nogil: + _blend_images_cy_impl( + mask_warped.data, + frame_warped.data, + frame_rgb.data, + h, w, + result.data + ) \ No newline at end of file diff --git a/core/utils/blend/blend.pyxbld b/core/utils/blend/blend.pyxbld new file mode 100644 index 0000000000000000000000000000000000000000..75e1d54ac97c8c38475250b6b441f54ed0b5c5b5 --- /dev/null +++ b/core/utils/blend/blend.pyxbld @@ -0,0 +1,11 @@ +import numpy as np +import os + +def make_ext(modname, pyxfilename): + from distutils.extension import Extension + + return Extension(name=modname, + sources=[pyxfilename, os.path.join(os.path.dirname(pyxfilename), "blend_impl.c")], + include_dirs=[np.get_include(), os.path.dirname(pyxfilename)], + extra_compile_args=["-O3", "-std=c99", "-march=native", "-ffast-math"], + ) diff --git a/core/utils/blend/blend_impl.c b/core/utils/blend/blend_impl.c new file mode 100644 index 0000000000000000000000000000000000000000..bf0661d494057650cad4f5981db298a8a05f8478 --- /dev/null +++ b/core/utils/blend/blend_impl.c @@ -0,0 +1,34 @@ +#include + +void _blend_images_cy_impl( + const float* mask_warped, + const float* frame_warped, + const unsigned char* frame_rgb, + const int height, + const int width, + unsigned char* result) { + + const float* mask_pointer = mask_warped; + const float* frame_warped_pointer = frame_warped; + const unsigned char* frame_rgb_pointer = frame_rgb; + unsigned char* result_pointer = result; + + for(int i = 0; i < height; i++) { + for(int j = 0; j < width; j++) { + float mask = *mask_pointer; + float mask_inv = 1.0f - mask; + + float blended1 = mask * (*frame_warped_pointer) + mask_inv * (*frame_rgb_pointer); + float blended2 = mask * (*(frame_warped_pointer+1)) + mask_inv * (*(frame_rgb_pointer+1)); + float blended3 = mask * (*(frame_warped_pointer+2)) + mask_inv * (*(frame_rgb_pointer+2)); + + *(result_pointer++) = blended1 > 255 ? 255 : (blended1 < 0) ? 0 : (unsigned char)blended1; + *(result_pointer++) = blended2 > 255 ? 255 : (blended2 < 0) ? 0 : (unsigned char)blended2; + *(result_pointer++) = blended3 > 255 ? 255 : (blended3 < 0) ? 0 : (unsigned char)blended3; + + frame_warped_pointer+=3; + frame_rgb_pointer+=3; + mask_pointer++; + } + } +} \ No newline at end of file diff --git a/core/utils/blend/blend_impl.h b/core/utils/blend/blend_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..854092525031c01a65dc7e288cc83324f59b7a1c --- /dev/null +++ b/core/utils/blend/blend_impl.h @@ -0,0 +1,13 @@ + +#ifndef __BLEND_IMAGES_CY_IMPL_H__ +#define __BLEND_IMAGES_CY_IMPL_H__ + +void _blend_images_cy_impl( + const float* mask_warped, + const float* frame_warped, + const unsigned char* frame_rgb, + const int height, + const int width, + unsigned char* result); + +#endif // __BLEND_IMAGES_CY_IMPL_H__ \ No newline at end of file diff --git a/core/utils/crop.py b/core/utils/crop.py new file mode 100644 index 0000000000000000000000000000000000000000..e01d08d53fb1d212c40da1753149b33eb5b14deb --- /dev/null +++ b/core/utils/crop.py @@ -0,0 +1,459 @@ +# coding: utf-8 + +""" +cropping function and the related preprocess functions for cropping +""" + +import numpy as np +import os.path as osp +from math import sin, cos, acos, degrees +import cv2 + +DTYPE = np.float32 +CV2_INTERP = cv2.INTER_LINEAR + + +def _transform_img(img, M, dsize, flags=CV2_INTERP, borderMode=None): + """conduct similarity or affine transformation to the image, do not do border operation! + img: + M: 2x3 matrix or 3x3 matrix + dsize: target shape (width, height) + """ + if isinstance(dsize, tuple) or isinstance(dsize, list): + _dsize = tuple(dsize) + else: + _dsize = (dsize, dsize) + + if borderMode is not None: + return cv2.warpAffine( + img, + M[:2, :], + dsize=_dsize, + flags=flags, + borderMode=borderMode, + borderValue=(0, 0, 0), + ) + else: + return cv2.warpAffine(img, M[:2, :], dsize=_dsize, flags=flags) + + +def _transform_pts(pts, M): + """conduct similarity or affine transformation to the pts + pts: Nx2 ndarray + M: 2x3 matrix or 3x3 matrix + return: Nx2 + """ + return pts @ M[:2, :2].T + M[:2, 2] + + +def parse_pt2_from_pt101(pt101, use_lip=True): + """ + parsing the 2 points according to the 101 points, which cancels the roll + """ + # the former version use the eye center, but it is not robust, now use interpolation + pt_left_eye = np.mean(pt101[[39, 42, 45, 48]], axis=0) # left eye center + pt_right_eye = np.mean(pt101[[51, 54, 57, 60]], axis=0) # right eye center + + if use_lip: + # use lip + pt_center_eye = (pt_left_eye + pt_right_eye) / 2 + pt_center_lip = (pt101[75] + pt101[81]) / 2 + pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0) + else: + pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0) + return pt2 + + +def parse_pt2_from_pt106(pt106, use_lip=True): + """ + parsing the 2 points according to the 106 points, which cancels the roll + """ + pt_left_eye = np.mean(pt106[[33, 35, 40, 39]], axis=0) # left eye center + pt_right_eye = np.mean(pt106[[87, 89, 94, 93]], axis=0) # right eye center + + if use_lip: + # use lip + pt_center_eye = (pt_left_eye + pt_right_eye) / 2 + pt_center_lip = (pt106[52] + pt106[61]) / 2 + pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0) + else: + pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0) + return pt2 + + +def parse_pt2_from_pt203(pt203, use_lip=True): + """ + parsing the 2 points according to the 203 points, which cancels the roll + """ + pt_left_eye = np.mean(pt203[[0, 6, 12, 18]], axis=0) # left eye center + pt_right_eye = np.mean(pt203[[24, 30, 36, 42]], axis=0) # right eye center + if use_lip: + # use lip + pt_center_eye = (pt_left_eye + pt_right_eye) / 2 + pt_center_lip = (pt203[48] + pt203[66]) / 2 + pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0) + else: + pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0) + return pt2 + + +def parse_pt2_from_pt68(pt68, use_lip=True): + """ + parsing the 2 points according to the 68 points, which cancels the roll + """ + lm_idx = np.array([31, 37, 40, 43, 46, 49, 55], dtype=np.int32) - 1 + if use_lip: + pt5 = np.stack( + [ + np.mean(pt68[lm_idx[[1, 2]], :], 0), # left eye + np.mean(pt68[lm_idx[[3, 4]], :], 0), # right eye + pt68[lm_idx[0], :], # nose + pt68[lm_idx[5], :], # lip + pt68[lm_idx[6], :], # lip + ], + axis=0, + ) + + pt2 = np.stack([(pt5[0] + pt5[1]) / 2, (pt5[3] + pt5[4]) / 2], axis=0) + else: + pt2 = np.stack( + [ + np.mean(pt68[lm_idx[[1, 2]], :], 0), # left eye + np.mean(pt68[lm_idx[[3, 4]], :], 0), # right eye + ], + axis=0, + ) + + return pt2 + + +def parse_pt2_from_pt5(pt5, use_lip=True): + """ + parsing the 2 points according to the 5 points, which cancels the roll + """ + if use_lip: + pt2 = np.stack([(pt5[0] + pt5[1]) / 2, (pt5[3] + pt5[4]) / 2], axis=0) + else: + pt2 = np.stack([pt5[0], pt5[1]], axis=0) + return pt2 + + +def parse_pt2_from_pt9(pt9, use_lip=True): + """ + parsing the 2 points according to the 9 points, which cancels the roll + ['right eye right', 'right eye left', 'left eye right', 'left eye left', 'nose tip', 'lip right', 'lip left', 'upper lip', 'lower lip'] + """ + if use_lip: + pt9 = np.stack( + [ + (pt9[2] + pt9[3]) / 2, # left eye + (pt9[0] + pt9[1]) / 2, # right eye + pt9[4], + (pt9[5] + pt9[6]) / 2, # lip + ], + axis=0, + ) + pt2 = np.stack([(pt9[0] + pt9[1]) / 2, pt9[3]], axis=0) # eye # lip + else: + pt2 = np.stack( + [ + (pt9[2] + pt9[3]) / 2, + (pt9[0] + pt9[1]) / 2, + ], + axis=0, + ) + + return pt2 + + +def parse_pt2_from_pt_x(pts, use_lip=True): + if pts.shape[0] == 101: + pt2 = parse_pt2_from_pt101(pts, use_lip=use_lip) + elif pts.shape[0] == 106: + pt2 = parse_pt2_from_pt106(pts, use_lip=use_lip) + elif pts.shape[0] == 68: + pt2 = parse_pt2_from_pt68(pts, use_lip=use_lip) + elif pts.shape[0] == 5: + pt2 = parse_pt2_from_pt5(pts, use_lip=use_lip) + elif pts.shape[0] == 203: + pt2 = parse_pt2_from_pt203(pts, use_lip=use_lip) + elif pts.shape[0] > 101: + # take the first 101 points + pt2 = parse_pt2_from_pt101(pts[:101], use_lip=use_lip) + elif pts.shape[0] == 9: + pt2 = parse_pt2_from_pt9(pts, use_lip=use_lip) + else: + raise Exception(f"Unknow shape: {pts.shape}") + + if not use_lip: + # NOTE: to compile with the latter code, need to rotate the pt2 90 degrees clockwise manually + v = pt2[1] - pt2[0] + pt2[1, 0] = pt2[0, 0] - v[1] + pt2[1, 1] = pt2[0, 1] + v[0] + + return pt2 + + +def parse_rect_from_landmark( + pts, + scale=1.5, + need_square=True, + vx_ratio=0, + vy_ratio=0, + use_deg_flag=False, + **kwargs, +): + """parsing center, size, angle from 101/68/5/x landmarks + vx_ratio: the offset ratio along the pupil axis x-axis, multiplied by size + vy_ratio: the offset ratio along the pupil axis y-axis, multiplied by size, which is used to contain more forehead area + + judge with pts.shape + """ + pt2 = parse_pt2_from_pt_x(pts, use_lip=kwargs.get("use_lip", True)) + + uy = pt2[1] - pt2[0] + l = np.linalg.norm(uy) + if l <= 1e-3: + uy = np.array([0, 1], dtype=DTYPE) + else: + uy /= l + ux = np.array((uy[1], -uy[0]), dtype=DTYPE) + + # the rotation degree of the x-axis, the clockwise is positive, the counterclockwise is negative (image coordinate system) + # print(uy) + # print(ux) + angle = acos(ux[0]) + if ux[1] < 0: + angle = -angle + + # rotation matrix + M = np.array([ux, uy]) + + # calculate the size which contains the angle degree of the bbox, and the center + center0 = np.mean(pts, axis=0) + rpts = (pts - center0) @ M.T # (M @ P.T).T = P @ M.T + lt_pt = np.min(rpts, axis=0) + rb_pt = np.max(rpts, axis=0) + center1 = (lt_pt + rb_pt) / 2 + + size = rb_pt - lt_pt + if need_square: + m = max(size[0], size[1]) + size[0] = m + size[1] = m + + size *= scale # scale size + center = ( + center0 + ux * center1[0] + uy * center1[1] + ) # counterclockwise rotation, equivalent to M.T @ center1.T + center = ( + center + ux * (vx_ratio * size) + uy * (vy_ratio * size) + ) # considering the offset in vx and vy direction + + if use_deg_flag: + angle = degrees(angle) + + return center, size, angle + + +def parse_bbox_from_landmark(pts, **kwargs): + center, size, angle = parse_rect_from_landmark(pts, **kwargs) + cx, cy = center + w, h = size + + # calculate the vertex positions before rotation + bbox = np.array( + [ + [cx - w / 2, cy - h / 2], # left, top + [cx + w / 2, cy - h / 2], + [cx + w / 2, cy + h / 2], # right, bottom + [cx - w / 2, cy + h / 2], + ], + dtype=DTYPE, + ) + + # construct rotation matrix + bbox_rot = bbox.copy() + R = np.array( + [[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]], dtype=DTYPE + ) + + # calculate the relative position of each vertex from the rotation center, then rotate these positions, and finally add the coordinates of the rotation center + bbox_rot = (bbox_rot - center) @ R.T + center + + return { + "center": center, # 2x1 + "size": size, # scalar + "angle": angle, # rad, counterclockwise + "bbox": bbox, # 4x2 + "bbox_rot": bbox_rot, # 4x2 + } + + +def crop_image_by_bbox( + img, bbox, lmk=None, dsize=512, angle=None, flag_rot=False, **kwargs +): + left, top, right, bot = bbox + if int(right - left) != int(bot - top): + print(f"right-left {right-left} != bot-top {bot-top}") + size = right - left + + src_center = np.array([(left + right) / 2, (top + bot) / 2], dtype=DTYPE) + tgt_center = np.array([dsize / 2, dsize / 2], dtype=DTYPE) + + s = dsize / size # scale + if flag_rot and angle is not None: + costheta, sintheta = cos(angle), sin(angle) + cx, cy = src_center[0], src_center[1] # ori center + tcx, tcy = tgt_center[0], tgt_center[1] # target center + # need to infer + M_o2c = np.array( + [ + [s * costheta, s * sintheta, tcx - s * (costheta * cx + sintheta * cy)], + [ + -s * sintheta, + s * costheta, + tcy - s * (-sintheta * cx + costheta * cy), + ], + ], + dtype=DTYPE, + ) + else: + M_o2c = np.array( + [ + [s, 0, tgt_center[0] - s * src_center[0]], + [0, s, tgt_center[1] - s * src_center[1]], + ], + dtype=DTYPE, + ) + + # if flag_rot and angle is None: + # print('angle is None, but flag_rotate is True', style="bold yellow") + + img_crop = _transform_img( + img, M_o2c, dsize=dsize, borderMode=kwargs.get("borderMode", None) + ) + lmk_crop = _transform_pts(lmk, M_o2c) if lmk is not None else None + + M_o2c = np.vstack([M_o2c, np.array([0, 0, 1], dtype=DTYPE)]) + M_c2o = np.linalg.inv(M_o2c) + + # cv2.imwrite('crop.jpg', img_crop) + + return { + "img_crop": img_crop, + "lmk_crop": lmk_crop, + "M_o2c": M_o2c, + "M_c2o": M_c2o, + } + + +def _estimate_similar_transform_from_pts( + pts, dsize, scale=1.5, vx_ratio=0, vy_ratio=-0.1, flag_do_rot=True, **kwargs +): + """calculate the affine matrix of the cropped image from sparse points, the original image to the cropped image, the inverse is the cropped image to the original image + pts: landmark, 101 or 68 points or other points, Nx2 + scale: the larger scale factor, the smaller face ratio + vx_ratio: x shift + vy_ratio: y shift, the smaller the y shift, the lower the face region + rot_flag: if it is true, conduct correction + """ + center, size, angle = parse_rect_from_landmark( + pts, + scale=scale, + vx_ratio=vx_ratio, + vy_ratio=vy_ratio, + use_lip=kwargs.get("use_lip", True), + ) + + s = dsize / size[0] # scale + tgt_center = np.array([dsize / 2, dsize / 2], dtype=DTYPE) # center of dsize + + if flag_do_rot: + costheta, sintheta = cos(angle), sin(angle) + cx, cy = center[0], center[1] # ori center + tcx, tcy = tgt_center[0], tgt_center[1] # target center + # need to infer + M_INV = np.array( + [ + [s * costheta, s * sintheta, tcx - s * (costheta * cx + sintheta * cy)], + [ + -s * sintheta, + s * costheta, + tcy - s * (-sintheta * cx + costheta * cy), + ], + ], + dtype=DTYPE, + ) + else: + M_INV = np.array( + [ + [s, 0, tgt_center[0] - s * center[0]], + [0, s, tgt_center[1] - s * center[1]], + ], + dtype=DTYPE, + ) + + M_INV_H = np.vstack([M_INV, np.array([0, 0, 1])]) + M = np.linalg.inv(M_INV_H) + + # M_INV is from the original image to the cropped image, M is from the cropped image to the original image + return M_INV, M[:2, ...] + + +def crop_image(img, pts: np.ndarray, **kwargs): + dsize = kwargs.get("dsize", 224) + scale = kwargs.get("scale", 1.5) # 1.5 | 1.6 + vy_ratio = kwargs.get("vy_ratio", -0.1) # -0.0625 | -0.1 + + pt_crop_flag = kwargs.get("pt_crop_flag", True) + + M_INV, _ = _estimate_similar_transform_from_pts( + pts, + dsize=dsize, + scale=scale, + vy_ratio=vy_ratio, + flag_do_rot=kwargs.get("flag_do_rot", True), + ) + + img_crop = _transform_img(img, M_INV, dsize) # origin to crop + if pt_crop_flag: + pt_crop = _transform_pts(pts, M_INV) + else: + pt_crop = None + + M_o2c = np.vstack([M_INV, np.array([0, 0, 1], dtype=DTYPE)]) + M_c2o = np.linalg.inv(M_o2c) + + ret_dct = { + "M_o2c": M_o2c, # from the original image to the cropped image 3x3 + "M_c2o": M_c2o, # from the cropped image to the original image 3x3 + "img_crop": img_crop, # the cropped image + "pt_crop": pt_crop, # the landmarks of the cropped image + } + + return ret_dct + + +def average_bbox_lst(bbox_lst): + if len(bbox_lst) == 0: + return None + bbox_arr = np.array(bbox_lst) + return np.mean(bbox_arr, axis=0).tolist() + + +def prepare_paste_back(mask_crop, crop_M_c2o, dsize): + """prepare mask for later image paste back""" + mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize) + mask_ori = mask_ori.astype(np.float32) / 255.0 + return mask_ori + + +def paste_back(img_crop, M_c2o, img_ori, mask_ori): + """paste back the image""" + dsize = (img_ori.shape[1], img_ori.shape[0]) + result = _transform_img(img_crop, M_c2o, dsize=dsize) + result = np.clip(mask_ori * result + (1 - mask_ori) * img_ori, 0, 255).astype( + np.uint8 + ) + return result diff --git a/core/utils/eye_info.py b/core/utils/eye_info.py new file mode 100644 index 0000000000000000000000000000000000000000..01fdc1e1d791e0777711642cc415d8c6bd90fdfe --- /dev/null +++ b/core/utils/eye_info.py @@ -0,0 +1,111 @@ +import numpy as np +from dataclasses import dataclass + + +@dataclass +class EyeIdxMP: + LO = [33] + LI = [133] + LD = [7, 163, 144, 145, 153, 154, 155] # O -> I + LU = [246, 161, 160, 159, 158, 157, 173] # O -> I + RO = [263] + RI = [362] + RD = [249, 390, 373, 374, 380, 381, 382] # O -> I + RU = [466, 388, 387, 386, 385, 384, 398] # O -> I + + LW = [33, 133] # oi + LH0 = [145, 159] + LH1 = [144, 160] + LH2 = [153, 158] + + RW = [263, 362] # oi + RH0 = [374, 386] + RH1 = [373, 387] + RH2 = [380, 385] + + LB = [468] # eye ball + RB = [473] + + +class EyeAttrUtilsByMP: + def __init__(self, lmks_mp): + self.IDX = EyeIdxMP() + self.lmks = lmks_mp # [n, 478, 3] + + self.L_width = self._dist_idx(*self.IDX.LW) # [n] + self.R_width = self._dist_idx(*self.IDX.RW) + + self.L_h0 = self._dist_idx(*self.IDX.LH0) + self.L_h1 = self._dist_idx(*self.IDX.LH1) + self.L_h2 = self._dist_idx(*self.IDX.LH2) + + self.R_h0 = self._dist_idx(*self.IDX.RH0) + self.R_h1 = self._dist_idx(*self.IDX.RH1) + self.R_h2 = self._dist_idx(*self.IDX.RH2) + + self.L_open = (self.L_h0 + self.L_h1 + self.L_h2) / (self.L_width + 1e-8) # [n] + self.R_open = (self.R_h0 + self.R_h1 + self.R_h2) / (self.R_width + 1e-8) + + self.L_center = self._center_idx(*self.IDX.LW) # [n, 3/2] + self.R_center = self._center_idx(*self.IDX.RW) + + self.L_ball = self.lmks[:, self.IDX.LB[0]] # [n, 3/2] + self.R_ball = self.lmks[:, self.IDX.RB[0]] + + self.L_ball_direc = (self.L_ball - self.L_center) / (self.L_width[:, None] + 1e-8) # [n, 3/2] + self.R_ball_direc = (self.R_ball - self.R_center) / (self.R_width[:, None] + 1e-8) + + self.L_eye_direc = self._direc_idx(*self.IDX.LW) # I->O + self.R_eye_direc = self._direc_idx(*self.IDX.RW) + + self.L_ball_move_dist = self._dist(self.L_ball, self.L_center) + self.R_ball_move_dist = self._dist(self.R_ball, self.R_center) + + self.L_ball_move_direc = self._direc(self.L_ball, self.L_center) - self.L_eye_direc + self.R_ball_move_direc = self._direc(self.R_ball, self.R_center) - self.R_eye_direc + + self.L_ball_move = self.L_ball_move_direc * self.L_ball_move_dist[:, None] + self.R_ball_move = self.R_ball_move_direc * self.R_ball_move_dist[:, None] + + def LR_open(self): + LR_open = np.stack([self.L_open, self.R_open], -1) # [n, 2] + return LR_open + + def LR_ball_direc(self): + LR_ball_direc = np.stack([self.L_ball_direc, self.R_ball_direc], -1) # [n, 3, 2] + return LR_ball_direc + + def LR_ball_move(self): + LR_ball_move = np.stack([self.L_ball_move, self.R_ball_move], -1) + return LR_ball_move + + @staticmethod + def _dist(p1, p2): + # p1/p2: [n, 3/2] + return (((p1 - p2) ** 2).sum(-1)) ** 0.5 # [n] + + @staticmethod + def _center(p1, p2): + return (p1 + p2) * 0.5 # [n, 3/2] + + def _direc(self, p1, p2): + # p1 - p2, (2->1) + return (p1 - p2) / (self._dist(p1, p2)[:, None] + 1e-8) + + def _dist_idx(self, idx1, idx2): + p1 = self.lmks[:, idx1] + p2 = self.lmks[:, idx2] + d = self._dist(p1, p2) + return d + + def _center_idx(self, idx1, idx2): + p1 = self.lmks[:, idx1] + p2 = self.lmks[:, idx2] + c = self._center(p1, p2) + return c + + def _direc_idx(self, idx1, idx2): + p1 = self.lmks[:, idx1] + p2 = self.lmks[:, idx2] + dir = self._direc(p1, p2) + return dir \ No newline at end of file diff --git a/core/utils/get_mask.py b/core/utils/get_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..1cb83658455011bbeb6e93b3461ef5b777946a52 --- /dev/null +++ b/core/utils/get_mask.py @@ -0,0 +1,67 @@ +import numpy as np + + +def get_mask(W, H, ratio_w=0.9, ratio_h=0.9): + w = int(W * ratio_w) + h = int(H * ratio_h) + + x1 = (W - w) // 2 + x2 = x1 + w + + y1 = (H - h) // 2 + y2 = y1 + h + + mask = np.ones((H, W), dtype=np.float32) + + # top + row = np.linspace(0, 0, w)[None, :] # (1, w) + col = np.linspace(0, 1, y1)[:, None] # (y1, 1) + grad_t = np.sqrt(row**2 + col**2).astype(np.float32) # (y1, w) + mask[0:y1, x1:x2] = grad_t + + # bottom + row = np.linspace(0, 0, w)[None, :] # (1, w) + col = np.linspace(1, 0, H - y2)[:, None] # (H-y2, 1) + grad_b = np.sqrt(row**2 + col**2).astype(np.float32) # (H-y2, w) + mask[y2:H, x1:x2] = grad_b + + # left + row = np.linspace(0, 1, x1)[None, :] # (1, x1) + col = np.linspace(0, 0, h)[:, None] # (h, 1) + grad_l = np.sqrt(row**2 + col**2).astype(np.float32) # (h, x1) + mask[y1:y2, 0:x1] = grad_l + + # right + row = np.linspace(1, 0, W - x2)[None, :] # (1, W-x2) + col = np.linspace(0, 0, h)[:, None] # (h, 1) + grad_r = np.sqrt(row**2 + col**2).astype(np.float32) # (h, W-x2) + mask[y1:y2, x2:W] = grad_r + + # top left + row = np.linspace(1, 0, x1)[None, :] # (1, w) + col = np.linspace(1, 0, y1)[:, None] # (y1, 1) + grad_tl = np.sqrt(row**2 + col**2).astype(np.float32) # (y1, x1) + grad_tl = 1 - np.clip(grad_tl, 0, 1) + mask[0:y1, 0:x1] = grad_tl + + # top right + row = np.linspace(0, 1, W - x2)[None, :] # (1, W-x2) + col = np.linspace(1, 0, y1)[:, None] # (y1, 1) + grad_tr = np.sqrt(row**2 + col**2).astype(np.float32) # (y1, W-x2) + grad_tr = 1 - np.clip(grad_tr, 0, 1) + mask[0:y1, x2:W] = grad_tr + + # bottom left + row = np.linspace(1, 0, x1)[None, :] # (1, x1) + col = np.linspace(0, 1, H - y2)[:, None] # (H-y2, 1) + grad_bl = np.sqrt(row**2 + col**2).astype(np.float32) # (H-y2, x1) + grad_bl = 1 - np.clip(grad_bl, 0, 1) + mask[y2:H, 0:x1] = grad_bl + + # bottom right + row = np.linspace(0, 1, W - x2)[None, :] # (1, W-x2) + col = np.linspace(0, 1, H - y2)[:, None] # (H-y2, 1) + grad_br = np.sqrt(row**2 + col**2).astype(np.float32) # (H-y2, W-x2) + grad_br = 1 - np.clip(grad_br, 0, 1) + mask[y2:H, x2:W] = grad_br + return mask[:, :, None] diff --git a/core/utils/load_model.py b/core/utils/load_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2fda8e9405117758f95faadbfff2be3407949f --- /dev/null +++ b/core/utils/load_model.py @@ -0,0 +1,64 @@ +def load_model(model_path: str, device: str = "cuda", **kwargs): + if kwargs.get("force_ori_type", False): + # for hubert, landmark, retinaface, mediapipe + model = load_force_ori_type(model_path, device, **kwargs) + return model, "ori" + + if model_path.endswith(".onnx"): + # onnx + import onnxruntime + + if device == "cuda": + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + model = onnxruntime.InferenceSession(model_path, providers=providers) + return model, "onnx" + + elif model_path.endswith(".engine") or model_path.endswith(".trt"): + # tensorRT + from .tensorrt_utils import TRTWrapper + + model = TRTWrapper(model_path) + return model, "tensorrt" + + elif model_path.endswith(".pt") or model_path.endswith(".pth"): + # pytorch + model = create_model(model_path, device, **kwargs) + return model, "pytorch" + + else: + raise ValueError(f"Unsupported model file type: {model_path}") + + +def create_model( + model_path: str, + device: str = "cuda", + module_name="", + package_name="..models.modules", + **kwargs, +): + import importlib + + # module = getattr(importlib.import_module('..models.modules', __package__), module_name) + module = getattr(importlib.import_module(package_name, __package__), module_name) + # from import + + model = module(**kwargs) + model.load_model(model_path).to(device) + return model + + +def load_force_ori_type( + model_path: str, + device: str = "cuda", + module_name="", + package_name="..aux_models.modules", + force_ori_type=False, + **kwargs, +): + import importlib + + module = getattr(importlib.import_module(package_name, __package__), module_name) + model = module(**kwargs) + return model diff --git a/core/utils/tensorrt_utils.py b/core/utils/tensorrt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ff233b9c54ac81e25722e36afe0b8a70def27844 --- /dev/null +++ b/core/utils/tensorrt_utils.py @@ -0,0 +1,273 @@ +import ctypes +from collections import OrderedDict +from typing import Type +from cuda import cuda, cudart, nvrtc +import numpy as np +import torch +import ctypes +import os +import torch + +try: + import tensorrt as trt +except ImportError: + import tensorrt_libs + + trt_libs_path = tensorrt_libs.__path__[0] + ctypes.CDLL(os.path.join(trt_libs_path, "libnvinfer.so.8")) + ctypes.CDLL(os.path.join(trt_libs_path, "libnvinfer_plugin.so.8")) + ctypes.CDLL(os.path.join(trt_libs_path, "libnvonnxparser.so.8")) + ctypes.CDLL(os.path.join(trt_libs_path, "libnvparsers.so.8")) + ctypes.CDLL(os.path.join(trt_libs_path, "libnvinfer_builder_resource.so.8.6.1")) + import tensorrt as trt + +logger = trt.Logger(trt.Logger.ERROR) +trt.init_libnvinfer_plugins(logger, "") + + +def _cudaGetErrorEnum(error): + if isinstance(error, cuda.CUresult): + err, name = cuda.cuGetErrorName(error) + return name if err == cuda.CUresult.CUDA_SUCCESS else "" + elif isinstance(error, cudart.cudaError_t): + return cudart.cudaGetErrorName(error)[1] + elif isinstance(error, nvrtc.nvrtcResult): + return nvrtc.nvrtcGetErrorString(error)[1] + else: + raise RuntimeError("Unknown error type: {}".format(error)) + + +def checkCudaErrors(result): + if result[0].value: + raise RuntimeError( + "CUDA error code={}({})".format( + result[0].value, _cudaGetErrorEnum(result[0]) + ) + ) + if len(result) == 1: + return None + elif len(result) == 2: + return result[1] + else: + return result[1:] + + +class MyOutputAllocator(trt.IOutputAllocator): + def __init__(self) -> None: + super().__init__() + # members for outside use + self.shape = None + self.n_bytes = 0 + self.address = 0 + + def reallocate_output(self, tensor_name, old_address, size, alignment) -> int: + return self.reallocate_common(tensor_name, old_address, size, alignment) + + def reallocate_output_async( + self, tensor_name, old_address, size, alignment, stream + ) -> int: + return self.reallocate_common(tensor_name, old_address, size, alignment, stream) + + def notify_shape(self, tensor_name, shape): + self.shape = shape + return + + def reallocate_common( + self, tensor_name, old_address, size, alignment, stream=-1 + ): # not necessary API + if size <= self.n_bytes: + return old_address + if old_address != 0: + checkCudaErrors(cudart.cudaFree(old_address)) + if stream == -1: + address = checkCudaErrors(cudart.cudaMalloc(size)) + else: + address = checkCudaErrors(cudart.cudaMallocAsync(size, stream)) + self.n_bytes = size + self.address = address + return address + + +class TRTWrapper: + def __init__( + self, + trt_file: str, + plugin_file_list: list = [], + ) -> None: + # Load custom plugins + for plugin_file in plugin_file_list: + ctypes.cdll.LoadLibrary(plugin_file) + + # Load engine bytes from file + self.model = trt_file + with open(trt_file, "rb") as f, trt.Runtime(logger) as runtime: + assert runtime + self.engine = runtime.deserialize_cuda_engine(f.read()) + assert self.engine + self.buffer = OrderedDict() + self.output_allocator_map = OrderedDict() + self.context = self.engine.create_execution_context() + return + + def setup(self, input_data: dict = {}) -> None: + for name, value in self.buffer.items(): + _, device_buffer, _ = value + if ( + device_buffer is not None + and device_buffer != 0 + and name not in self.output_allocator_map + ): + checkCudaErrors(cudart.cudaFree(device_buffer)) + self.buffer[name][1] = None + self.buffer[name][2] = 0 + self.tensor_name_list = [ + self.engine.get_tensor_name(i) for i in range(self.engine.num_io_tensors) + ] + self.n_input = sum( + [ + self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT + for name in self.tensor_name_list + ] + ) + self.n_output = self.engine.num_io_tensors - self.n_input + + for name, data in input_data.items(): + if self.engine.get_tensor_location(name) == trt.TensorLocation.DEVICE: + self.context.set_input_shape(name, data.shape) + else: + self.context.set_tensor_address(name, data.ctypes.data) + + # Prepare work before inference + for name in self.tensor_name_list: + data_type = self.engine.get_tensor_dtype(name) + runtime_shape = self.context.get_tensor_shape(name) + if name not in self.output_allocator_map: + if -1 in runtime_shape: + # for Data-Dependent-Shape (DDS) output, "else" branch for normal output + n_byte = 0 # self.context.get_max_output_size(name) + self.output_allocator_map[name] = MyOutputAllocator() + self.context.set_output_allocator( + name, self.output_allocator_map[name] + ) + host_buffer = np.empty(0, dtype=trt.nptype(data_type)) + device_buffer = None + else: + n_byte = trt.volume(runtime_shape) * data_type.itemsize + host_buffer = np.empty(runtime_shape, dtype=trt.nptype(data_type)) + if ( + self.engine.get_tensor_location(name) + == trt.TensorLocation.DEVICE + ): + device_buffer = checkCudaErrors(cudart.cudaMalloc(n_byte)) + else: + device_buffer = None + self.buffer[name] = [host_buffer, device_buffer, n_byte] + else: + # for DDS output, don't need to reallocate + pass + + for name, data in input_data.items(): + self.buffer[name][0] = np.ascontiguousarray(data) + + for name in self.tensor_name_list: + if self.engine.get_tensor_location(name) == trt.TensorLocation.DEVICE: + if self.buffer[name][1] is not None: + self.context.set_tensor_address(name, self.buffer[name][1]) + elif self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT: + self.context.set_tensor_address(name, self.buffer[name][0].ctypes.data) + + return + + def infer(self, stream=0) -> None: + # Do inference and print output + for name in self.tensor_name_list: + if ( + self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT + and self.engine.get_tensor_location(name) == trt.TensorLocation.DEVICE + ): + cudart.cudaMemcpy( + self.buffer[name][1], + self.buffer[name][0].ctypes.data, + self.buffer[name][2], + cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, + ) + + self.context.execute_async_v3(stream) + + for name in self.output_allocator_map: + myOutputAllocator = self.context.get_output_allocator(name) + runtime_shape = myOutputAllocator.shape + data_type = self.engine.get_tensor_dtype(name) + host_buffer = np.empty(runtime_shape, dtype=trt.nptype(data_type)) + device_buffer = myOutputAllocator.address + n_bytes = trt.volume(runtime_shape) * data_type.itemsize + self.buffer[name] = [host_buffer, device_buffer, n_bytes] + + for name in self.tensor_name_list: + if ( + self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT + and self.engine.get_tensor_location(name) == trt.TensorLocation.DEVICE + ): + cudart.cudaMemcpy( + self.buffer[name][0].ctypes.data, + self.buffer[name][1], + self.buffer[name][2], + cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, + ) + + return + + def infer_async(self, stream=0) -> None: + # Do inference and print output + for name in self.tensor_name_list: + if ( + self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT + and self.engine.get_tensor_location(name) == trt.TensorLocation.DEVICE + ): + cudart.cudaMemcpyAsync( + self.buffer[name][1], + self.buffer[name][0].ctypes.data, + self.buffer[name][2], + cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, + stream=stream, + ) + + self.context.execute_async_v3(stream) + + for name in self.output_allocator_map: + myOutputAllocator = self.context.get_output_allocator(name) + runtime_shape = myOutputAllocator.shape + data_type = self.engine.get_tensor_dtype(name) + host_buffer = np.empty(runtime_shape, dtype=trt.nptype(data_type)) + device_buffer = myOutputAllocator.address + n_bytes = trt.volume(runtime_shape) * data_type.itemsize + self.buffer[name] = [host_buffer, device_buffer, n_bytes] + + for name in self.tensor_name_list: + if ( + self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT + and self.engine.get_tensor_location(name) == trt.TensorLocation.DEVICE + ): + cudart.cudaMemcpyAsync( + self.buffer[name][0].ctypes.data, + self.buffer[name][1], + self.buffer[name][2], + cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, + stream=stream, + ) + + return + + def __del__(self): + if hasattr(self, "buffer") and self.buffer is not None: + for _, device_buffer, _ in self.buffer.values(): + if ( + device_buffer is not None + and device_buffer != 0 + and cudart is not None + ): + try: + checkCudaErrors(cudart.cudaFree(device_buffer)) + except TypeError: + pass + return diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bae39453d68edbb7af3f88de11bb3b89220d51e5 --- /dev/null +++ b/environment.yaml @@ -0,0 +1,142 @@ +name: ditto +channels: + - pytorch + - nvidia + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - blas=1.0=mkl + - brotli-python=1.0.9=py310h6a678d5_9 + - bzip2=1.0.8=h5eee18b_6 + - ca-certificates=2024.11.26=h06a4308_0 + - certifi=2024.12.14=py310h06a4308_0 + - charset-normalizer=3.3.2=pyhd3eb1b0_0 + - cuda-cudart=12.1.105=0 + - cuda-cupti=12.1.105=0 + - cuda-libraries=12.1.0=0 + - cuda-nvrtc=12.1.105=0 + - cuda-nvtx=12.1.105=0 + - cuda-opencl=12.6.77=0 + - cuda-runtime=12.1.0=0 + - cuda-version=12.6=3 + - ffmpeg=4.3=hf484d3e_0 + - filelock=3.13.1=py310h06a4308_0 + - freetype=2.12.1=h4a9f257_0 + - giflib=5.2.2=h5eee18b_0 + - gmp=6.2.1=h295c915_3 + - gmpy2=2.1.2=py310heeb90bb_0 + - gnutls=3.6.15=he1e5248_0 + - idna=3.7=py310h06a4308_0 + - intel-openmp=2023.1.0=hdb19cb5_46306 + - jinja2=3.1.4=py310h06a4308_1 + - jpeg=9e=h5eee18b_3 + - lame=3.100=h7b6447c_0 + - lcms2=2.16=hb9589c4_0 + - ld_impl_linux-64=2.40=h12ee557_0 + - lerc=4.0.0=h6a678d5_0 + - libcublas=12.1.0.26=0 + - libcufft=11.0.2.4=0 + - libcufile=1.11.1.6=0 + - libcurand=10.3.7.77=0 + - libcusolver=11.4.4.55=0 + - libcusparse=12.0.2.55=0 + - libdeflate=1.22=h5eee18b_0 + - libffi=3.4.4=h6a678d5_1 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libiconv=1.16=h5eee18b_3 + - libidn2=2.3.4=h5eee18b_0 + - libjpeg-turbo=2.0.0=h9bf148f_0 + - libnpp=12.0.2.50=0 + - libnvjitlink=12.1.105=0 + - libnvjpeg=12.1.1.14=0 + - libpng=1.6.39=h5eee18b_0 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.19.0=h5eee18b_0 + - libtiff=4.5.1=hffd6297_1 + - libunistring=0.9.10=h27cfd23_0 + - libuuid=1.41.5=h5eee18b_0 + - libwebp=1.3.2=h11a3e52_0 + - libwebp-base=1.3.2=h5eee18b_1 + - llvm-openmp=14.0.6=h9e868ea_0 + - lz4-c=1.9.4=h6a678d5_1 + - markupsafe=2.1.3=py310h5eee18b_0 + - mkl=2023.1.0=h213fc3f_46344 + - mkl-service=2.4.0=py310h5eee18b_1 + - mkl_fft=1.3.11=py310h5eee18b_0 + - mkl_random=1.2.8=py310h1128e8f_0 + - mpc=1.1.0=h10f8cd9_1 + - mpfr=4.0.2=hb69a4c5_1 + - mpmath=1.3.0=py310h06a4308_0 + - ncurses=6.4=h6a678d5_0 + - nettle=3.7.3=hbbd107a_1 + - networkx=3.2.1=py310h06a4308_0 + - numpy=2.0.1=py310h5f9d8c6_1 + - numpy-base=2.0.1=py310hb5e798b_1 + - openh264=2.1.1=h4ff587b_0 + - openjpeg=2.5.2=he7f1fd0_0 + - openssl=3.0.15=h5eee18b_0 + - pillow=11.0.0=py310hcea889d_1 + - pip=24.2=py310h06a4308_0 + - pysocks=1.7.1=py310h06a4308_0 + - python=3.10.16=he870216_1 + - pytorch=2.5.1=py3.10_cuda12.1_cudnn9.1.0_0 + - pytorch-cuda=12.1=ha16c6d3_6 + - pytorch-mutex=1.0=cuda + - pyyaml=6.0.2=py310h5eee18b_0 + - readline=8.2=h5eee18b_0 + - requests=2.32.3=py310h06a4308_1 + - setuptools=75.1.0=py310h06a4308_0 + - sqlite=3.45.3=h5eee18b_0 + - sympy=1.13.3=py310h06a4308_0 + - tbb=2021.8.0=hdb19cb5_0 + - tk=8.6.14=h39e8969_0 + - torchaudio=2.5.1=py310_cu121 + - torchtriton=3.1.0=py310 + - torchvision=0.20.1=py310_cu121 + - typing_extensions=4.12.2=py310h06a4308_0 + - tzdata=2024b=h04d1e81_0 + - urllib3=2.2.3=py310h06a4308_0 + - wheel=0.44.0=py310h06a4308_0 + - xz=5.4.6=h5eee18b_1 + - yaml=0.2.5=h7b6447c_0 + - zlib=1.2.13=h5eee18b_1 + - zstd=1.5.6=hc292b87_0 + - pip: + - audioread==3.0.1 + - cffi==1.17.1 + - cuda-python==12.6.2.post1 + - cython==3.0.11 + - decorator==5.1.1 + - filetype==1.2.0 + - imageio==2.36.1 + - imageio-ffmpeg==0.5.1 + - joblib==1.4.2 + - lazy-loader==0.4 + - librosa==0.10.2.post1 + - llvmlite==0.43.0 + - msgpack==1.1.0 + - numba==0.60.0 + - nvidia-cublas-cu12==12.6.4.1 + - nvidia-cuda-runtime-cu12==12.6.77 + - nvidia-cudnn-cu12==9.6.0.74 + - opencv-python-headless==4.10.0.84 + - packaging==24.2 + - platformdirs==4.3.6 + - pooch==1.8.2 + - pycparser==2.22 + - scikit-image==0.25.0 + - scikit-learn==1.6.0 + - scipy==1.15.0 + - soundfile==0.13.0 + - soxr==0.5.0.post1 + - tensorrt==8.6.1 + - tensorrt-bindings==8.6.1 + - tensorrt-libs==8.6.1 + - threadpoolctl==3.5.0 + - tifffile==2024.12.12 + - tqdm==4.67.1 + - polygraphy + - colored +prefix: /opt/conda/envs/ditto diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..155109ce6ab260eb1fa04ea4459b9107c19e0f9f --- /dev/null +++ b/inference.py @@ -0,0 +1,89 @@ +import librosa +import math +import os +import numpy as np +import random +import torch +import pickle + +from stream_pipeline_offline import StreamSDK + + +def seed_everything(seed): + os.environ["PYTHONHASHSEED"] = str(seed) + os.environ["PL_GLOBAL_SEED"] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def load_pkl(pkl): + with open(pkl, "rb") as f: + return pickle.load(f) + + +def run(SDK: StreamSDK, audio_path: str, source_path: str, output_path: str, more_kwargs: str | dict = {}): + + if isinstance(more_kwargs, str): + more_kwargs = load_pkl(more_kwargs) + setup_kwargs = more_kwargs.get("setup_kwargs", {}) + run_kwargs = more_kwargs.get("run_kwargs", {}) + + SDK.setup(source_path, output_path, **setup_kwargs) + + audio, sr = librosa.core.load(audio_path, sr=16000) + num_f = math.ceil(len(audio) / 16000 * 25) + + fade_in = run_kwargs.get("fade_in", -1) + fade_out = run_kwargs.get("fade_out", -1) + ctrl_info = run_kwargs.get("ctrl_info", {}) + SDK.setup_Nd(N_d=num_f, fade_in=fade_in, fade_out=fade_out, ctrl_info=ctrl_info) + + online_mode = SDK.online_mode + if online_mode: + chunksize = run_kwargs.get("chunksize", (3, 5, 2)) + audio = np.concatenate([np.zeros((chunksize[0] * 640,), dtype=np.float32), audio], 0) + split_len = int(sum(chunksize) * 0.04 * 16000) + 80 # 6480 + for i in range(0, len(audio), chunksize[1] * 640): + audio_chunk = audio[i:i + split_len] + if len(audio_chunk) < split_len: + audio_chunk = np.pad(audio_chunk, (0, split_len - len(audio_chunk)), mode="constant") + SDK.run_chunk(audio_chunk, chunksize) + else: + aud_feat = SDK.wav2feat.wav2feat(audio) + SDK.audio2motion_queue.put(aud_feat) + SDK.close() + + cmd = f'ffmpeg -loglevel error -y -i "{SDK.tmp_output_path}" -i "{audio_path}" -map 0:v -map 1:a -c:v copy -c:a aac "{output_path}"' + print(cmd) + os.system(cmd) + + print(output_path) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--data_root", type=str, default="./checkpoints/ditto_trt_Ampere_Plus", help="path to trt data_root") + parser.add_argument("--cfg_pkl", type=str, default="./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl", help="path to cfg_pkl") + + parser.add_argument("--audio_path", type=str, help="path to input wav") + parser.add_argument("--source_path", type=str, help="path to input image") + parser.add_argument("--output_path", type=str, help="path to output mp4") + args = parser.parse_args() + + # init sdk + data_root = args.data_root # model dir + cfg_pkl = args.cfg_pkl # cfg pkl + SDK = StreamSDK(cfg_pkl, data_root) + + # input args + audio_path = args.audio_path # .wav + source_path = args.source_path # video|image + output_path = args.output_path # .mp4 + + # run + # seed_everything(1024) + run(SDK, audio_path, source_path, output_path) diff --git a/scripts/cvt_onnx_to_trt.py b/scripts/cvt_onnx_to_trt.py new file mode 100644 index 0000000000000000000000000000000000000000..8e7b1c0e2a268af31188bf6956d2e07bcdee736a --- /dev/null +++ b/scripts/cvt_onnx_to_trt.py @@ -0,0 +1,164 @@ +import os +import torch +import argparse + + +def onnx_to_trt(onnx_file, trt_file, fp16=False, more_cmd=None): + cap = torch.cuda.get_device_capability() + if cap[0] >= 8: + compatiable = "--hardware-compatibility-level=Ampere_Plus" + else: + compatiable = "" + cmd = [ + "polygraphy", + "convert", + onnx_file, + "-o", + trt_file, + compatiable, + "--fp16" if fp16 else "", + f"--builder-optimization-level=5", + ] + if more_cmd: + cmd = cmd + more_cmd + print(" ".join(cmd)) + os.system(" ".join(cmd)) + + +def onnx_to_trt_for_gridsample(onnx_file, trt_file, fp16=False, plugin_file="./libgrid_sample_3d_plugin.so"): + import tensorrt as trt + + logger = trt.Logger(trt.Logger.INFO) + trt.init_libnvinfer_plugins(logger, "") + plugin_libs = [plugin_file] + + onnx_path = onnx_file + engine_path = trt_file + + builder = trt.Builder(logger) + for pluginlib in plugin_libs: + builder.get_plugin_registry().load_library(pluginlib) + network = builder.create_network( + 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + ) + + parser = trt.OnnxParser(network, logger) + res = parser.parse_from_file(onnx_path) # parse from file + if not res: + print(f"Fail parsing {onnx_path}") + for i in range(parser.num_errors): # Get error information + error = parser.get_error(i) + print(error) # Print error information + print( + f"{error.code() = }\n{error.file() = }\n{error.func() = }\n{error.line() = }\n{error.local_function_stack_size() = }" + ) + print( + f"{error.local_function_stack() = }\n{error.node_name() = }\n{error.node_operator() = }\n{error.node() = }" + ) + parser.clear_errors() + config = builder.create_builder_config() + # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) + config.builder_optimization_level = 5 + # Set the flag of hardware compatibility, Hardware-compatible engines are only supported on Ampere and beyond + cap = torch.cuda.get_device_capability() + if cap[0] >= 8: + compatible = True + else: + compatible = False + + if compatible: + config.hardware_compatibility_level = ( + trt.HardwareCompatibilityLevel.AMPERE_PLUS + ) + + if fp16: + config.set_flag(trt.BuilderFlag.FP16) + config.set_flag(trt.BuilderFlag.PREFER_PRECISION_CONSTRAINTS) + config.set_preview_feature(trt.PreviewFeature.PROFILE_SHARING_0806, True) + exclude_list = [ + "SHAPE", + "ASSERTION", + "SHUFFLE", + "IDENTITY", + "CONSTANT", + "CONCATENATION", + "GATHER", + "SLICE", + "CONDITION", + "CONDITIONAL_INPUT", + "CONDITIONAL_OUTPUT", + "FILL", + "NON_ZERO", + "ONE_HOT", + ] + for i in range(0, network.num_layers): + layer = network.get_layer(i) + if str(layer.type)[10:] in exclude_list: + continue + if "GridSample" in layer.name: + print(f"set {layer.name} to float32") + layer.precision = trt.float32 + config.plugins_to_serialize = plugin_libs + engineString = builder.build_serialized_network(network, config) + if engineString is not None: + with open(engine_path, "wb") as f: + f.write(engineString) + + +def main(onnx_dir, trt_dir, grid_sample_plugin_file=""): + names = [i[:-5] for i in os.listdir(onnx_dir) if i.endswith(".onnx")] + for name in names: + if name == "warp_network_ori": + continue + + print("=" * 20, f"{name} start", "=" * 20) + + fp16 = False if name in {"motion_extractor", "hubert", "wavlm"} or name.startswith("lmdm") else True + + more_cmd = None + if name == "wavlm": + more_cmd = [ + "--trt-min-shapes audio:[1,1000]", + "--trt-max-shapes audio:[1,320080]", + "--trt-opt-shapes audio:[1,320080]", + ] + elif name == "hubert": + more_cmd = [ + "--trt-min-shapes input_values:[1,3240]", + "--trt-max-shapes input_values:[1,12960]", + "--trt-opt-shapes input_values:[1,6480]", + ] + + + onnx_file = f"{onnx_dir}/{name}.onnx" + trt_file = f"{trt_dir}/{name}_fp{16 if fp16 else 32}.engine" + + if os.path.isfile(trt_file): + print("=" * 20, f"{name} skip", "=" * 20) + continue + + if name == "warp_network": + onnx_to_trt_for_gridsample(onnx_file, trt_file, fp16, plugin_file=grid_sample_plugin_file) + else: + onnx_to_trt(onnx_file, trt_file, fp16, more_cmd=more_cmd) + + print("=" * 20, f"{name} done", "=" * 20) + + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--onnx_dir", type=str, help="input onnx dir") + parser.add_argument("--trt_dir", type=str, help="output trt dir") + args = parser.parse_args() + + onnx_dir = args.onnx_dir + trt_dir = args.trt_dir + + assert os.path.isdir(onnx_dir) + os.makedirs(trt_dir, exist_ok=True) + + grid_sample_plugin_file = os.path.join(onnx_dir, "libgrid_sample_3d_plugin.so") + main(onnx_dir, trt_dir, grid_sample_plugin_file) + diff --git a/stream_pipeline_offline.py b/stream_pipeline_offline.py new file mode 100644 index 0000000000000000000000000000000000000000..fea86099b0eebc881cb2be2d85246da69a265a03 --- /dev/null +++ b/stream_pipeline_offline.py @@ -0,0 +1,545 @@ +import threading +import queue +import numpy as np +import traceback +from tqdm import tqdm + +from core.atomic_components.avatar_registrar import AvatarRegistrar, smooth_x_s_info_lst +from core.atomic_components.condition_handler import ConditionHandler, _mirror_index +from core.atomic_components.audio2motion import Audio2Motion +from core.atomic_components.motion_stitch import MotionStitch +from core.atomic_components.warp_f3d import WarpF3D +from core.atomic_components.decode_f3d import DecodeF3D +from core.atomic_components.putback import PutBack +from core.atomic_components.writer import VideoWriterByImageIO +from core.atomic_components.wav2feat import Wav2Feat +from core.atomic_components.cfg import parse_cfg, print_cfg + + +class StreamSDK: + def __init__(self, cfg_pkl, data_root, **kwargs): + + [ + avatar_registrar_cfg, + condition_handler_cfg, + lmdm_cfg, + stitch_network_cfg, + warp_network_cfg, + decoder_cfg, + wav2feat_cfg, + default_kwargs, + ] = parse_cfg(cfg_pkl, data_root, kwargs) + + self.default_kwargs = default_kwargs + + self.avatar_registrar = AvatarRegistrar(**avatar_registrar_cfg) + self.condition_handler = ConditionHandler(**condition_handler_cfg) + self.audio2motion = Audio2Motion(lmdm_cfg) + self.motion_stitch = MotionStitch(stitch_network_cfg) + self.warp_f3d = WarpF3D(warp_network_cfg) + self.decode_f3d = DecodeF3D(decoder_cfg) + self.putback = PutBack() + + self.wav2feat = Wav2Feat(**wav2feat_cfg) + + def _merge_kwargs(self, default_kwargs, run_kwargs): + for k, v in default_kwargs.items(): + if k not in run_kwargs: + run_kwargs[k] = v + return run_kwargs + + def setup_Nd(self, N_d, fade_in=-1, fade_out=-1, ctrl_info=None): + # for eye open at video end + self.motion_stitch.set_Nd(N_d) + + # for fade in/out alpha + if ctrl_info is None: + ctrl_info = self.ctrl_info + if fade_in > 0: + for i in range(fade_in): + alpha = i / fade_in + item = ctrl_info.get(i, {}) + item["fade_alpha"] = alpha + ctrl_info[i] = item + if fade_out > 0: + ss = N_d - fade_out - 1 + ee = N_d - 1 + for i in range(ss, N_d): + alpha = max((ee - i) / (ee - ss), 0) + item = ctrl_info.get(i, {}) + item["fade_alpha"] = alpha + ctrl_info[i] = item + self.ctrl_info = ctrl_info + + def setup(self, source_path, output_path, **kwargs): + + # ======== Prepare Options ======== + kwargs = self._merge_kwargs(self.default_kwargs, kwargs) + print("=" * 20, "setup kwargs", "=" * 20) + print_cfg(**kwargs) + print("=" * 50) + + # -- avatar_registrar: template cfg -- + self.max_size = kwargs.get("max_size", 1920) + self.template_n_frames = kwargs.get("template_n_frames", -1) + + # -- avatar_registrar: crop cfg -- + self.crop_scale = kwargs.get("crop_scale", 2.3) + self.crop_vx_ratio = kwargs.get("crop_vx_ratio", 0) + self.crop_vy_ratio = kwargs.get("crop_vy_ratio", -0.125) + self.crop_flag_do_rot = kwargs.get("crop_flag_do_rot", True) + + # -- avatar_registrar: smo for video -- + self.smo_k_s = kwargs.get('smo_k_s', 13) + + # -- condition_handler: ECS -- + self.emo = kwargs.get("emo", 4) # int | [int] | [[int]] | numpy + self.eye_f0_mode = kwargs.get("eye_f0_mode", False) # for video + self.ch_info = kwargs.get("ch_info", None) # dict of np.ndarray + + # -- audio2motion: setup -- + self.overlap_v2 = kwargs.get("overlap_v2", 10) + self.fix_kp_cond = kwargs.get("fix_kp_cond", 0) + self.fix_kp_cond_dim = kwargs.get("fix_kp_cond_dim", None) # [ds,de] + self.sampling_timesteps = kwargs.get("sampling_timesteps", 50) + self.online_mode = kwargs.get("online_mode", False) + self.v_min_max_for_clip = kwargs.get('v_min_max_for_clip', None) + self.smo_k_d = kwargs.get("smo_k_d", 3) + + # -- motion_stitch: setup -- + self.N_d = kwargs.get("N_d", -1) + self.use_d_keys = kwargs.get("use_d_keys", None) + self.relative_d = kwargs.get("relative_d", True) + self.drive_eye = kwargs.get("drive_eye", None) # None: true4image, false4video + self.delta_eye_arr = kwargs.get("delta_eye_arr", None) + self.delta_eye_open_n = kwargs.get("delta_eye_open_n", 0) + self.fade_type = kwargs.get("fade_type", "") # "" | "d0" | "s" + self.fade_out_keys = kwargs.get("fade_out_keys", ("exp",)) + self.flag_stitching = kwargs.get("flag_stitching", True) + + self.ctrl_info = kwargs.get("ctrl_info", dict()) + self.overall_ctrl_info = kwargs.get("overall_ctrl_info", dict()) + """ + ctrl_info: list or dict + { + fid: ctrl_kwargs + } + + ctrl_kwargs (see motion_stitch.py): + fade_alpha + fade_out_keys + + delta_pitch + delta_yaw + delta_roll + """ + + # only hubert support online mode + assert self.wav2feat.support_streaming or not self.online_mode + + # ======== Register Avatar ======== + crop_kwargs = { + "crop_scale": self.crop_scale, + "crop_vx_ratio": self.crop_vx_ratio, + "crop_vy_ratio": self.crop_vy_ratio, + "crop_flag_do_rot": self.crop_flag_do_rot, + } + n_frames = self.template_n_frames if self.template_n_frames > 0 else self.N_d + source_info = self.avatar_registrar( + source_path, + max_dim=self.max_size, + n_frames=n_frames, + **crop_kwargs, + ) + + if len(source_info["x_s_info_lst"]) > 1 and self.smo_k_s > 1: + source_info["x_s_info_lst"] = smooth_x_s_info_lst(source_info["x_s_info_lst"], smo_k=self.smo_k_s) + + self.source_info = source_info + self.source_info_frames = len(source_info["x_s_info_lst"]) + + # ======== Setup Condition Handler ======== + self.condition_handler.setup(source_info, self.emo, eye_f0_mode=self.eye_f0_mode, ch_info=self.ch_info) + + # ======== Setup Audio2Motion (LMDM) ======== + x_s_info_0 = self.condition_handler.x_s_info_0 + self.audio2motion.setup( + x_s_info_0, + overlap_v2=self.overlap_v2, + fix_kp_cond=self.fix_kp_cond, + fix_kp_cond_dim=self.fix_kp_cond_dim, + sampling_timesteps=self.sampling_timesteps, + online_mode=self.online_mode, + v_min_max_for_clip=self.v_min_max_for_clip, + smo_k_d=self.smo_k_d, + ) + + # ======== Setup Motion Stitch ======== + is_image_flag = source_info["is_image_flag"] + x_s_info = source_info['x_s_info_lst'][0] + self.motion_stitch.setup( + N_d=self.N_d, + use_d_keys=self.use_d_keys, + relative_d=self.relative_d, + drive_eye=self.drive_eye, + delta_eye_arr=self.delta_eye_arr, + delta_eye_open_n=self.delta_eye_open_n, + fade_out_keys=self.fade_out_keys, + fade_type=self.fade_type, + flag_stitching=self.flag_stitching, + is_image_flag=is_image_flag, + x_s_info=x_s_info, + d0=None, + ch_info=self.ch_info, + overall_ctrl_info=self.overall_ctrl_info, + ) + + # ======== Video Writer ======== + self.output_path = output_path + self.tmp_output_path = output_path + ".tmp.mp4" + self.writer = VideoWriterByImageIO(self.tmp_output_path) + self.writer_pbar = tqdm(desc="writer") + + # ======== Audio Feat Buffer ======== + if self.online_mode: + # buffer: seq_frames - valid_clip_len + self.audio_feat = self.wav2feat.wav2feat(np.zeros((self.overlap_v2 * 640,), dtype=np.float32), sr=16000) + assert len(self.audio_feat) == self.overlap_v2, f"{len(self.audio_feat)}" + else: + self.audio_feat = np.zeros((0, self.wav2feat.feat_dim), dtype=np.float32) + self.cond_idx_start = 0 - len(self.audio_feat) + + # ======== Setup Worker Threads ======== + QUEUE_MAX_SIZE = 100 + # self.QUEUE_TIMEOUT = None + + self.worker_exception = None + self.stop_event = threading.Event() + + self.audio2motion_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE) + self.motion_stitch_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE) + self.warp_f3d_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE) + self.decode_f3d_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE) + self.putback_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE) + self.writer_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE) + + self.thread_list = [ + threading.Thread(target=self.audio2motion_worker), + threading.Thread(target=self.motion_stitch_worker), + threading.Thread(target=self.warp_f3d_worker), + threading.Thread(target=self.decode_f3d_worker), + threading.Thread(target=self.putback_worker), + threading.Thread(target=self.writer_worker), + ] + + for thread in self.thread_list: + thread.start() + + def _get_ctrl_info(self, fid): + try: + if isinstance(self.ctrl_info, dict): + return self.ctrl_info.get(fid, {}) + elif isinstance(self.ctrl_info, list): + return self.ctrl_info[fid] + else: + return {} + except Exception as e: + traceback.print_exc() + return {} + + def writer_worker(self): + try: + self._writer_worker() + except Exception as e: + self.worker_exception = e + self.stop_event.set() + + def _writer_worker(self): + while not self.stop_event.is_set(): + try: + item = self.writer_queue.get(timeout=1) + except queue.Empty: + continue + + if item is None: + break + res_frame_rgb = item + self.writer(res_frame_rgb, fmt="rgb") + self.writer_pbar.update() + + def putback_worker(self): + try: + self._putback_worker() + except Exception as e: + self.worker_exception = e + self.stop_event.set() + + def _putback_worker(self): + while not self.stop_event.is_set(): + try: + item = self.putback_queue.get(timeout=1) + except queue.Empty: + continue + if item is None: + self.writer_queue.put(None) + break + frame_idx, render_img = item + frame_rgb = self.source_info["img_rgb_lst"][frame_idx] + M_c2o = self.source_info["M_c2o_lst"][frame_idx] + res_frame_rgb = self.putback(frame_rgb, render_img, M_c2o) + self.writer_queue.put(res_frame_rgb) + + def decode_f3d_worker(self): + try: + self._decode_f3d_worker() + except Exception as e: + self.worker_exception = e + self.stop_event.set() + + def _decode_f3d_worker(self): + while not self.stop_event.is_set(): + try: + item = self.decode_f3d_queue.get(timeout=1) + except queue.Empty: + continue + if item is None: + self.putback_queue.put(None) + break + frame_idx, f_3d = item + render_img = self.decode_f3d(f_3d) + self.putback_queue.put([frame_idx, render_img]) + + def warp_f3d_worker(self): + try: + self._warp_f3d_worker() + except Exception as e: + self.worker_exception = e + self.stop_event.set() + + def _warp_f3d_worker(self): + while not self.stop_event.is_set(): + try: + item = self.warp_f3d_queue.get(timeout=1) + except queue.Empty: + continue + if item is None: + self.decode_f3d_queue.put(None) + break + frame_idx, x_s, x_d = item + f_s = self.source_info["f_s_lst"][frame_idx] + f_3d = self.warp_f3d(f_s, x_s, x_d) + self.decode_f3d_queue.put([frame_idx, f_3d]) + + def motion_stitch_worker(self): + try: + self._motion_stitch_worker() + except Exception as e: + self.worker_exception = e + self.stop_event.set() + + def _motion_stitch_worker(self): + while not self.stop_event.is_set(): + try: + item = self.motion_stitch_queue.get(timeout=1) + except queue.Empty: + continue + if item is None: + self.warp_f3d_queue.put(None) + break + + frame_idx, x_d_info, ctrl_kwargs = item + x_s_info = self.source_info["x_s_info_lst"][frame_idx] + x_s, x_d = self.motion_stitch(x_s_info, x_d_info, **ctrl_kwargs) + self.warp_f3d_queue.put([frame_idx, x_s, x_d]) + + def audio2motion_worker(self): + try: + # self._audio2motion_worker() + self._audio2motion_offline() + except Exception as e: + self.worker_exception = e + self.stop_event.set() + + def _audio2motion_offline(self): + + while not self.stop_event.is_set(): + try: + item = self.audio2motion_queue.get(timeout=1) # audio feat + except queue.Empty: + continue + + if item is None: + break + + aud_feat = item + + aud_cond_all = self.condition_handler(aud_feat, 0) + seq_frames = self.audio2motion.seq_frames + valid_clip_len = self.audio2motion.valid_clip_len + num_frames = len(aud_cond_all) + idx = 0 + res_kp_seq = None + pbar = tqdm(desc="dit") + while idx < num_frames: + pbar.update() + aud_cond = aud_cond_all[idx:idx + seq_frames][None] + if aud_cond.shape[1] < seq_frames: + pad = np.stack([aud_cond[:, -1]] * (seq_frames - aud_cond.shape[1]), 1) + aud_cond = np.concatenate([aud_cond, pad], 1) + res_kp_seq = self.audio2motion(aud_cond, res_kp_seq) + idx += valid_clip_len + + pbar.close() + res_kp_seq = res_kp_seq[:, :num_frames] + res_kp_seq = self.audio2motion._smo(res_kp_seq, 0, res_kp_seq.shape[1]) + + x_d_info_list = self.audio2motion.cvt_fmt(res_kp_seq) + + gen_frame_idx = 0 + for x_d_info in x_d_info_list: + frame_idx = _mirror_index(gen_frame_idx, self.source_info_frames) + ctrl_kwargs = self._get_ctrl_info(gen_frame_idx) + + while not self.stop_event.is_set(): + try: + self.motion_stitch_queue.put([frame_idx, x_d_info, ctrl_kwargs], timeout=1) + break + except queue.Full: + continue + gen_frame_idx += 1 + + break + + self.motion_stitch_queue.put(None) + + + def _audio2motion_worker(self): + is_end = False + seq_frames = self.audio2motion.seq_frames + valid_clip_len = self.audio2motion.valid_clip_len + aud_feat_dim = self.wav2feat.feat_dim + item_buffer = np.zeros((0, aud_feat_dim), dtype=np.float32) + + res_kp_seq = None + res_kp_seq_valid_start = None if self.online_mode else 0 + + global_idx = 0 # frame idx, for template + local_idx = 0 # for cur audio_feat + gen_frame_idx = 0 + while not self.stop_event.is_set(): + try: + item = self.audio2motion_queue.get(timeout=1) # audio feat + except queue.Empty: + continue + if item is None: + is_end = True + else: + item_buffer = np.concatenate([item_buffer, item], 0) + + if not is_end and item_buffer.shape[0] < valid_clip_len: + # wait at least valid_clip_len new item + continue + else: + self.audio_feat = np.concatenate([self.audio_feat, item_buffer], 0) + item_buffer = np.zeros((0, aud_feat_dim), dtype=np.float32) + + while True: + # print("self.audio_feat.shape:", self.audio_feat.shape, "local_idx:", local_idx, "global_idx:", global_idx) + aud_feat = self.audio_feat[local_idx: local_idx+seq_frames] + real_valid_len = valid_clip_len + if len(aud_feat) == 0: + break + elif len(aud_feat) < seq_frames: + if not is_end: + # wait next chunk + break + else: + # final clip: pad to seq_frames + real_valid_len = len(aud_feat) + pad = np.stack([aud_feat[-1]] * (seq_frames - len(aud_feat)), 0) + aud_feat = np.concatenate([aud_feat, pad], 0) + + aud_cond = self.condition_handler(aud_feat, global_idx + self.cond_idx_start) + res_kp_seq = self.audio2motion(aud_cond, res_kp_seq) + if res_kp_seq_valid_start is None: + # online mode, first chunk + res_kp_seq_valid_start = res_kp_seq.shape[1] - self.audio2motion.fuse_length + d0 = self.audio2motion.cvt_fmt(res_kp_seq[0:1])[0] + self.motion_stitch.d0 = d0 + + local_idx += real_valid_len + global_idx += real_valid_len + continue + else: + valid_res_kp_seq = res_kp_seq[:, res_kp_seq_valid_start: res_kp_seq_valid_start + real_valid_len] + x_d_info_list = self.audio2motion.cvt_fmt(valid_res_kp_seq) + + for x_d_info in x_d_info_list: + frame_idx = _mirror_index(gen_frame_idx, self.source_info_frames) + ctrl_kwargs = self._get_ctrl_info(gen_frame_idx) + + while not self.stop_event.is_set(): + try: + self.motion_stitch_queue.put([frame_idx, x_d_info, ctrl_kwargs], timeout=1) + break + except queue.Full: + continue + + gen_frame_idx += 1 + + res_kp_seq_valid_start += real_valid_len + + local_idx += real_valid_len + global_idx += real_valid_len + + L = res_kp_seq.shape[1] + if L > seq_frames * 2: + cut_L = L - seq_frames * 2 + res_kp_seq = res_kp_seq[:, cut_L:] + res_kp_seq_valid_start -= cut_L + + if local_idx >= len(self.audio_feat): + break + + L = len(self.audio_feat) + if L > seq_frames * 2: + cut_L = L - seq_frames * 2 + self.audio_feat = self.audio_feat[cut_L:] + local_idx -= cut_L + + if is_end: + break + + self.motion_stitch_queue.put(None) + + def close(self): + # flush frames + self.audio2motion_queue.put(None) + # Wait for worker threads to finish + for thread in self.thread_list: + thread.join() + + try: + self.writer.close() + self.writer_pbar.close() + except: + traceback.print_exc() + + # Check if any worker encountered an exception + if self.worker_exception is not None: + raise self.worker_exception + + def run_chunk(self, audio_chunk, chunksize=(3, 5, 2)): + # only for hubert + aud_feat = self.wav2feat(audio_chunk, chunksize=chunksize) + while not self.stop_event.is_set(): + try: + self.audio2motion_queue.put(aud_feat, timeout=1) + break + except queue.Full: + continue + + + + + diff --git a/stream_pipeline_online.py b/stream_pipeline_online.py new file mode 100644 index 0000000000000000000000000000000000000000..0fbfd755af7c134066c98a5945031a632e7ac4da --- /dev/null +++ b/stream_pipeline_online.py @@ -0,0 +1,511 @@ +import threading +import queue +import numpy as np +import traceback +from tqdm import tqdm + +from core.atomic_components.avatar_registrar import AvatarRegistrar, smooth_x_s_info_lst +from core.atomic_components.condition_handler import ConditionHandler, _mirror_index +from core.atomic_components.audio2motion import Audio2Motion +from core.atomic_components.motion_stitch import MotionStitch +from core.atomic_components.warp_f3d import WarpF3D +from core.atomic_components.decode_f3d import DecodeF3D +from core.atomic_components.putback import PutBack +from core.atomic_components.writer import VideoWriterByImageIO +from core.atomic_components.wav2feat import Wav2Feat +from core.atomic_components.cfg import parse_cfg, print_cfg + + +""" +avatar_registrar_cfg: + insightface_det_cfg, + landmark106_cfg, + landmark203_cfg, + landmark478_cfg, + appearance_extractor_cfg, + motion_extractor_cfg, + +condition_handler_cfg: + use_emo=True, + use_sc=True, + use_eye_open=True, + use_eye_ball=True, + seq_frames=80, + +wav2feat_cfg: + w2f_cfg, + w2f_type +""" + + +class StreamSDK: + def __init__(self, cfg_pkl, data_root, **kwargs): + + [ + avatar_registrar_cfg, + condition_handler_cfg, + lmdm_cfg, + stitch_network_cfg, + warp_network_cfg, + decoder_cfg, + wav2feat_cfg, + default_kwargs, + ] = parse_cfg(cfg_pkl, data_root, kwargs) + + self.default_kwargs = default_kwargs + + self.avatar_registrar = AvatarRegistrar(**avatar_registrar_cfg) + self.condition_handler = ConditionHandler(**condition_handler_cfg) + self.audio2motion = Audio2Motion(lmdm_cfg) + self.motion_stitch = MotionStitch(stitch_network_cfg) + self.warp_f3d = WarpF3D(warp_network_cfg) + self.decode_f3d = DecodeF3D(decoder_cfg) + self.putback = PutBack() + + self.wav2feat = Wav2Feat(**wav2feat_cfg) + + def _merge_kwargs(self, default_kwargs, run_kwargs): + for k, v in default_kwargs.items(): + if k not in run_kwargs: + run_kwargs[k] = v + return run_kwargs + + def setup_Nd(self, N_d, fade_in=-1, fade_out=-1, ctrl_info=None): + # for eye open at video end + self.motion_stitch.set_Nd(N_d) + + # for fade in/out alpha + if ctrl_info is None: + ctrl_info = self.ctrl_info + if fade_in > 0: + for i in range(fade_in): + alpha = i / fade_in + item = ctrl_info.get(i, {}) + item["fade_alpha"] = alpha + ctrl_info[i] = item + if fade_out > 0: + ss = N_d - fade_out - 1 + ee = N_d - 1 + for i in range(ss, N_d): + alpha = max((ee - i) / (ee - ss), 0) + item = ctrl_info.get(i, {}) + item["fade_alpha"] = alpha + ctrl_info[i] = item + self.ctrl_info = ctrl_info + + def setup(self, source_path, output_path, **kwargs): + + # ======== Prepare Options ======== + kwargs = self._merge_kwargs(self.default_kwargs, kwargs) + print("=" * 20, "setup kwargs", "=" * 20) + print_cfg(**kwargs) + print("=" * 50) + + # -- avatar_registrar: template cfg -- + self.max_size = kwargs.get("max_size", 1920) + self.template_n_frames = kwargs.get("template_n_frames", -1) + + # -- avatar_registrar: crop cfg -- + self.crop_scale = kwargs.get("crop_scale", 2.3) + self.crop_vx_ratio = kwargs.get("crop_vx_ratio", 0) + self.crop_vy_ratio = kwargs.get("crop_vy_ratio", -0.125) + self.crop_flag_do_rot = kwargs.get("crop_flag_do_rot", True) + + # -- avatar_registrar: smo for video -- + self.smo_k_s = kwargs.get('smo_k_s', 13) + + # -- condition_handler: ECS -- + self.emo = kwargs.get("emo", 4) # int | [int] | [[int]] | numpy + self.eye_f0_mode = kwargs.get("eye_f0_mode", False) # for video + self.ch_info = kwargs.get("ch_info", None) # dict of np.ndarray + + # -- audio2motion: setup -- + self.overlap_v2 = kwargs.get("overlap_v2", 10) + self.fix_kp_cond = kwargs.get("fix_kp_cond", 0) + self.fix_kp_cond_dim = kwargs.get("fix_kp_cond_dim", None) # [ds,de] + self.sampling_timesteps = kwargs.get("sampling_timesteps", 50) + self.online_mode = kwargs.get("online_mode", False) + self.v_min_max_for_clip = kwargs.get('v_min_max_for_clip', None) + self.smo_k_d = kwargs.get("smo_k_d", 3) + + # -- motion_stitch: setup -- + self.N_d = kwargs.get("N_d", -1) + self.use_d_keys = kwargs.get("use_d_keys", None) + self.relative_d = kwargs.get("relative_d", True) + self.drive_eye = kwargs.get("drive_eye", None) # None: true4image, false4video + self.delta_eye_arr = kwargs.get("delta_eye_arr", None) + self.delta_eye_open_n = kwargs.get("delta_eye_open_n", 0) + self.fade_type = kwargs.get("fade_type", "") # "" | "d0" | "s" + self.fade_out_keys = kwargs.get("fade_out_keys", ("exp",)) + self.flag_stitching = kwargs.get("flag_stitching", True) + + self.ctrl_info = kwargs.get("ctrl_info", dict()) + self.overall_ctrl_info = kwargs.get("overall_ctrl_info", dict()) + """ + ctrl_info: list or dict + { + fid: ctrl_kwargs + } + + ctrl_kwargs (see motion_stitch.py): + fade_alpha + fade_out_keys + + delta_pitch + delta_yaw + delta_roll + """ + + # only hubert support online mode + assert self.wav2feat.support_streaming or not self.online_mode + + # ======== Register Avatar ======== + crop_kwargs = { + "crop_scale": self.crop_scale, + "crop_vx_ratio": self.crop_vx_ratio, + "crop_vy_ratio": self.crop_vy_ratio, + "crop_flag_do_rot": self.crop_flag_do_rot, + } + n_frames = self.template_n_frames if self.template_n_frames > 0 else self.N_d + source_info = self.avatar_registrar( + source_path, + max_dim=self.max_size, + n_frames=n_frames, + **crop_kwargs, + ) + + if len(source_info["x_s_info_lst"]) > 1 and self.smo_k_s > 1: + source_info["x_s_info_lst"] = smooth_x_s_info_lst(source_info["x_s_info_lst"], smo_k=self.smo_k_s) + + self.source_info = source_info + self.source_info_frames = len(source_info["x_s_info_lst"]) + + # ======== Setup Condition Handler ======== + self.condition_handler.setup(source_info, self.emo, eye_f0_mode=self.eye_f0_mode, ch_info=self.ch_info) + + # ======== Setup Audio2Motion (LMDM) ======== + x_s_info_0 = self.condition_handler.x_s_info_0 + self.audio2motion.setup( + x_s_info_0, + overlap_v2=self.overlap_v2, + fix_kp_cond=self.fix_kp_cond, + fix_kp_cond_dim=self.fix_kp_cond_dim, + sampling_timesteps=self.sampling_timesteps, + online_mode=self.online_mode, + v_min_max_for_clip=self.v_min_max_for_clip, + smo_k_d=self.smo_k_d, + ) + + # ======== Setup Motion Stitch ======== + is_image_flag = source_info["is_image_flag"] + x_s_info = source_info['x_s_info_lst'][0] + self.motion_stitch.setup( + N_d=self.N_d, + use_d_keys=self.use_d_keys, + relative_d=self.relative_d, + drive_eye=self.drive_eye, + delta_eye_arr=self.delta_eye_arr, + delta_eye_open_n=self.delta_eye_open_n, + fade_out_keys=self.fade_out_keys, + fade_type=self.fade_type, + flag_stitching=self.flag_stitching, + is_image_flag=is_image_flag, + x_s_info=x_s_info, + d0=None, + ch_info=self.ch_info, + overall_ctrl_info=self.overall_ctrl_info, + ) + + # ======== Video Writer ======== + self.output_path = output_path + self.tmp_output_path = output_path + ".tmp.mp4" + self.writer = VideoWriterByImageIO(self.tmp_output_path) + self.writer_pbar = tqdm(desc="writer") + + # ======== Audio Feat Buffer ======== + if self.online_mode: + # buffer: seq_frames - valid_clip_len + self.audio_feat = self.wav2feat.wav2feat(np.zeros((self.overlap_v2 * 640,), dtype=np.float32), sr=16000) + assert len(self.audio_feat) == self.overlap_v2, f"{len(self.audio_feat)}" + else: + self.audio_feat = np.zeros((0, self.wav2feat.feat_dim), dtype=np.float32) + self.cond_idx_start = 0 - len(self.audio_feat) + + # ======== Setup Worker Threads ======== + QUEUE_MAX_SIZE = 100 + # self.QUEUE_TIMEOUT = None + + self.worker_exception = None + self.stop_event = threading.Event() + + self.audio2motion_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE) + self.motion_stitch_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE) + self.warp_f3d_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE) + self.decode_f3d_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE) + self.putback_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE) + self.writer_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE) + + self.thread_list = [ + threading.Thread(target=self.audio2motion_worker), + threading.Thread(target=self.motion_stitch_worker), + threading.Thread(target=self.warp_f3d_worker), + threading.Thread(target=self.decode_f3d_worker), + threading.Thread(target=self.putback_worker), + threading.Thread(target=self.writer_worker), + ] + + for thread in self.thread_list: + thread.start() + + def _get_ctrl_info(self, fid): + try: + if isinstance(self.ctrl_info, dict): + return self.ctrl_info.get(fid, {}) + elif isinstance(self.ctrl_info, list): + return self.ctrl_info[fid] + else: + return {} + except Exception as e: + traceback.print_exc() + return {} + + def writer_worker(self): + try: + self._writer_worker() + except Exception as e: + self.worker_exception = e + self.stop_event.set() + + def _writer_worker(self): + while not self.stop_event.is_set(): + try: + item = self.writer_queue.get(timeout=1) + except queue.Empty: + continue + + if item is None: + break + res_frame_rgb = item + self.writer(res_frame_rgb, fmt="rgb") + self.writer_pbar.update() + + def putback_worker(self): + try: + self._putback_worker() + except Exception as e: + self.worker_exception = e + self.stop_event.set() + + def _putback_worker(self): + while not self.stop_event.is_set(): + try: + item = self.putback_queue.get(timeout=1) + except queue.Empty: + continue + if item is None: + self.writer_queue.put(None) + break + frame_idx, render_img = item + frame_rgb = self.source_info["img_rgb_lst"][frame_idx] + M_c2o = self.source_info["M_c2o_lst"][frame_idx] + res_frame_rgb = self.putback(frame_rgb, render_img, M_c2o) + self.writer_queue.put(res_frame_rgb) + + def decode_f3d_worker(self): + try: + self._decode_f3d_worker() + except Exception as e: + self.worker_exception = e + self.stop_event.set() + + def _decode_f3d_worker(self): + while not self.stop_event.is_set(): + try: + item = self.decode_f3d_queue.get(timeout=1) + except queue.Empty: + continue + if item is None: + self.putback_queue.put(None) + break + frame_idx, f_3d = item + render_img = self.decode_f3d(f_3d) + self.putback_queue.put([frame_idx, render_img]) + + def warp_f3d_worker(self): + try: + self._warp_f3d_worker() + except Exception as e: + self.worker_exception = e + self.stop_event.set() + + def _warp_f3d_worker(self): + while not self.stop_event.is_set(): + try: + item = self.warp_f3d_queue.get(timeout=1) + except queue.Empty: + continue + if item is None: + self.decode_f3d_queue.put(None) + break + frame_idx, x_s, x_d = item + f_s = self.source_info["f_s_lst"][frame_idx] + f_3d = self.warp_f3d(f_s, x_s, x_d) + self.decode_f3d_queue.put([frame_idx, f_3d]) + + def motion_stitch_worker(self): + try: + self._motion_stitch_worker() + except Exception as e: + self.worker_exception = e + self.stop_event.set() + + def _motion_stitch_worker(self): + while not self.stop_event.is_set(): + try: + item = self.motion_stitch_queue.get(timeout=1) + except queue.Empty: + continue + if item is None: + self.warp_f3d_queue.put(None) + break + + frame_idx, x_d_info, ctrl_kwargs = item + x_s_info = self.source_info["x_s_info_lst"][frame_idx] + x_s, x_d = self.motion_stitch(x_s_info, x_d_info, **ctrl_kwargs) + self.warp_f3d_queue.put([frame_idx, x_s, x_d]) + + def audio2motion_worker(self): + try: + self._audio2motion_worker() + except Exception as e: + self.worker_exception = e + self.stop_event.set() + + def _audio2motion_worker(self): + is_end = False + seq_frames = self.audio2motion.seq_frames + valid_clip_len = self.audio2motion.valid_clip_len + aud_feat_dim = self.wav2feat.feat_dim + item_buffer = np.zeros((0, aud_feat_dim), dtype=np.float32) + + res_kp_seq = None + res_kp_seq_valid_start = None if self.online_mode else 0 + + global_idx = 0 # frame idx, for template + local_idx = 0 # for cur audio_feat + gen_frame_idx = 0 + while not self.stop_event.is_set(): + try: + item = self.audio2motion_queue.get(timeout=1) # audio feat + except queue.Empty: + continue + if item is None: + is_end = True + else: + item_buffer = np.concatenate([item_buffer, item], 0) + + if not is_end and item_buffer.shape[0] < valid_clip_len: + # wait at least valid_clip_len new item + continue + else: + self.audio_feat = np.concatenate([self.audio_feat, item_buffer], 0) + item_buffer = np.zeros((0, aud_feat_dim), dtype=np.float32) + + while True: + # print("self.audio_feat.shape:", self.audio_feat.shape, "local_idx:", local_idx, "global_idx:", global_idx) + aud_feat = self.audio_feat[local_idx: local_idx+seq_frames] + real_valid_len = valid_clip_len + if len(aud_feat) == 0: + break + elif len(aud_feat) < seq_frames: + if not is_end: + # wait next chunk + break + else: + # final clip: pad to seq_frames + real_valid_len = len(aud_feat) + pad = np.stack([aud_feat[-1]] * (seq_frames - len(aud_feat)), 0) + aud_feat = np.concatenate([aud_feat, pad], 0) + + aud_cond = self.condition_handler(aud_feat, global_idx + self.cond_idx_start)[None] + res_kp_seq = self.audio2motion(aud_cond, res_kp_seq) + if res_kp_seq_valid_start is None: + # online mode, first chunk + res_kp_seq_valid_start = res_kp_seq.shape[1] - self.audio2motion.fuse_length + d0 = self.audio2motion.cvt_fmt(res_kp_seq[0:1])[0] + self.motion_stitch.d0 = d0 + + local_idx += real_valid_len + global_idx += real_valid_len + continue + else: + valid_res_kp_seq = res_kp_seq[:, res_kp_seq_valid_start: res_kp_seq_valid_start + real_valid_len] + x_d_info_list = self.audio2motion.cvt_fmt(valid_res_kp_seq) + + for x_d_info in x_d_info_list: + frame_idx = _mirror_index(gen_frame_idx, self.source_info_frames) + ctrl_kwargs = self._get_ctrl_info(gen_frame_idx) + + while not self.stop_event.is_set(): + try: + self.motion_stitch_queue.put([frame_idx, x_d_info, ctrl_kwargs], timeout=1) + break + except queue.Full: + continue + + gen_frame_idx += 1 + + res_kp_seq_valid_start += real_valid_len + + local_idx += real_valid_len + global_idx += real_valid_len + + L = res_kp_seq.shape[1] + if L > seq_frames * 2: + cut_L = L - seq_frames * 2 + res_kp_seq = res_kp_seq[:, cut_L:] + res_kp_seq_valid_start -= cut_L + + if local_idx >= len(self.audio_feat): + break + + L = len(self.audio_feat) + if L > seq_frames * 2: + cut_L = L - seq_frames * 2 + self.audio_feat = self.audio_feat[cut_L:] + local_idx -= cut_L + + if is_end: + break + + self.motion_stitch_queue.put(None) + + def close(self): + # flush frames + self.audio2motion_queue.put(None) + # Wait for worker threads to finish + for thread in self.thread_list: + thread.join() + + try: + self.writer.close() + self.writer_pbar.close() + except: + traceback.print_exc() + + # Check if any worker encountered an exception + if self.worker_exception is not None: + raise self.worker_exception + + def run_chunk(self, audio_chunk, chunksize=(3, 5, 2)): + # only for hubert + aud_feat = self.wav2feat(audio_chunk, chunksize=chunksize) + while not self.stop_event.is_set(): + try: + self.audio2motion_queue.put(aud_feat, timeout=1) + break + except queue.Full: + continue + + +