diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..05dc67685b581fd71ce44b63652aa320666d02ad
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,43 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*__pycache__
+**/__pycache__/
+*.py[cod]
+**/*.py[cod]
+*$py.class
+
+# Model weights
+checkpoints
+**/*.pth
+**/*.onnx
+**/*.pt
+**/*.pth.tar
+
+.idea
+.vscode
+.DS_Store
+*.DS_Store
+
+*.swp
+tmp*
+
+*build
+*.egg-info/
+*.mp4
+
+log/*
+*.mp4
+*.png
+*.jpg
+*.wav
+*.pth
+*.pyc
+*.jpeg
+
+# Folders to ignore
+example/
+ToDo/
+
+!example/audio.wav
+!example/image.png
+
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..f49a4e16e68b128803cc2dcea614603632b04eac
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/README_ditto-talkinghead.md b/README_ditto-talkinghead.md
new file mode 100644
index 0000000000000000000000000000000000000000..a078e6a14afc9084d34976f48d796a58cce64d22
--- /dev/null
+++ b/README_ditto-talkinghead.md
@@ -0,0 +1,232 @@
+
Ditto: Motion-Space Diffusion for Controllable Realtime Talking Head Synthesis
+
+
+
+Ant Group
+
+
+
+
+
+
+
+## 📌 Updates
+* [2025.07.11] 🔥 The [PyTorch model](#-pytorch-model) is now available.
+* [2025.07.07] 🔥 Ditto is accepted by ACM MM 2025.
+* [2025.01.21] 🔥 We update the [Colab](https://colab.research.google.com/drive/19SUi1TiO32IS-Crmsu9wrkNspWE8tFbs?usp=sharing) demo, welcome to try it.
+* [2025.01.10] 🔥 We release our inference [codes](https://github.com/antgroup/ditto-talkinghead) and [models](https://huggingface.co/digital-avatar/ditto-talkinghead).
+* [2024.11.29] 🔥 Our [paper](https://arxiv.org/abs/2411.19509) is in public on arxiv.
+
+
+
+## 🛠️ Installation
+
+Tested Environment
+- System: Centos 7.2
+- GPU: A100
+- Python: 3.10
+- tensorRT: 8.6.1
+
+
+Clone the codes from [GitHub](https://github.com/antgroup/ditto-talkinghead):
+```bash
+git clone https://github.com/antgroup/ditto-talkinghead
+cd ditto-talkinghead
+```
+
+### Conda
+Create `conda` environment:
+```bash
+conda env create -f environment.yaml
+conda activate ditto
+```
+
+### Pip
+If you have problems creating a conda environment, you can also refer to our [Colab](https://colab.research.google.com/drive/19SUi1TiO32IS-Crmsu9wrkNspWE8tFbs?usp=sharing).
+After correctly installing `pytorch`, `cuda` and `cudnn`, you only need to install a few packages using pip:
+```bash
+pip install \
+ tensorrt==8.6.1 \
+ librosa \
+ tqdm \
+ filetype \
+ imageio \
+ opencv_python_headless \
+ scikit-image \
+ cython \
+ cuda-python \
+ imageio-ffmpeg \
+ colored \
+ polygraphy \
+ numpy==2.0.1
+```
+
+If you don't use `conda`, you may also need to install `ffmpeg` according to the [official website](https://www.ffmpeg.org/download.html).
+
+
+## 📥 Download Checkpoints
+
+Download checkpoints from [HuggingFace](https://huggingface.co/digital-avatar/ditto-talkinghead) and put them in `checkpoints` dir:
+```bash
+git lfs install
+git clone https://huggingface.co/digital-avatar/ditto-talkinghead checkpoints
+```
+
+The `checkpoints` should be like:
+```text
+./checkpoints/
+├── ditto_cfg
+│ ├── v0.4_hubert_cfg_trt.pkl
+│ └── v0.4_hubert_cfg_trt_online.pkl
+├── ditto_onnx
+│ ├── appearance_extractor.onnx
+│ ├── blaze_face.onnx
+│ ├── decoder.onnx
+│ ├── face_mesh.onnx
+│ ├── hubert.onnx
+│ ├── insightface_det.onnx
+│ ├── landmark106.onnx
+│ ├── landmark203.onnx
+│ ├── libgrid_sample_3d_plugin.so
+│ ├── lmdm_v0.4_hubert.onnx
+│ ├── motion_extractor.onnx
+│ ├── stitch_network.onnx
+│ └── warp_network.onnx
+└── ditto_trt_Ampere_Plus
+ ├── appearance_extractor_fp16.engine
+ ├── blaze_face_fp16.engine
+ ├── decoder_fp16.engine
+ ├── face_mesh_fp16.engine
+ ├── hubert_fp32.engine
+ ├── insightface_det_fp16.engine
+ ├── landmark106_fp16.engine
+ ├── landmark203_fp16.engine
+ ├── lmdm_v0.4_hubert_fp32.engine
+ ├── motion_extractor_fp32.engine
+ ├── stitch_network_fp16.engine
+ └── warp_network_fp16.engine
+```
+
+- The `ditto_cfg/v0.4_hubert_cfg_trt_online.pkl` is online config
+- The `ditto_cfg/v0.4_hubert_cfg_trt.pkl` is offline config
+
+
+## 🚀 Inference
+
+Run `inference.py`:
+
+```shell
+python inference.py \
+ --data_root "" \
+ --cfg_pkl "" \
+ --audio_path "" \
+ --source_path "" \
+ --output_path ""
+```
+
+For example:
+
+```shell
+python inference.py \
+ --data_root "./checkpoints/ditto_trt_Ampere_Plus" \
+ --cfg_pkl "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl" \
+ --audio_path "./example/audio.wav" \
+ --source_path "./example/image.png" \
+ --output_path "./tmp/result.mp4"
+```
+
+❗Note:
+
+We have provided the tensorRT model with `hardware-compatibility-level=Ampere_Plus` (`checkpoints/ditto_trt_Ampere_Plus/`). If your GPU does not support it, please execute the `cvt_onnx_to_trt.py` script to convert from the general onnx model (`checkpoints/ditto_onnx/`) to the tensorRT model.
+
+```bash
+python scripts/cvt_onnx_to_trt.py --onnx_dir "./checkpoints/ditto_onnx" --trt_dir "./checkpoints/ditto_trt_custom"
+```
+
+Then run `inference.py` with `--data_root=./checkpoints/ditto_trt_custom`.
+
+
+## ⚡ PyTorch Model
+*Based on community interest and to better support further development, we are now open-sourcing the PyTorch version of the model.*
+
+
+We have added the PyTorch model and corresponding configuration files to the [HuggingFace](https://huggingface.co/digital-avatar/ditto-talkinghead). Please refer to [Download Checkpoints](#-download-checkpoints) to prepare the model files.
+
+The `checkpoints` should be like:
+```text
+./checkpoints/
+├── ditto_cfg
+│ ├── ...
+│ └── v0.4_hubert_cfg_pytorch.pkl
+├── ...
+└── ditto_pytorch
+ ├── aux_models
+ │ ├── 2d106det.onnx
+ │ ├── det_10g.onnx
+ │ ├── face_landmarker.task
+ │ ├── hubert_streaming_fix_kv.onnx
+ │ └── landmark203.onnx
+ └── models
+ ├── appearance_extractor.pth
+ ├── decoder.pth
+ ├── lmdm_v0.4_hubert.pth
+ ├── motion_extractor.pth
+ ├── stitch_network.pth
+ └── warp_network.pth
+```
+
+To run inference, execute the following command:
+
+```shell
+python inference.py \
+ --data_root "./checkpoints/ditto_pytorch" \
+ --cfg_pkl "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl" \
+ --audio_path "./example/audio.wav" \
+ --source_path "./example/image.png" \
+ --output_path "./tmp/result.mp4"
+```
+
+
+## 📧 Acknowledgement
+Our implementation is based on [S2G-MDDiffusion](https://github.com/thuhcsi/S2G-MDDiffusion) and [LivePortrait](https://github.com/KwaiVGI/LivePortrait). Thanks for their remarkable contribution and released code! If we missed any open-source projects or related articles, we would like to complement the acknowledgement of this specific work immediately.
+
+## ⚖️ License
+This repository is released under the Apache-2.0 license as found in the [LICENSE](LICENSE) file.
+
+## 📚 Citation
+If you find this codebase useful for your research, please use the following entry.
+```BibTeX
+@article{li2024ditto,
+ title={Ditto: Motion-Space Diffusion for Controllable Realtime Talking Head Synthesis},
+ author={Li, Tianqi and Zheng, Ruobing and Yang, Minghui and Chen, Jingdong and Yang, Ming},
+ journal={arXiv preprint arXiv:2411.19509},
+ year={2024}
+}
+```
+
+
+## 🌟 Star History
+
+[](https://www.star-history.com/#antgroup/ditto-talkinghead&Date)
diff --git a/core/atomic_components/audio2motion.py b/core/atomic_components/audio2motion.py
new file mode 100644
index 0000000000000000000000000000000000000000..12d2cf6e0d5c57303b6948d56455e6508c565ece
--- /dev/null
+++ b/core/atomic_components/audio2motion.py
@@ -0,0 +1,196 @@
+import numpy as np
+from ..models.lmdm import LMDM
+
+
+"""
+lmdm_cfg = {
+ "model_path": "",
+ "device": "cuda",
+ "motion_feat_dim": 265,
+ "audio_feat_dim": 1024+35,
+ "seq_frames": 80,
+}
+"""
+
+
+def _cvt_LP_motion_info(inp, mode, ignore_keys=()):
+ ks_shape_map = [
+ ['scale', (1, 1), 1],
+ ['pitch', (1, 66), 66],
+ ['yaw', (1, 66), 66],
+ ['roll', (1, 66), 66],
+ ['t', (1, 3), 3],
+ ['exp', (1, 63), 63],
+ ['kp', (1, 63), 63],
+ ]
+
+ def _dic2arr(_dic):
+ arr = []
+ for k, _, ds in ks_shape_map:
+ if k not in _dic or k in ignore_keys:
+ continue
+ v = _dic[k].reshape(ds)
+ if k == 'scale':
+ v = v - 1
+ arr.append(v)
+ arr = np.concatenate(arr, -1) # (133)
+ return arr
+
+ def _arr2dic(_arr):
+ dic = {}
+ s = 0
+ for k, ds, ss in ks_shape_map:
+ if k in ignore_keys:
+ continue
+ v = _arr[s:s + ss].reshape(ds)
+ if k == 'scale':
+ v = v + 1
+ dic[k] = v
+ s += ss
+ if s >= len(_arr):
+ break
+ return dic
+
+ if mode == 'dic2arr':
+ assert isinstance(inp, dict)
+ return _dic2arr(inp) # (dim)
+ elif mode == 'arr2dic':
+ assert inp.shape[0] >= 265, f"{inp.shape}"
+ return _arr2dic(inp) # {k: (1, dim)}
+ else:
+ raise ValueError()
+
+
+class Audio2Motion:
+ def __init__(
+ self,
+ lmdm_cfg,
+ ):
+ self.lmdm = LMDM(**lmdm_cfg)
+
+ def setup(
+ self,
+ x_s_info,
+ overlap_v2=10,
+ fix_kp_cond=0,
+ fix_kp_cond_dim=None,
+ sampling_timesteps=50,
+ online_mode=False,
+ v_min_max_for_clip=None,
+ smo_k_d=3,
+ ):
+ self.smo_k_d = smo_k_d
+ self.overlap_v2 = overlap_v2
+ self.seq_frames = self.lmdm.seq_frames
+ self.valid_clip_len = self.seq_frames - self.overlap_v2
+
+ # for fuse
+ self.online_mode = online_mode
+ if self.online_mode:
+ self.fuse_length = min(self.overlap_v2, self.valid_clip_len)
+ else:
+ self.fuse_length = self.overlap_v2
+ self.fuse_alpha = np.arange(self.fuse_length, dtype=np.float32).reshape(1, -1, 1) / self.fuse_length
+
+ self.fix_kp_cond = fix_kp_cond
+ self.fix_kp_cond_dim = fix_kp_cond_dim
+ self.sampling_timesteps = sampling_timesteps
+
+ self.v_min_max_for_clip = v_min_max_for_clip
+ if self.v_min_max_for_clip is not None:
+ self.v_min = self.v_min_max_for_clip[0][None] # [dim, 1]
+ self.v_max = self.v_min_max_for_clip[1][None]
+
+ kp_source = _cvt_LP_motion_info(x_s_info, mode='dic2arr', ignore_keys={'kp'})[None]
+ self.s_kp_cond = kp_source.copy().reshape(1, -1)
+ self.kp_cond = self.s_kp_cond.copy()
+
+ self.lmdm.setup(sampling_timesteps)
+
+ self.clip_idx = 0
+
+ def _fuse(self, res_kp_seq, pred_kp_seq):
+ ## ========================
+ ## offline fuse mode
+ ## last clip: -------
+ ## fuse part: *****
+ ## curr clip: -------
+ ## output: ^^
+ #
+ ## online fuse mode
+ ## last clip: -------
+ ## fuse part: **
+ ## curr clip: -------
+ ## output: ^^
+ ## ========================
+
+ fuse_r1_s = res_kp_seq.shape[1] - self.fuse_length
+ fuse_r1_e = res_kp_seq.shape[1]
+ fuse_r2_s = self.seq_frames - self.valid_clip_len - self.fuse_length
+ fuse_r2_e = self.seq_frames - self.valid_clip_len
+
+ r1 = res_kp_seq[:, fuse_r1_s:fuse_r1_e] # [1, fuse_len, dim]
+ r2 = pred_kp_seq[:, fuse_r2_s: fuse_r2_e] # [1, fuse_len, dim]
+ r_fuse = r1 * (1 - self.fuse_alpha) + r2 * self.fuse_alpha
+
+ res_kp_seq[:, fuse_r1_s:fuse_r1_e] = r_fuse # fuse last
+ res_kp_seq = np.concatenate([res_kp_seq, pred_kp_seq[:, fuse_r2_e:]], 1) # len(res_kp_seq) + valid_clip_len
+
+ return res_kp_seq
+
+ def _update_kp_cond(self, res_kp_seq, idx):
+ if self.fix_kp_cond == 0: # 不重置
+ self.kp_cond = res_kp_seq[:, idx-1]
+ elif self.fix_kp_cond > 0:
+ if self.clip_idx % self.fix_kp_cond == 0: # 重置
+ self.kp_cond = self.s_kp_cond.copy() # 重置所有
+ if self.fix_kp_cond_dim is not None:
+ ds, de = self.fix_kp_cond_dim
+ self.kp_cond[:, ds:de] = res_kp_seq[:, idx-1, ds:de]
+ else:
+ self.kp_cond = res_kp_seq[:, idx-1]
+
+ def _smo(self, res_kp_seq, s, e):
+ if self.smo_k_d <= 1:
+ return res_kp_seq
+ new_res_kp_seq = res_kp_seq.copy()
+ n = res_kp_seq.shape[1]
+ half_k = self.smo_k_d // 2
+ for i in range(s, e):
+ ss = max(0, i - half_k)
+ ee = min(n, i + half_k + 1)
+ res_kp_seq[:, i, :202] = np.mean(new_res_kp_seq[:, ss:ee, :202], axis=1)
+ return res_kp_seq
+
+ def __call__(self, aud_cond, res_kp_seq=None):
+ """
+ aud_cond: (1, seq_frames, dim)
+ """
+
+ pred_kp_seq = self.lmdm(self.kp_cond, aud_cond, self.sampling_timesteps)
+ if res_kp_seq is None:
+ res_kp_seq = pred_kp_seq # [1, seq_frames, dim]
+ res_kp_seq = self._smo(res_kp_seq, 0, res_kp_seq.shape[1])
+ else:
+ res_kp_seq = self._fuse(res_kp_seq, pred_kp_seq) # len(res_kp_seq) + valid_clip_len
+ res_kp_seq = self._smo(res_kp_seq, res_kp_seq.shape[1] - self.valid_clip_len - self.fuse_length, res_kp_seq.shape[1] - self.valid_clip_len + 1)
+
+ self.clip_idx += 1
+
+ idx = res_kp_seq.shape[1] - self.overlap_v2
+ self._update_kp_cond(res_kp_seq, idx)
+
+ return res_kp_seq
+
+ def cvt_fmt(self, res_kp_seq):
+ # res_kp_seq: [1, n, dim]
+ if self.v_min_max_for_clip is not None:
+ tmp_res_kp_seq = np.clip(res_kp_seq[0], self.v_min, self.v_max)
+ else:
+ tmp_res_kp_seq = res_kp_seq[0]
+
+ x_d_info_list = []
+ for i in range(tmp_res_kp_seq.shape[0]):
+ x_d_info = _cvt_LP_motion_info(tmp_res_kp_seq[i], 'arr2dic') # {k: (1, dim)}
+ x_d_info_list.append(x_d_info)
+ return x_d_info_list
diff --git a/core/atomic_components/avatar_registrar.py b/core/atomic_components/avatar_registrar.py
new file mode 100644
index 0000000000000000000000000000000000000000..47c2a4c72b2197f720c5fe0464d4e44720daf6f8
--- /dev/null
+++ b/core/atomic_components/avatar_registrar.py
@@ -0,0 +1,102 @@
+import numpy as np
+
+from .loader import load_source_frames
+from .source2info import Source2Info
+
+
+def _mean_filter(arr, k):
+ n = arr.shape[0]
+ half_k = k // 2
+ res = []
+ for i in range(n):
+ s = max(0, i - half_k)
+ e = min(n, i + half_k + 1)
+ res.append(arr[s:e].mean(0))
+ res = np.stack(res, 0)
+ return res
+
+
+def smooth_x_s_info_lst(x_s_info_list, ignore_keys=(), smo_k=13):
+ keys = x_s_info_list[0].keys()
+ N = len(x_s_info_list)
+ smo_dict = {}
+ for k in keys:
+ _lst = [x_s_info_list[i][k] for i in range(N)]
+ if k not in ignore_keys:
+ _lst = np.stack(_lst, 0)
+ _smo_lst = _mean_filter(_lst, smo_k)
+ else:
+ _smo_lst = _lst
+ smo_dict[k] = _smo_lst
+
+ smo_res = []
+ for i in range(N):
+ x_s_info = {k: smo_dict[k][i] for k in keys}
+ smo_res.append(x_s_info)
+ return smo_res
+
+
+class AvatarRegistrar:
+ """
+ source image|video -> rgb_list -> source_info
+ """
+ def __init__(
+ self,
+ insightface_det_cfg,
+ landmark106_cfg,
+ landmark203_cfg,
+ landmark478_cfg,
+ appearance_extractor_cfg,
+ motion_extractor_cfg,
+ ):
+ self.source2info = Source2Info(
+ insightface_det_cfg,
+ landmark106_cfg,
+ landmark203_cfg,
+ landmark478_cfg,
+ appearance_extractor_cfg,
+ motion_extractor_cfg,
+ )
+
+ def register(
+ self,
+ source_path, # image | video
+ max_dim=1920,
+ n_frames=-1,
+ **kwargs,
+ ):
+ """
+ kwargs:
+ crop_scale: 2.3
+ crop_vx_ratio: 0
+ crop_vy_ratio: -0.125
+ crop_flag_do_rot: True
+ """
+ rgb_list, is_image_flag = load_source_frames(source_path, max_dim=max_dim, n_frames=n_frames)
+ source_info = {
+ "x_s_info_lst": [],
+ "f_s_lst": [],
+ "M_c2o_lst": [],
+ "eye_open_lst": [],
+ "eye_ball_lst": [],
+ }
+ keys = ["x_s_info", "f_s", "M_c2o", "eye_open", "eye_ball"]
+ last_lmk = None
+ for rgb in rgb_list:
+ info = self.source2info(rgb, last_lmk, **kwargs)
+ for k in keys:
+ source_info[f"{k}_lst"].append(info[k])
+
+ last_lmk = info["lmk203"]
+
+ sc_f0 = source_info['x_s_info_lst'][0]['kp'].flatten()
+
+ source_info["sc"] = sc_f0
+ source_info["is_image_flag"] = is_image_flag
+ source_info["img_rgb_lst"] = rgb_list
+
+ return source_info
+
+ def __call__(self, *args, **kwargs):
+ return self.register(*args, **kwargs)
+
\ No newline at end of file
diff --git a/core/atomic_components/cfg.py b/core/atomic_components/cfg.py
new file mode 100644
index 0000000000000000000000000000000000000000..28e28ec9e53c3bef4f7de601771be92093a9bffc
--- /dev/null
+++ b/core/atomic_components/cfg.py
@@ -0,0 +1,111 @@
+import os
+import pickle
+import numpy as np
+
+
+def load_pkl(pkl):
+ with open(pkl, "rb") as f:
+ return pickle.load(f)
+
+
+def parse_cfg(cfg_pkl, data_root, replace_cfg=None):
+
+ def _check_path(p):
+ if os.path.isfile(p):
+ return p
+ else:
+ return os.path.join(data_root, p)
+
+ cfg = load_pkl(cfg_pkl)
+
+ # ---
+ # replace cfg for debug
+ if isinstance(replace_cfg, dict):
+ for k, v in replace_cfg.items():
+ if not isinstance(v, dict):
+ continue
+ for kk, vv in v.items():
+ cfg[k][kk] = vv
+ # ---
+
+ base_cfg = cfg["base_cfg"]
+ audio2motion_cfg = cfg["audio2motion_cfg"]
+ default_kwargs = cfg["default_kwargs"]
+
+ for k in base_cfg:
+ if k == "landmark478_cfg":
+ for kk in ["task_path", "blaze_face_model_path", "face_mesh_model_path"]:
+ if kk in base_cfg[k] and base_cfg[k][kk]:
+ base_cfg[k][kk] = _check_path(base_cfg[k][kk])
+ else:
+ base_cfg[k]["model_path"] = _check_path(base_cfg[k]["model_path"])
+
+ audio2motion_cfg["model_path"] = _check_path(audio2motion_cfg["model_path"])
+
+ avatar_registrar_cfg = {
+ k: base_cfg[k]
+ for k in [
+ "insightface_det_cfg",
+ "landmark106_cfg",
+ "landmark203_cfg",
+ "landmark478_cfg",
+ "appearance_extractor_cfg",
+ "motion_extractor_cfg",
+ ]
+ }
+
+ stitch_network_cfg = base_cfg["stitch_network_cfg"]
+ warp_network_cfg = base_cfg["warp_network_cfg"]
+ decoder_cfg = base_cfg["decoder_cfg"]
+
+ condition_handler_cfg = {
+ k: audio2motion_cfg[k]
+ for k in [
+ "use_emo",
+ "use_sc",
+ "use_eye_open",
+ "use_eye_ball",
+ "seq_frames",
+ ]
+ }
+
+ lmdm_cfg = {
+ k: audio2motion_cfg[k]
+ for k in [
+ "model_path",
+ "device",
+ "motion_feat_dim",
+ "audio_feat_dim",
+ "seq_frames",
+ ]
+ }
+
+ w2f_type = audio2motion_cfg["w2f_type"]
+ wav2feat_cfg = {
+ "w2f_cfg": base_cfg["hubert_cfg"] if w2f_type == "hubert" else base_cfg["wavlm_cfg"],
+ "w2f_type": w2f_type,
+ }
+
+ return [
+ avatar_registrar_cfg,
+ condition_handler_cfg,
+ lmdm_cfg,
+ stitch_network_cfg,
+ warp_network_cfg,
+ decoder_cfg,
+ wav2feat_cfg,
+ default_kwargs,
+ ]
+
+
+def print_cfg(**kwargs):
+ for k, v in kwargs.items():
+ if k == "ch_info":
+ print(k, type(v))
+ elif k == "ctrl_info":
+ print(k, type(v), len(v))
+ else:
+ if isinstance(v, np.ndarray):
+ print(k, type(v), v.shape)
+ else:
+ print(k, type(v), v)
diff --git a/core/atomic_components/condition_handler.py b/core/atomic_components/condition_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba1af63eecb9646eb85a91d864a6d912b286e5d7
--- /dev/null
+++ b/core/atomic_components/condition_handler.py
@@ -0,0 +1,168 @@
+import numpy as np
+from scipy.special import softmax
+import copy
+
+
+def _get_emo_avg(idx=6):
+ emo_avg = np.zeros(8, dtype=np.float32)
+ if isinstance(idx, (list, tuple)):
+ for i in idx:
+ emo_avg[i] = 8
+ else:
+ emo_avg[idx] = 8
+ emo_avg = softmax(emo_avg)
+ #emo_avg = None
+ # 'Angry', 'Disgust', 'Fear', 'Happy', 'Neutral', 'Sad', 'Surprise', 'Contempt'
+ return emo_avg
+
+
+def _mirror_index(index, size):
+ turn = index // size
+ res = index % size
+ if turn % 2 == 0:
+ return res
+ else:
+ return size - res - 1
+
+
+class ConditionHandler:
+ """
+ aud_feat, emo_seq, eye_seq, sc_seq -> cond_seq
+ """
+ def __init__(
+ self,
+ use_emo=True,
+ use_sc=True,
+ use_eye_open=True,
+ use_eye_ball=True,
+ seq_frames=80,
+ ):
+ self.use_emo = use_emo
+ self.use_sc = use_sc
+ self.use_eye_open = use_eye_open
+ self.use_eye_ball = use_eye_ball
+
+ self.seq_frames = seq_frames
+
+ def setup(self, setup_info, emo, eye_f0_mode=False, ch_info=None):
+ """
+ emo: int | [int] | [[int]] | numpy
+ """
+ if ch_info is None:
+ source_info = copy.deepcopy(setup_info)
+ else:
+ source_info = ch_info
+
+ self.eye_f0_mode = eye_f0_mode
+ self.x_s_info_0 = source_info['x_s_info_lst'][0]
+
+ if self.use_sc:
+ self.sc = source_info["sc"] # 63
+ self.sc_seq = np.stack([self.sc] * self.seq_frames, 0)
+
+ if self.use_eye_open:
+ self.eye_open_lst = np.concatenate(source_info["eye_open_lst"], 0) # [n, 2]
+ self.num_eye_open = len(self.eye_open_lst)
+ if self.num_eye_open == 1 or self.eye_f0_mode:
+ self.eye_open_seq = np.stack([self.eye_open_lst[0]] * self.seq_frames, 0)
+ else:
+ self.eye_open_seq = None
+
+ if self.use_eye_ball:
+ self.eye_ball_lst = np.concatenate(source_info["eye_ball_lst"], 0) # [n, 6]
+ self.num_eye_ball = len(self.eye_ball_lst)
+ if self.num_eye_ball == 1 or self.eye_f0_mode:
+ self.eye_ball_seq = np.stack([self.eye_ball_lst[0]] * self.seq_frames, 0)
+ else:
+ self.eye_ball_seq = None
+
+ if self.use_emo:
+ self.emo_lst = self._parse_emo_seq(emo)
+ self.num_emo = len(self.emo_lst)
+ if self.num_emo == 1:
+ self.emo_seq = np.concatenate([self.emo_lst] * self.seq_frames, 0)
+ else:
+ self.emo_seq = None
+
+ @staticmethod
+ def _parse_emo_seq(emo, seq_len=-1):
+ if isinstance(emo, np.ndarray) and emo.ndim == 2 and emo.shape[1] == 8:
+ # emo arr, e.g. real
+ emo_seq = emo # [m, 8]
+ elif isinstance(emo, int) and 0 <= emo < 8:
+ # emo label, e.g. 4
+ emo_seq = _get_emo_avg(emo).reshape(1, 8) # [1, 8]
+ elif isinstance(emo, (list, tuple)) and 0 < len(emo) < 8 and isinstance(emo[0], int):
+ # emo labels, e.g. [3,4]
+ emo_seq = _get_emo_avg(emo).reshape(1, 8) # [1, 8]
+ elif isinstance(emo, list) and emo and isinstance(emo[0], (list, tuple)):
+ # emo label list, e.g. [[4], [3,4], [3],[3,4,5], ...]
+ emo_seq = np.stack([_get_emo_avg(i) for i in emo], 0) # [m, 8]
+ else:
+ raise ValueError(f"Unsupported emo type: {emo}")
+
+ if seq_len > 0:
+ if len(emo_seq) == seq_len:
+ return emo_seq
+ elif len(emo_seq) == 1:
+ return np.concatenate([emo_seq] * seq_len, 0)
+ elif len(emo_seq) > seq_len:
+ return emo_seq[:seq_len]
+ else:
+ raise ValueError(f"emo len {len(emo_seq)} can not match seq len ({seq_len})")
+ else:
+ return emo_seq
+
+ def __call__(self, aud_feat, idx, emo=None):
+ """
+ aud_feat: [n, 1024]
+ idx: int, <0 means pad (first clip buffer)
+ """
+
+ frame_num = len(aud_feat)
+ more_cond = [aud_feat]
+ if self.use_emo:
+ if emo is not None:
+ emo_seq = self._parse_emo_seq(emo, frame_num)
+ elif self.emo_seq is not None and len(self.emo_seq) == frame_num:
+ emo_seq = self.emo_seq
+ else:
+ emo_idx_list = [max(i, 0) % self.num_emo for i in range(idx, idx + frame_num)]
+ emo_seq = self.emo_lst[emo_idx_list]
+ more_cond.append(emo_seq)
+
+ if self.use_eye_open:
+ if self.eye_open_seq is not None and len(self.eye_open_seq) == frame_num:
+ eye_open_seq = self.eye_open_seq
+ else:
+ if self.eye_f0_mode:
+ eye_idx_list = [0] * frame_num
+ else:
+ eye_idx_list = [_mirror_index(max(i, 0), self.num_eye_open) for i in range(idx, idx + frame_num)]
+ eye_open_seq = self.eye_open_lst[eye_idx_list]
+ more_cond.append(eye_open_seq)
+
+ if self.use_eye_ball:
+ if self.eye_ball_seq is not None and len(self.eye_ball_seq) == frame_num:
+ eye_ball_seq = self.eye_ball_seq
+ else:
+ if self.eye_f0_mode:
+ eye_idx_list = [0] * frame_num
+ else:
+ eye_idx_list = [_mirror_index(max(i, 0), self.num_eye_ball) for i in range(idx, idx + frame_num)]
+ eye_ball_seq = self.eye_ball_lst[eye_idx_list]
+ more_cond.append(eye_ball_seq)
+
+ if self.use_sc:
+ if len(self.sc_seq) == frame_num:
+ sc_seq = self.sc_seq
+ else:
+ sc_seq = np.stack([self.sc] * frame_num, 0)
+ more_cond.append(sc_seq)
+
+ if len(more_cond) > 1:
+ cond_seq = np.concatenate(more_cond, -1) # [n, dim_cond]
+ else:
+ cond_seq = aud_feat
+
+ return cond_seq
diff --git a/core/atomic_components/decode_f3d.py b/core/atomic_components/decode_f3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..96b85d2ac855ccb4b80804416e1c4e9c818789b2
--- /dev/null
+++ b/core/atomic_components/decode_f3d.py
@@ -0,0 +1,22 @@
+from ..models.decoder import Decoder
+
+
+"""
+# __init__
+decoder_cfg = {
+ "model_path": "",
+ "device": "cuda",
+}
+"""
+
+class DecodeF3D:
+ def __init__(
+ self,
+ decoder_cfg,
+ ):
+ self.decoder = Decoder(**decoder_cfg)
+
+ def __call__(self, f_s):
+ out = self.decoder(f_s)
+ return out
+
\ No newline at end of file
diff --git a/core/atomic_components/loader.py b/core/atomic_components/loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb46cec89081c04e641b6ab0f87afa6589c3e340
--- /dev/null
+++ b/core/atomic_components/loader.py
@@ -0,0 +1,133 @@
+import filetype
+import imageio
+import cv2
+
+
+def is_image(file_path):
+ return filetype.is_image(file_path)
+
+
+def is_video(file_path):
+ return filetype.is_video(file_path)
+
+
+def check_resize(h, w, max_dim=1920, division=2):
+ rsz_flag = False
+ # ajust the size of the image according to the maximum dimension
+ if max_dim > 0 and max(h, w) > max_dim:
+ rsz_flag = True
+ if h > w:
+ new_h = max_dim
+ new_w = int(round(w * max_dim / h))
+ else:
+ new_w = max_dim
+ new_h = int(round(h * max_dim / w))
+ else:
+ new_h = h
+ new_w = w
+
+ # ensure that the image dimensions are multiples of n
+ if new_h % division != 0:
+ new_h = new_h - (new_h % division)
+ rsz_flag = True
+ if new_w % division != 0:
+ new_w = new_w - (new_w % division)
+ rsz_flag = True
+
+ return new_h, new_w, rsz_flag
+
+
+def load_image(image_path, max_dim=-1):
+ img = cv2.imread(image_path, cv2.IMREAD_COLOR)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ h, w = img.shape[:2]
+ new_h, new_w, rsz_flag = check_resize(h, w, max_dim)
+ if rsz_flag:
+ img = cv2.resize(img, (new_w, new_h))
+ return img
+
+
+def load_video(video_path, n_frames=-1, max_dim=-1):
+ reader = imageio.get_reader(video_path, "ffmpeg")
+
+ new_h, new_w, rsz_flag = None, None, None
+
+ ret = []
+ for idx, frame_rgb in enumerate(reader):
+ if n_frames > 0 and idx >= n_frames:
+ break
+
+ if rsz_flag is None:
+ h, w = frame_rgb.shape[:2]
+ new_h, new_w, rsz_flag = check_resize(h, w, max_dim)
+
+ if rsz_flag:
+ frame_rgb = cv2.resize(frame_rgb, (new_w, new_h))
+
+ ret.append(frame_rgb)
+
+ reader.close()
+ return ret
+
+
+def load_source_frames(source_path, max_dim=-1, n_frames=-1):
+ if is_image(source_path):
+ rgb = load_image(source_path, max_dim)
+ rgb_list = [rgb]
+ is_image_flag = True
+ elif is_video(source_path):
+ rgb_list = load_video(source_path, n_frames, max_dim)
+ is_image_flag = False
+ else:
+ raise ValueError(f"Unsupported source type: {source_path}")
+ return rgb_list, is_image_flag
+
+
+def _mirror_index(index, size):
+ turn = index // size
+ res = index % size
+ if turn % 2 == 0:
+ return res
+ else:
+ return size - res - 1
+
+
+class LoopLoader:
+ def __init__(self, item_list, max_iter_num=-1, mirror_loop=True):
+ self.item_list = item_list
+ self.idx = 0
+ self.item_num = len(self.item_list)
+ self.max_iter_num = max_iter_num if max_iter_num > 0 else self.item_num
+ self.mirror_loop = mirror_loop
+
+ def __len__(self):
+ return self.max_iter_num
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.idx >= self.max_iter_num:
+ raise StopIteration
+
+ if self.mirror_loop:
+ idx = _mirror_index(self.idx, self.item_num)
+ else:
+ idx = self.idx % self.item_num
+ item = self.item_list[idx]
+
+ self.idx += 1
+ return item
+
+ def __call__(self):
+ return self.__iter__()
+
+ def reset(self, max_iter_num=-1):
+ self.frame_idx = 0
+ self.max_iter_num = max_iter_num if max_iter_num > 0 else self.item_num
+
+
+
+
+
+
diff --git a/core/atomic_components/motion_stitch.py b/core/atomic_components/motion_stitch.py
new file mode 100644
index 0000000000000000000000000000000000000000..7967984b0abd39bcf31fc5784756c360ce40dbbc
--- /dev/null
+++ b/core/atomic_components/motion_stitch.py
@@ -0,0 +1,491 @@
+import copy
+import random
+import numpy as np
+from scipy.special import softmax
+
+from ..models.stitch_network import StitchNetwork
+
+
+"""
+# __init__
+stitch_network_cfg = {
+ "model_path": "",
+ "device": "cuda",
+}
+
+# __call__
+kwargs:
+ fade_alpha
+ fade_out_keys
+
+ delta_pitch
+ delta_yaw
+ delta_roll
+
+"""
+
+
+def ctrl_motion(x_d_info, **kwargs):
+ # pose + offset
+ for kk in ["delta_pitch", "delta_yaw", "delta_roll"]:
+ if kk in kwargs:
+ k = kk[6:]
+ x_d_info[k] = bin66_to_degree(x_d_info[k]) + kwargs[kk]
+
+ # pose * alpha
+ for kk in ["alpha_pitch", "alpha_yaw", "alpha_roll"]:
+ if kk in kwargs:
+ k = kk[6:]
+ x_d_info[k] = x_d_info[k] * kwargs[kk]
+
+ # exp + offset
+ if "delta_exp" in kwargs:
+ k = "exp"
+ x_d_info[k] = x_d_info[k] + kwargs["delta_exp"]
+
+ return x_d_info
+
+
+def fade(x_d_info, dst, alpha, keys=None):
+ if keys is None:
+ keys = x_d_info.keys()
+ for k in keys:
+ if k == 'kp':
+ continue
+ x_d_info[k] = x_d_info[k] * alpha + dst[k] * (1 - alpha)
+ return x_d_info
+
+
+def ctrl_vad(x_d_info, dst, alpha):
+ exp = x_d_info["exp"]
+ exp_dst = dst["exp"]
+
+ _lip = [6, 12, 14, 17, 19, 20]
+ _a1 = np.zeros((21, 3), dtype=np.float32)
+ _a1[_lip] = alpha
+ _a1 = _a1.reshape(1, -1)
+ x_d_info["exp"] = exp * alpha + exp_dst * (1 - alpha)
+
+ return x_d_info
+
+
+
+def _mix_s_d_info(
+ x_s_info,
+ x_d_info,
+ use_d_keys=("exp", "pitch", "yaw", "roll", "t"),
+ d0=None,
+):
+ if d0 is not None:
+ if isinstance(use_d_keys, dict):
+ x_d_info = {
+ k: x_s_info[k] + (v - d0[k]) * use_d_keys.get(k, 1)
+ for k, v in x_d_info.items()
+ }
+ else:
+ x_d_info = {k: x_s_info[k] + (v - d0[k]) for k, v in x_d_info.items()}
+
+ for k, v in x_s_info.items():
+ if k not in x_d_info or k not in use_d_keys:
+ x_d_info[k] = v
+
+ if isinstance(use_d_keys, dict) and d0 is None:
+ for k, alpha in use_d_keys.items():
+ x_d_info[k] *= alpha
+ return x_d_info
+
+
+def _set_eye_blink_idx(N, blink_n=15, open_n=-1):
+ """
+ open_n:
+ -1: no blink
+ 0: random open_n
+ >0: fix open_n
+ list: loop open_n
+ """
+ OPEN_MIN = 60
+ OPEN_MAX = 100
+
+ idx = [0] * N
+ if isinstance(open_n, int):
+ if open_n < 0: # no blink
+ return idx
+ elif open_n > 0: # fix open_n
+ open_ns = [open_n]
+ else: # open_n == 0: # random open_n, 60-100
+ open_ns = []
+ elif isinstance(open_n, list):
+ open_ns = open_n # loop open_n
+ else:
+ raise ValueError()
+
+ blink_idx = list(range(blink_n))
+
+ start_n = open_ns[0] if open_ns else random.randint(OPEN_MIN, OPEN_MAX)
+ end_n = open_ns[-1] if open_ns else random.randint(OPEN_MIN, OPEN_MAX)
+ max_i = N - max(end_n, blink_n)
+ cur_i = start_n
+ cur_n_i = 1
+ while cur_i < max_i:
+ idx[cur_i : cur_i + blink_n] = blink_idx
+
+ if open_ns:
+ cur_n = open_ns[cur_n_i % len(open_ns)]
+ cur_n_i += 1
+ else:
+ cur_n = random.randint(OPEN_MIN, OPEN_MAX)
+
+ cur_i = cur_i + blink_n + cur_n
+
+ return idx
+
+
+def _fix_exp_for_x_d_info(x_d_info, x_s_info, delta_eye=None, drive_eye=True):
+ _eye = [11, 13, 15, 16, 18]
+ _lip = [6, 12, 14, 17, 19, 20]
+ alpha = np.zeros((21, 3), dtype=x_d_info["exp"].dtype)
+ alpha[_lip] = 1
+ if delta_eye is None and drive_eye: # use d eye
+ alpha[_eye] = 1
+ alpha = alpha.reshape(1, -1)
+ x_d_info["exp"] = x_d_info["exp"] * alpha + x_s_info["exp"] * (1 - alpha)
+
+ if delta_eye is not None and drive_eye:
+ alpha = np.zeros((21, 3), dtype=x_d_info["exp"].dtype)
+ alpha[_eye] = 1
+ alpha = alpha.reshape(1, -1)
+ x_d_info["exp"] = (delta_eye + x_s_info["exp"]) * alpha + x_d_info["exp"] * (
+ 1 - alpha
+ )
+
+ return x_d_info
+
+
+def _fix_exp_for_x_d_info_v2(x_d_info, x_s_info, delta_eye, a1, a2, a3):
+ x_d_info["exp"] = x_d_info["exp"] * a1 + x_s_info["exp"] * a2 + delta_eye * a3
+ return x_d_info
+
+
+def bin66_to_degree(pred):
+ if pred.ndim > 1 and pred.shape[1] == 66:
+ idx = np.arange(66).astype(np.float32)
+ pred = softmax(pred, axis=1)
+ degree = np.sum(pred * idx, axis=1) * 3 - 97.5
+ return degree
+ return pred
+
+
+def _eye_delta(exp, dx=0, dy=0):
+ if dx > 0:
+ exp[0, 33] += dx * 0.0007
+ exp[0, 45] += dx * 0.001
+ else:
+ exp[0, 33] += dx * 0.001
+ exp[0, 45] += dx * 0.0007
+
+ exp[0, 34] += dy * -0.001
+ exp[0, 46] += dy * -0.001
+ return exp
+
+def _fix_gaze(pose_s, x_d_info):
+ x_ratio = 0.26
+ y_ratio = 0.28
+
+ yaw_s, pitch_s = pose_s
+ yaw_d = bin66_to_degree(x_d_info['yaw']).item()
+ pitch_d = bin66_to_degree(x_d_info['pitch']).item()
+
+ delta_yaw = yaw_d - yaw_s
+ delta_pitch = pitch_d - pitch_s
+
+ dx = delta_yaw * x_ratio
+ dy = delta_pitch * y_ratio
+
+ x_d_info['exp'] = _eye_delta(x_d_info['exp'], dx, dy)
+ return x_d_info
+
+
+def get_rotation_matrix(pitch_, yaw_, roll_):
+ """ the input is in degree
+ """
+ # transform to radian
+ pitch = pitch_ / 180 * np.pi
+ yaw = yaw_ / 180 * np.pi
+ roll = roll_ / 180 * np.pi
+
+ if pitch.ndim == 1:
+ pitch = pitch[:, None]
+ if yaw.ndim == 1:
+ yaw = yaw[:, None]
+ if roll.ndim == 1:
+ roll = roll[:, None]
+
+ # calculate the euler matrix
+ bs = pitch.shape[0]
+ ones = np.ones((bs, 1), dtype=np.float32)
+ zeros = np.zeros((bs, 1), dtype=np.float32)
+ x, y, z = pitch, yaw, roll
+
+ rot_x = np.concatenate([
+ ones, zeros, zeros,
+ zeros, np.cos(x), -np.sin(x),
+ zeros, np.sin(x), np.cos(x)
+ ], axis=1).reshape(bs, 3, 3)
+
+ rot_y = np.concatenate([
+ np.cos(y), zeros, np.sin(y),
+ zeros, ones, zeros,
+ -np.sin(y), zeros, np.cos(y)
+ ], axis=1).reshape(bs, 3, 3)
+
+ rot_z = np.concatenate([
+ np.cos(z), -np.sin(z), zeros,
+ np.sin(z), np.cos(z), zeros,
+ zeros, zeros, ones
+ ], axis=1).reshape(bs, 3, 3)
+
+ rot = np.matmul(np.matmul(rot_z, rot_y), rot_x)
+ return np.transpose(rot, (0, 2, 1))
+
+
+def transform_keypoint(kp_info: dict):
+ """
+ transform the implicit keypoints with the pose, shift, and expression deformation
+ kp: BxNx3
+ """
+ kp = kp_info['kp'] # (bs, k, 3)
+ pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']
+
+ t, exp = kp_info['t'], kp_info['exp']
+ scale = kp_info['scale']
+
+ pitch = bin66_to_degree(pitch)
+ yaw = bin66_to_degree(yaw)
+ roll = bin66_to_degree(roll)
+
+ bs = kp.shape[0]
+ if kp.ndim == 2:
+ num_kp = kp.shape[1] // 3 # Bx(num_kpx3)
+ else:
+ num_kp = kp.shape[1] # Bxnum_kpx3
+
+ rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3)
+
+ # Eqn.2: s * (R * x_c,s + exp) + t
+ kp_transformed = np.matmul(kp.reshape(bs, num_kp, 3), rot_mat) + exp.reshape(bs, num_kp, 3)
+ kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
+ kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty
+
+ return kp_transformed
+
+
+class MotionStitch:
+ def __init__(
+ self,
+ stitch_network_cfg,
+ ):
+ self.stitch_net = StitchNetwork(**stitch_network_cfg)
+
+ def set_Nd(self, N_d=-1):
+ # only for offline (make start|end eye open)
+ if N_d == self.N_d:
+ return
+
+ self.N_d = N_d
+ if self.drive_eye and self.delta_eye_arr is not None:
+ N = 3000 if self.N_d == -1 else self.N_d
+ self.delta_eye_idx_list = _set_eye_blink_idx(
+ N, len(self.delta_eye_arr), self.delta_eye_open_n
+ )
+
+ def setup(
+ self,
+ N_d=-1,
+ use_d_keys=None,
+ relative_d=True,
+ drive_eye=None, # use d eye or s eye
+ delta_eye_arr=None, # fix eye
+ delta_eye_open_n=-1, # int|list
+ fade_out_keys=("exp",),
+ fade_type="", # "" | "d0" | "s"
+ flag_stitching=True,
+ is_image_flag=True,
+ x_s_info=None,
+ d0=None,
+ ch_info=None,
+ overall_ctrl_info=None,
+ ):
+ self.is_image_flag = is_image_flag
+ if use_d_keys is None:
+ if self.is_image_flag:
+ self.use_d_keys = ("exp", "pitch", "yaw", "roll", "t")
+ else:
+ self.use_d_keys = ("exp", )
+ else:
+ self.use_d_keys = use_d_keys
+
+ if drive_eye is None:
+ if self.is_image_flag:
+ self.drive_eye = True
+ else:
+ self.drive_eye = False
+ else:
+ self.drive_eye = drive_eye
+
+ self.N_d = N_d
+ self.relative_d = relative_d
+ self.delta_eye_arr = delta_eye_arr
+ self.delta_eye_open_n = delta_eye_open_n
+ self.fade_out_keys = fade_out_keys
+ self.fade_type = fade_type
+ self.flag_stitching = flag_stitching
+
+ _eye = [11, 13, 15, 16, 18]
+ _lip = [6, 12, 14, 17, 19, 20]
+ _a1 = np.zeros((21, 3), dtype=np.float32)
+ _a1[_lip] = 1
+ _a2 = 0
+ if self.drive_eye:
+ if self.delta_eye_arr is None:
+ _a1[_eye] = 1
+ else:
+ _a2 = np.zeros((21, 3), dtype=np.float32)
+ _a2[_eye] = 1
+ _a2 = _a2.reshape(1, -1)
+ _a1 = _a1.reshape(1, -1)
+
+ self.fix_exp_a1 = _a1 * (1 - _a2)
+ self.fix_exp_a2 = (1 - _a1) + _a1 * _a2
+ self.fix_exp_a3 = _a2
+
+ if self.drive_eye and self.delta_eye_arr is not None:
+ N = 3000 if self.N_d == -1 else self.N_d
+ self.delta_eye_idx_list = _set_eye_blink_idx(
+ N, len(self.delta_eye_arr), self.delta_eye_open_n
+ )
+
+ self.pose_s = None
+ self.x_s = None
+ self.fade_dst = None
+ if self.is_image_flag and x_s_info is not None:
+ yaw_s = bin66_to_degree(x_s_info['yaw']).item()
+ pitch_s = bin66_to_degree(x_s_info['pitch']).item()
+ self.pose_s = [yaw_s, pitch_s]
+ self.x_s = transform_keypoint(x_s_info)
+
+ if self.fade_type == "s":
+ self.fade_dst = copy.deepcopy(x_s_info)
+
+ if ch_info is not None:
+ self.scale_a = ch_info['x_s_info_lst'][0]['scale'].item()
+ if x_s_info is not None:
+ self.scale_b = x_s_info['scale'].item()
+ self.scale_ratio = self.scale_a / self.scale_b
+ self._set_scale_ratio(self.scale_ratio)
+ else:
+ self.scale_ratio = None
+ else:
+ self.scale_ratio = 1
+
+ self.overall_ctrl_info = overall_ctrl_info
+
+ self.d0 = d0
+ self.idx = 0
+
+ def _set_scale_ratio(self, scale_ratio=1):
+ if scale_ratio == 1:
+ return
+ if isinstance(self.use_d_keys, dict):
+ self.use_d_keys = {k: v * (scale_ratio if k in {"exp", "pitch", "yaw", "roll"} else 1) for k, v in self.use_d_keys.items()}
+ else:
+ self.use_d_keys = {k: scale_ratio if k in {"exp", "pitch", "yaw", "roll"} else 1 for k in self.use_d_keys}
+
+ @staticmethod
+ def _merge_kwargs(default_kwargs, run_kwargs):
+ if default_kwargs is None:
+ return run_kwargs
+
+ for k, v in default_kwargs.items():
+ if k not in run_kwargs:
+ run_kwargs[k] = v
+ return run_kwargs
+
+ def __call__(self, x_s_info, x_d_info, **kwargs):
+ # return x_s, x_d
+
+ kwargs = self._merge_kwargs(self.overall_ctrl_info, kwargs)
+
+ if self.scale_ratio is None:
+ self.scale_b = x_s_info['scale'].item()
+ self.scale_ratio = self.scale_a / self.scale_b
+ self._set_scale_ratio(self.scale_ratio)
+
+ if self.relative_d and self.d0 is None:
+ self.d0 = copy.deepcopy(x_d_info)
+
+ x_d_info = _mix_s_d_info(
+ x_s_info,
+ x_d_info,
+ self.use_d_keys,
+ self.d0,
+ )
+
+ delta_eye = 0
+ if self.drive_eye and self.delta_eye_arr is not None:
+ delta_eye = self.delta_eye_arr[
+ self.delta_eye_idx_list[self.idx % len(self.delta_eye_idx_list)]
+ ][None]
+ x_d_info = _fix_exp_for_x_d_info_v2(
+ x_d_info,
+ x_s_info,
+ delta_eye,
+ self.fix_exp_a1,
+ self.fix_exp_a2,
+ self.fix_exp_a3,
+ )
+
+ if kwargs.get("vad_alpha", 1) < 1:
+ x_d_info = ctrl_vad(x_d_info, x_s_info, kwargs.get("vad_alpha", 1))
+
+ x_d_info = ctrl_motion(x_d_info, **kwargs)
+
+ if self.fade_type == "d0" and self.fade_dst is None:
+ self.fade_dst = copy.deepcopy(x_d_info)
+
+ # fade
+ if "fade_alpha" in kwargs and self.fade_type in ["d0", "s"]:
+ fade_alpha = kwargs["fade_alpha"]
+ fade_keys = kwargs.get("fade_out_keys", self.fade_out_keys)
+ if self.fade_type == "d0":
+ fade_dst = self.fade_dst
+ elif self.fade_type == "s":
+ if self.fade_dst is not None:
+ fade_dst = self.fade_dst
+ else:
+ fade_dst = copy.deepcopy(x_s_info)
+ if self.is_image_flag:
+ self.fade_dst = fade_dst
+ x_d_info = fade(x_d_info, fade_dst, fade_alpha, fade_keys)
+
+ if self.drive_eye:
+ if self.pose_s is None:
+ yaw_s = bin66_to_degree(x_s_info['yaw']).item()
+ pitch_s = bin66_to_degree(x_s_info['pitch']).item()
+ self.pose_s = [yaw_s, pitch_s]
+ x_d_info = _fix_gaze(self.pose_s, x_d_info)
+
+ if self.x_s is not None:
+ x_s = self.x_s
+ else:
+ x_s = transform_keypoint(x_s_info)
+ if self.is_image_flag:
+ self.x_s = x_s
+
+ x_d = transform_keypoint(x_d_info)
+
+ if self.flag_stitching:
+ x_d = self.stitch_net(x_s, x_d)
+
+ self.idx += 1
+
+ return x_s, x_d
diff --git a/core/atomic_components/putback.py b/core/atomic_components/putback.py
new file mode 100644
index 0000000000000000000000000000000000000000..39a5a6dc1581697774de0d5b62b214a09c8dfb72
--- /dev/null
+++ b/core/atomic_components/putback.py
@@ -0,0 +1,60 @@
+import cv2
+import numpy as np
+from ..utils.blend import blend_images_cy
+from ..utils.get_mask import get_mask
+
+
+class PutBackNumpy:
+ def __init__(
+ self,
+ mask_template_path=None,
+ ):
+ if mask_template_path is None:
+ mask = get_mask(512, 512, 0.9, 0.9)
+ self.mask_ori_float = np.concatenate([mask] * 3, 2)
+ else:
+ mask = cv2.imread(mask_template_path, cv2.IMREAD_COLOR)
+ self.mask_ori_float = mask.astype(np.float32) / 255.0
+
+ def __call__(self, frame_rgb, render_image, M_c2o):
+ h, w = frame_rgb.shape[:2]
+ mask_warped = cv2.warpAffine(
+ self.mask_ori_float, M_c2o[:2, :], dsize=(w, h), flags=cv2.INTER_LINEAR
+ ).clip(0, 1)
+ frame_warped = cv2.warpAffine(
+ render_image, M_c2o[:2, :], dsize=(w, h), flags=cv2.INTER_LINEAR
+ )
+ result = mask_warped * frame_warped + (1 - mask_warped) * frame_rgb
+ result = np.clip(result, 0, 255)
+ result = result.astype(np.uint8)
+ return result
+
+
+class PutBack:
+ def __init__(
+ self,
+ mask_template_path=None,
+ ):
+ if mask_template_path is None:
+ mask = get_mask(512, 512, 0.9, 0.9)
+ mask = np.concatenate([mask] * 3, 2)
+ else:
+ mask = cv2.imread(mask_template_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0
+
+ self.mask_ori_float = np.ascontiguousarray(mask)[:,:,0]
+ self.result_buffer = None
+
+ def __call__(self, frame_rgb, render_image, M_c2o):
+ h, w = frame_rgb.shape[:2]
+ mask_warped = cv2.warpAffine(
+ self.mask_ori_float, M_c2o[:2, :], dsize=(w, h), flags=cv2.INTER_LINEAR
+ ).clip(0, 1)
+ frame_warped = cv2.warpAffine(
+ render_image, M_c2o[:2, :], dsize=(w, h), flags=cv2.INTER_LINEAR
+ )
+ self.result_buffer = np.empty((h, w, 3), dtype=np.uint8)
+
+ # Use Cython implementation for blending
+ blend_images_cy(mask_warped, frame_warped, frame_rgb, self.result_buffer)
+
+ return self.result_buffer
\ No newline at end of file
diff --git a/core/atomic_components/source2info.py b/core/atomic_components/source2info.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a64fb6fbe16928ea60f4281930525f643fc6d7e
--- /dev/null
+++ b/core/atomic_components/source2info.py
@@ -0,0 +1,155 @@
+import numpy as np
+import cv2
+
+from ..aux_models.insightface_det import InsightFaceDet
+from ..aux_models.insightface_landmark106 import Landmark106
+from ..aux_models.landmark203 import Landmark203
+from ..aux_models.mediapipe_landmark478 import Landmark478
+from ..models.appearance_extractor import AppearanceExtractor
+from ..models.motion_extractor import MotionExtractor
+
+from ..utils.crop import crop_image
+from ..utils.eye_info import EyeAttrUtilsByMP
+
+
+"""
+insightface_det_cfg = {
+ "model_path": "",
+ "device": "cuda",
+ "force_ori_type": False,
+}
+landmark106_cfg = {
+ "model_path": "",
+ "device": "cuda",
+ "force_ori_type": False,
+}
+landmark203_cfg = {
+ "model_path": "",
+ "device": "cuda",
+ "force_ori_type": False,
+}
+landmark478_cfg = {
+ "blaze_face_model_path": "",
+ "face_mesh_model_path": "",
+ "device": "cuda",
+ "force_ori_type": False,
+ "task_path": "",
+}
+appearance_extractor_cfg = {
+ "model_path": "",
+ "device": "cuda",
+}
+motion_extractor_cfg = {
+ "model_path": "",
+ "device": "cuda",
+}
+"""
+
+
+class Source2Info:
+ def __init__(
+ self,
+ insightface_det_cfg,
+ landmark106_cfg,
+ landmark203_cfg,
+ landmark478_cfg,
+ appearance_extractor_cfg,
+ motion_extractor_cfg,
+ ):
+ self.insightface_det = InsightFaceDet(**insightface_det_cfg)
+ self.landmark106 = Landmark106(**landmark106_cfg)
+ self.landmark203 = Landmark203(**landmark203_cfg)
+ self.landmark478 = Landmark478(**landmark478_cfg)
+
+ self.appearance_extractor = AppearanceExtractor(**appearance_extractor_cfg)
+ self.motion_extractor = MotionExtractor(**motion_extractor_cfg)
+
+ def _crop(self, img, last_lmk=None, **kwargs):
+ # img_rgb -> det->landmark106->landmark203->crop
+
+ if last_lmk is None: # det for first frame or image
+ det, _ = self.insightface_det(img)
+ boxes = det[np.argsort(-(det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1]))]
+ if len(boxes) == 0:
+ return None
+ lmk_for_track = self.landmark106(img, boxes[0]) # 106
+ else: # track for video frames
+ lmk_for_track = last_lmk # 203
+
+ crop_dct = crop_image(
+ img,
+ lmk_for_track,
+ dsize=self.landmark203.dsize,
+ scale=1.5,
+ vy_ratio=-0.1,
+ pt_crop_flag=False,
+ )
+ lmk203 = self.landmark203(crop_dct["img_crop"], crop_dct["M_c2o"])
+
+ ret_dct = crop_image(
+ img,
+ lmk203,
+ dsize=512,
+ scale=kwargs.get("crop_scale", 2.3),
+ vx_ratio=kwargs.get("crop_vx_ratio", 0),
+ vy_ratio=kwargs.get("crop_vy_ratio", -0.125),
+ flag_do_rot=kwargs.get("crop_flag_do_rot", True),
+ pt_crop_flag=False,
+ )
+
+ img_crop = ret_dct["img_crop"]
+ M_c2o = ret_dct["M_c2o"]
+
+ return img_crop, M_c2o, lmk203
+
+ @staticmethod
+ def _img_crop_to_bchw256(img_crop):
+ rgb_256 = cv2.resize(img_crop, (256, 256), interpolation=cv2.INTER_AREA)
+ rgb_256_bchw = (rgb_256.astype(np.float32) / 255.0)[None].transpose(0, 3, 1, 2)
+ return rgb_256_bchw
+
+ def _get_kp_info(self, img):
+ # rgb_256_bchw_norm01
+ kp_info = self.motion_extractor(img)
+ return kp_info
+
+ def _get_f3d(self, img):
+ # rgb_256_bchw_norm01
+ fs = self.appearance_extractor(img)
+ return fs
+
+ def _get_eye_info(self, img):
+ # rgb uint8
+ lmk478 = self.landmark478(img) # [1, 478, 3]
+ attr = EyeAttrUtilsByMP(lmk478)
+ lr_open = attr.LR_open().reshape(-1, 2) # [1, 2]
+ lr_ball = attr.LR_ball_move().reshape(-1, 6) # [1, 3, 2] -> [1, 6]
+ return [lr_open, lr_ball]
+
+ def __call__(self, img, last_lmk=None, **kwargs):
+ """
+ img: rgb, uint8
+ last_lmk: last frame lmk203, for video tracking
+ kwargs: optional crop cfg
+ crop_scale: 2.3
+ crop_vx_ratio: 0
+ crop_vy_ratio: -0.125
+ crop_flag_do_rot: True
+ """
+ img_crop, M_c2o, lmk203 = self._crop(img, last_lmk=last_lmk, **kwargs)
+
+ eye_open, eye_ball = self._get_eye_info(img_crop)
+
+ rgb_256_bchw = self._img_crop_to_bchw256(img_crop)
+ kp_info = self._get_kp_info(rgb_256_bchw)
+ fs = self._get_f3d(rgb_256_bchw)
+
+ source_info = {
+ "x_s_info": kp_info,
+ "f_s": fs,
+ "M_c2o": M_c2o,
+ "eye_open": eye_open, # [1, 2]
+ "eye_ball": eye_ball, # [1, 6]
+ "lmk203": lmk203, # for track
+ }
+ return source_info
diff --git a/core/atomic_components/warp_f3d.py b/core/atomic_components/warp_f3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..be4262f6d329f75b7b3ca661ebc8ad8d3503c6e3
--- /dev/null
+++ b/core/atomic_components/warp_f3d.py
@@ -0,0 +1,22 @@
+from ..models.warp_network import WarpNetwork
+
+
+"""
+# __init__
+warp_network_cfg = {
+ "model_path": "",
+ "device": "cuda",
+}
+"""
+
+class WarpF3D:
+ def __init__(
+ self,
+ warp_network_cfg,
+ ):
+ self.warp_net = WarpNetwork(**warp_network_cfg)
+
+ def __call__(self, f_s, x_s, x_d):
+ out = self.warp_net(f_s, x_s, x_d)
+ return out
+
\ No newline at end of file
diff --git a/core/atomic_components/wav2feat.py b/core/atomic_components/wav2feat.py
new file mode 100644
index 0000000000000000000000000000000000000000..10f360edb26b86c0ade9e59de4b71bc64983de82
--- /dev/null
+++ b/core/atomic_components/wav2feat.py
@@ -0,0 +1,110 @@
+import librosa
+import numpy as np
+import math
+
+from ..aux_models.hubert_stream import HubertStreaming
+
+"""
+wavlm_cfg = {
+ "model_path": "",
+ "device": "cuda",
+ "force_ori_type": False,
+}
+hubert_cfg = {
+ "model_path": "",
+ "device": "cuda",
+ "force_ori_type": False,
+}
+"""
+
+
+class Wav2Feat:
+ def __init__(self, w2f_cfg, w2f_type="hubert"):
+ self.w2f_type = w2f_type.lower()
+ if self.w2f_type == "hubert":
+ self.w2f = Wav2FeatHubert(hubert_cfg=w2f_cfg)
+ self.feat_dim = 1024
+ self.support_streaming = True
+ else:
+ raise ValueError(f"Unsupported w2f_type: {w2f_type}")
+
+ def __call__(
+ self,
+ audio,
+ sr=16000,
+ norm_mean_std=None, # for s2g
+ chunksize=(3, 5, 2), # for hubert
+ ):
+ if self.w2f_type == "hubert":
+ feat = self.w2f(audio, chunksize=chunksize)
+ elif self.w2f_type == "s2g":
+ feat = self.w2f(audio, sr=sr, norm_mean_std=norm_mean_std)
+ else:
+ raise ValueError(f"Unsupported w2f_type: {self.w2f_type}")
+ return feat
+
+ def wav2feat(
+ self,
+ audio,
+ sr=16000,
+ norm_mean_std=None, # for s2g
+ chunksize=(3, 5, 2),
+ ):
+ # for offline
+ if self.w2f_type == "hubert":
+ feat = self.w2f.wav2feat(audio, sr=sr, chunksize=chunksize)
+ elif self.w2f_type == "s2g":
+ feat = self.w2f(audio, sr=sr, norm_mean_std=norm_mean_std)
+ else:
+ raise ValueError(f"Unsupported w2f_type: {self.w2f_type}")
+ return feat
+
+
+class Wav2FeatHubert:
+ def __init__(
+ self,
+ hubert_cfg,
+ ):
+ self.hubert = HubertStreaming(**hubert_cfg)
+
+ def __call__(self, audio_chunk, chunksize=(3, 5, 2)):
+ """
+ audio_chunk: int(sum(chunksize) * 0.04 * 16000) + 80 # 6480
+ """
+ valid_feat_s = - sum(chunksize[1:]) * 2 # -7
+ valid_feat_e = - chunksize[2] * 2 # -2
+
+ encoding_chunk = self.hubert(audio_chunk)
+ valid_encoding = encoding_chunk[valid_feat_s:valid_feat_e]
+ valid_feat = valid_encoding.reshape(chunksize[1], 2, 1024).mean(1) # [5, 1024]
+ return valid_feat
+
+ def wav2feat(self, audio, sr, chunksize=(3, 5, 2)):
+ # for offline
+ if sr != 16000:
+ audio_16k = librosa.resample(audio, orig_sr=sr, target_sr=16000)
+ else:
+ audio_16k = audio
+
+ num_f = math.ceil(len(audio_16k) / 16000 * 25)
+ split_len = int(sum(chunksize) * 0.04 * 16000) + 80 # 6480
+
+ speech_pad = np.concatenate([
+ np.zeros((split_len - int(sum(chunksize[1:]) * 0.04 * 16000),), dtype=audio_16k.dtype),
+ audio_16k,
+ np.zeros((split_len,), dtype=audio_16k.dtype),
+ ], 0)
+
+ i = 0
+ res_lst = []
+ while i < num_f:
+ sss = int(i * 0.04 * 16000)
+ eee = sss + split_len
+ audio_chunk = speech_pad[sss:eee]
+ valid_feat = self.__call__(audio_chunk, chunksize)
+ res_lst.append(valid_feat)
+ i += chunksize[1]
+
+ ret = np.concatenate(res_lst, 0)
+ ret = ret[:num_f]
+ return ret
diff --git a/core/atomic_components/writer.py b/core/atomic_components/writer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c075b5a65fc75b992ebd1addf31cce81387eb5c
--- /dev/null
+++ b/core/atomic_components/writer.py
@@ -0,0 +1,36 @@
+import imageio
+import os
+
+
+class VideoWriterByImageIO:
+ def __init__(self, video_path, fps=25, **kwargs):
+ video_format = kwargs.get("format", "mp4") # default is mp4 format
+ codec = kwargs.get("vcodec", "libx264") # default is libx264 encoding
+ quality = kwargs.get("quality") # video quality
+ pixelformat = kwargs.get("pixelformat", "yuv420p") # video pixel format
+ macro_block_size = kwargs.get("macro_block_size", 2)
+ ffmpeg_params = ["-crf", str(kwargs.get("crf", 18))]
+
+ os.makedirs(os.path.dirname(video_path), exist_ok=True)
+
+ writer = imageio.get_writer(
+ video_path,
+ fps=fps,
+ format=video_format,
+ codec=codec,
+ quality=quality,
+ ffmpeg_params=ffmpeg_params,
+ pixelformat=pixelformat,
+ macro_block_size=macro_block_size,
+ )
+ self.writer = writer
+
+ def __call__(self, img, fmt="bgr"):
+ if fmt == "bgr":
+ frame = img[..., ::-1]
+ else:
+ frame = img
+ self.writer.append_data(frame)
+
+ def close(self):
+ self.writer.close()
diff --git a/core/aux_models/blaze_face.py b/core/aux_models/blaze_face.py
new file mode 100644
index 0000000000000000000000000000000000000000..508a5907d01f0ca30d51a7fc6c3a8a4a5ee93606
--- /dev/null
+++ b/core/aux_models/blaze_face.py
@@ -0,0 +1,351 @@
+import numpy as np
+import cv2
+from ..utils.load_model import load_model
+
+
+def intersect(box_a, box_b):
+ """We resize both arrays to [A,B,2] without new malloc:
+ [A,2] -> [A,1,2] -> [A,B,2]
+ [B,2] -> [1,B,2] -> [A,B,2]
+ Then we compute the area of intersect between box_a and box_b.
+ Args:
+ box_a: (array) bounding boxes, Shape: [A,4].
+ box_b: (array) bounding boxes, Shape: [B,4].
+ Return:
+ (array) intersection area, Shape: [A,B].
+ """
+ A = box_a.shape[0]
+ B = box_b.shape[0]
+ max_xy = np.minimum(
+ np.expand_dims(box_a[:, 2:], axis=1).repeat(B, axis=1),
+ np.expand_dims(box_b[:, 2:], axis=0).repeat(A, axis=0),
+ )
+ min_xy = np.maximum(
+ np.expand_dims(box_a[:, :2], axis=1).repeat(B, axis=1),
+ np.expand_dims(box_b[:, :2], axis=0).repeat(A, axis=0),
+ )
+ inter = np.clip((max_xy - min_xy), a_min=0, a_max=None)
+ return inter[:, :, 0] * inter[:, :, 1]
+
+
+def jaccard(box_a, box_b):
+ """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
+ is simply the intersection over union of two boxes. Here we operate on
+ ground truth boxes and default boxes.
+ E.g.:
+ A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
+ Args:
+ box_a: (array) Ground truth bounding boxes, Shape: [num_objects,4]
+ box_b: (array) Prior boxes from priorbox layers, Shape: [num_priors,4]
+ Return:
+ jaccard overlap: (array) Shape: [box_a.size(0), box_b.size(0)]
+ """
+ inter = intersect(box_a, box_b)
+ area_a = (
+ ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1]))
+ .reshape(-1, 1)
+ .repeat(box_b.shape[0], axis=1)
+ ) # [A,B]
+ area_b = (
+ ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1]))
+ .reshape(1, -1)
+ .repeat(box_a.shape[0], axis=0)
+ ) # [A,B]
+ union = area_a + area_b - inter
+ return inter / union # [A,B]
+
+
+def overlap_similarity(box, other_boxes):
+ """Computes the IOU between a bounding box and set of other boxes."""
+ box = np.expand_dims(box, axis=0) # Equivalent to unsqueeze(0) in PyTorch
+ iou = jaccard(box, other_boxes)
+ return np.squeeze(iou, axis=0) # Equivalent to squeeze(0) in PyTorch
+
+
+class BlazeFace:
+ def __init__(self, model_path, device="cuda"):
+ self.anchor_options = {
+ "num_layers": 4,
+ "min_scale": 0.1484375,
+ "max_scale": 0.75,
+ "input_size_height": 128,
+ "input_size_width": 128,
+ "anchor_offset_x": 0.5,
+ "anchor_offset_y": 0.5,
+ "strides": [8, 16, 16, 16],
+ "aspect_ratios": [1.0],
+ "reduce_boxes_in_lowest_layer": False,
+ "interpolated_scale_aspect_ratio": 1.0,
+ "fixed_anchor_size": True,
+ }
+ self.num_classes = 1
+ self.num_anchors = 896
+ self.num_coords = 16
+ self.x_scale = 128.0
+ self.y_scale = 128.0
+ self.h_scale = 128.0
+ self.w_scale = 128.0
+ self.min_score_thresh = 0.5
+ self.min_suppression_threshold = 0.3
+ self.anchors = self.generate_anchors(self.anchor_options)
+ self.anchors = np.array(self.anchors)
+ assert len(self.anchors) == 896
+ self.model, self.model_type = load_model(model_path, device=device)
+ self.output_names = ["regressors", "classificators"]
+
+ def __call__(self, image: np.ndarray):
+ """
+ image: RGB image
+ """
+ image = cv2.resize(image, (128, 128))
+ image = image[np.newaxis, :, :, :].astype(np.float32)
+ image = image / 127.5 - 1.0
+ outputs = {}
+ if self.model_type == "onnx":
+ out_list = self.model.run(None, {"input": image})
+ for i, name in enumerate(self.output_names):
+ outputs[name] = out_list[i]
+ elif self.model_type == "tensorrt":
+ self.model.setup({"input": image})
+ self.model.infer()
+ for name in self.output_names:
+ outputs[name] = self.model.buffer[name][0]
+ else:
+ raise ValueError(f"Unsupported model type: {self.model_type}")
+ boxes = self.postprocess(outputs["regressors"], outputs["classificators"])
+ return boxes
+
+ def calculate_scale(self, min_scale, max_scale, stride_index, num_strides):
+ return min_scale + (max_scale - min_scale) * stride_index / (num_strides - 1.0)
+
+ def generate_anchors(self, options):
+ strides_size = len(options["strides"])
+ assert options["num_layers"] == strides_size
+
+ anchors = []
+ layer_id = 0
+ while layer_id < strides_size:
+ anchor_height = []
+ anchor_width = []
+ aspect_ratios = []
+ scales = []
+
+ # For same strides, we merge the anchors in the same order.
+ last_same_stride_layer = layer_id
+ while (last_same_stride_layer < strides_size) and (
+ options["strides"][last_same_stride_layer]
+ == options["strides"][layer_id]
+ ):
+ scale = self.calculate_scale(
+ options["min_scale"],
+ options["max_scale"],
+ last_same_stride_layer,
+ strides_size,
+ )
+
+ if (
+ last_same_stride_layer == 0
+ and options["reduce_boxes_in_lowest_layer"]
+ ):
+ # For first layer, it can be specified to use predefined anchors.
+ aspect_ratios.append(1.0)
+ aspect_ratios.append(2.0)
+ aspect_ratios.append(0.5)
+ scales.append(0.1)
+ scales.append(scale)
+ scales.append(scale)
+ else:
+ for aspect_ratio in options["aspect_ratios"]:
+ aspect_ratios.append(aspect_ratio)
+ scales.append(scale)
+
+ if options["interpolated_scale_aspect_ratio"] > 0.0:
+ scale_next = (
+ 1.0
+ if last_same_stride_layer == strides_size - 1
+ else self.calculate_scale(
+ options["min_scale"],
+ options["max_scale"],
+ last_same_stride_layer + 1,
+ strides_size,
+ )
+ )
+ scales.append(np.sqrt(scale * scale_next))
+ aspect_ratios.append(options["interpolated_scale_aspect_ratio"])
+
+ last_same_stride_layer += 1
+
+ for i in range(len(aspect_ratios)):
+ ratio_sqrts = np.sqrt(aspect_ratios[i])
+ anchor_height.append(scales[i] / ratio_sqrts)
+ anchor_width.append(scales[i] * ratio_sqrts)
+
+ stride = options["strides"][layer_id]
+ feature_map_height = int(np.ceil(options["input_size_height"] / stride))
+ feature_map_width = int(np.ceil(options["input_size_width"] / stride))
+
+ for y in range(feature_map_height):
+ for x in range(feature_map_width):
+ for anchor_id in range(len(anchor_height)):
+ x_center = (x + options["anchor_offset_x"]) / feature_map_width
+ y_center = (y + options["anchor_offset_y"]) / feature_map_height
+
+ new_anchor = [x_center, y_center, 0, 0]
+ if options["fixed_anchor_size"]:
+ new_anchor[2] = 1.0
+ new_anchor[3] = 1.0
+ else:
+ new_anchor[2] = anchor_width[anchor_id]
+ new_anchor[3] = anchor_height[anchor_id]
+ anchors.append(new_anchor)
+
+ layer_id = last_same_stride_layer
+
+ return anchors
+
+ def _tensors_to_detections(self, raw_box_tensor, raw_score_tensor, anchors):
+ """The output of the neural network is a tensor of shape (b, 896, 16)
+ containing the bounding box regressor predictions, as well as a tensor
+ of shape (b, 896, 1) with the classification confidences.
+
+ This function converts these two "raw" tensors into proper detections.
+ Returns a list of (num_detections, 17) tensors, one for each image in
+ the batch.
+
+ This is based on the source code from:
+ mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc
+ mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.proto
+ """
+ assert raw_box_tensor.ndim == 3
+ assert raw_box_tensor.shape[1] == self.num_anchors
+ assert raw_box_tensor.shape[2] == self.num_coords
+
+ assert raw_score_tensor.ndim == 3
+ assert raw_score_tensor.shape[1] == self.num_anchors
+ assert raw_score_tensor.shape[2] == self.num_classes
+
+ assert raw_box_tensor.shape[0] == raw_score_tensor.shape[0]
+
+ detection_boxes = self._decode_boxes(raw_box_tensor, anchors)
+
+ raw_score_tensor = np.clip(raw_score_tensor, -50, 100)
+ detection_scores = 1 / (1 + np.exp(-raw_score_tensor))
+ mask = detection_scores >= self.min_score_thresh
+ mask = mask[0, :, 0]
+ boxes = detection_boxes[0, mask, :]
+ scores = detection_scores[0, mask, :]
+ return np.concatenate((boxes, scores), axis=-1)
+
+ def _decode_boxes(self, raw_boxes, anchors):
+ """Converts the predictions into actual coordinates using
+ the anchor boxes. Processes the entire batch at once.
+ """
+ boxes = np.zeros_like(raw_boxes)
+
+ x_center = raw_boxes[..., 0] / self.x_scale * anchors[:, 2] + anchors[:, 0]
+ y_center = raw_boxes[..., 1] / self.y_scale * anchors[:, 3] + anchors[:, 1]
+
+ w = raw_boxes[..., 2] / self.w_scale * anchors[:, 2]
+ h = raw_boxes[..., 3] / self.h_scale * anchors[:, 3]
+
+ boxes[..., 0] = self.x_scale * (x_center - w / 2.0) # xmin
+ boxes[..., 1] = self.y_scale * (y_center - h / 2.0) # ymin
+ boxes[..., 2] = self.w_scale * (x_center + w / 2.0) # xmax
+ boxes[..., 3] = self.h_scale * (y_center + h / 2.0) # ymax
+
+ for k in range(6):
+ offset = 4 + k * 2
+ keypoint_x = (
+ raw_boxes[..., offset] / self.x_scale * anchors[:, 2] + anchors[:, 0]
+ )
+ keypoint_y = (
+ raw_boxes[..., offset + 1] / self.y_scale * anchors[:, 3]
+ + anchors[:, 1]
+ )
+ boxes[..., offset] = keypoint_x
+ boxes[..., offset + 1] = keypoint_y
+
+ return boxes
+
+ def _weighted_non_max_suppression(self, detections):
+ """The alternative NMS method as mentioned in the BlazeFace paper:
+
+ "We replace the suppression algorithm with a blending strategy that
+ estimates the regression parameters of a bounding box as a weighted
+ mean between the overlapping predictions."
+
+ The original MediaPipe code assigns the score of the most confident
+ detection to the weighted detection, but we take the average score
+ of the overlapping detections.
+
+ The input detections should be a NumPy array of shape (count, 17).
+
+ Returns a list of NumPy arrays, one for each detected face.
+
+ This is based on the source code from:
+ mediapipe/calculators/util/non_max_suppression_calculator.cc
+ mediapipe/calculators/util/non_max_suppression_calculator.proto
+ """
+ if len(detections) == 0:
+ return []
+
+ output_detections = []
+
+ # Sort the detections from highest to lowest score.
+ remaining = np.argsort(detections[:, 16])[::-1]
+
+ while len(remaining) > 0:
+ detection = detections[remaining[0]]
+
+ # Compute the overlap between the first box and the other
+ # remaining boxes. (Note that the other_boxes also include
+ # the first_box.)
+ first_box = detection[:4]
+ other_boxes = detections[remaining, :4]
+ ious = overlap_similarity(first_box, other_boxes)
+
+ # If two detections don't overlap enough, they are considered
+ # to be from different faces.
+ mask = ious > self.min_suppression_threshold
+ overlapping = remaining[mask]
+ remaining = remaining[~mask]
+
+ # Take an average of the coordinates from the overlapping
+ # detections, weighted by their confidence scores.
+ weighted_detection = detection.copy()
+ if len(overlapping) > 1:
+ coordinates = detections[overlapping, :16]
+ scores = detections[overlapping, 16:17]
+ total_score = scores.sum()
+ weighted = (coordinates * scores).sum(axis=0) / total_score
+ weighted_detection[:16] = weighted
+ weighted_detection[16] = total_score / len(overlapping)
+
+ output_detections.append(weighted_detection)
+
+ return output_detections
+
+ def postprocess(self, raw_boxes, scores):
+ detections = self._tensors_to_detections(raw_boxes, scores, self.anchors)
+
+ detections = self._weighted_non_max_suppression(detections)
+ detections = np.array(detections)
+ return detections
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", type=str, default="")
+ parser.add_argument("--image", type=str, default=None)
+ args = parser.parse_args()
+
+ blaze_face = BlazeFace(args.model)
+ image = cv2.imread(args.image)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ image = cv2.resize(image, (128, 128))
+ image = image[np.newaxis, :, :, :].astype(np.float32)
+ image = image / 127.5 - 1.0
+ boxes = blaze_face(image)
+ print(boxes)
diff --git a/core/aux_models/face_mesh.py b/core/aux_models/face_mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..0557f572b5a89c13f5caa887aaf14308b1451b11
--- /dev/null
+++ b/core/aux_models/face_mesh.py
@@ -0,0 +1,101 @@
+import cv2
+import numpy as np
+
+from ..utils.load_model import load_model
+
+
+class FaceMesh:
+ def __init__(self, model_path, device="cuda"):
+ self.model, self.model_type = load_model(model_path, device=device)
+ self.input_size = (256, 256) # (w, h)
+ self.output_names = [
+ "Identity",
+ "Identity_1",
+ "Identity_2",
+ ] # Identity is the mesh
+
+ def project_landmarks(self, points, roi):
+ width, height = self.input_size
+ points /= (width, height, width)
+ sin, cos = np.sin(roi[4]), np.cos(roi[4])
+ matrix = np.array([[cos, sin, 0.0], [-sin, cos, 0.0], [1.0, 1.0, 1.0]])
+ points -= (0.5, 0.5, 0.0)
+ rotated = np.matmul(points * (1, 1, 0), matrix)
+ points *= (0, 0, 1)
+ points += rotated
+ points *= (roi[2], roi[3], roi[2])
+ points += (roi[0], roi[1], 0.0)
+ return points
+
+ def __call__(self, image, roi):
+ """
+ image: np.ndarray, RGB, (H, W, C), [0, 255]
+ roi: np.ndarray, (cx, cy, w, h, rotation), rotation in radian
+ """
+ cx, cy, w, h = roi[:4]
+ w_half, h_half = w / 2, h / 2
+ pts = [
+ (cx - w_half, cy - h_half),
+ (cx + w_half, cy - h_half),
+ (cx + w_half, cy + h_half),
+ (cx - w_half, cy + h_half),
+ ]
+ rotation = roi[4]
+ s, c = np.sin(rotation), np.cos(rotation)
+ t = np.array(pts) - (cx, cy)
+ r = np.array([[c, s], [-s, c]])
+ src_pts = np.matmul(t, r) + (cx, cy)
+ src_pts = src_pts.astype(np.float32)
+
+ dst_pts = np.array(
+ [
+ [0.0, 0.0],
+ [self.input_size[0], 0.0],
+ [self.input_size[0], self.input_size[1]],
+ [0.0, self.input_size[1]],
+ ]
+ ).astype(np.float32)
+ M = cv2.getPerspectiveTransform(src_pts, dst_pts)
+ roi_image = cv2.warpPerspective(
+ image, M, self.input_size, flags=cv2.INTER_LINEAR
+ )
+ # cv2.imwrite('test.jpg', cv2.cvtColor(roi_image, cv2.COLOR_RGB2BGR))
+ roi_image = roi_image / 255.0
+ roi_image = roi_image.astype(np.float32)
+ roi_image = roi_image[np.newaxis, :, :, :]
+
+ outputs = {}
+ if self.model_type == "onnx":
+ out_list = self.model.run(None, {"input": roi_image})
+ for i, name in enumerate(self.output_names):
+ outputs[name] = out_list[i]
+ elif self.model_type == "tensorrt":
+ self.model.setup({"input": roi_image})
+ self.model.infer()
+ for name in self.output_names:
+ outputs[name] = self.model.buffer[name][0]
+ else:
+ raise ValueError(f"Unsupported model type: {self.model_type}")
+ points = outputs["Identity"].reshape(1434 // 3, 3)
+ points = self.project_landmarks(points, roi)
+ return points
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", type=str, help="model path")
+ parser.add_argument("--image", type=str, help="image path")
+ parser.add_argument("--device", type=str, default="cuda", help="device")
+ args = parser.parse_args()
+
+ face_mesh = FaceMesh(args.model, args.device)
+ image = cv2.imread(args.image, cv2.IMREAD_COLOR)
+ image = cv2.resize(image, (256, 256))
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+
+ face_mesh = FaceMesh(args.model, args.device)
+ roi = np.array([128, 128, 256, 256, np.pi / 2])
+ mesh = face_mesh(image, roi)
+ print(mesh.shape)
diff --git a/core/aux_models/hubert_stream.py b/core/aux_models/hubert_stream.py
new file mode 100644
index 0000000000000000000000000000000000000000..30387c6df754d129d6f1815e131718e68069e411
--- /dev/null
+++ b/core/aux_models/hubert_stream.py
@@ -0,0 +1,29 @@
+from ..utils.load_model import load_model
+
+
+class HubertStreaming:
+ def __init__(self, model_path, device="cuda", **kwargs):
+ kwargs["model_file"] = model_path
+ kwargs["module_name"] = "HubertStreamingONNX"
+ kwargs["package_name"] = "..aux_models.modules"
+
+ self.model, self.model_type = load_model(model_path, device=device, **kwargs)
+ self.device = device
+
+ def forward_chunk(self, audio_chunk):
+ if self.model_type == "onnx":
+ output = self.model.run(None, {"input_values": audio_chunk.reshape(1, -1)})[0]
+ elif self.model_type == "tensorrt":
+ self.model.setup({"input_values": audio_chunk.reshape(1, -1)})
+ self.model.infer()
+ output = self.model.buffer["encoding_out"][0]
+ else:
+ raise ValueError(f"Unsupported model type: {self.model_type}")
+ return output
+
+ def __call__(self, audio_chunk):
+ if self.model_type == "ori":
+ output = self.model.forward_chunk(audio_chunk)
+ else:
+ output = self.forward_chunk(audio_chunk)
+ return output
diff --git a/core/aux_models/insightface_det.py b/core/aux_models/insightface_det.py
new file mode 100644
index 0000000000000000000000000000000000000000..10a6774b2dd69ee0dfecc4fb3b74a04093c7f5ba
--- /dev/null
+++ b/core/aux_models/insightface_det.py
@@ -0,0 +1,245 @@
+from __future__ import division
+import numpy as np
+import cv2
+
+from ..utils.load_model import load_model
+
+
+def distance2bbox(points, distance, max_shape=None):
+ """Decode distance prediction to bounding box.
+
+ Args:
+ points (Tensor): Shape (n, 2), [x, y].
+ distance (Tensor): Distance from the given point to 4
+ boundaries (left, top, right, bottom).
+ max_shape (tuple): Shape of the image.
+
+ Returns:
+ Tensor: Decoded bboxes.
+ """
+ x1 = points[:, 0] - distance[:, 0]
+ y1 = points[:, 1] - distance[:, 1]
+ x2 = points[:, 0] + distance[:, 2]
+ y2 = points[:, 1] + distance[:, 3]
+ if max_shape is not None:
+ x1 = x1.clamp(min=0, max=max_shape[1])
+ y1 = y1.clamp(min=0, max=max_shape[0])
+ x2 = x2.clamp(min=0, max=max_shape[1])
+ y2 = y2.clamp(min=0, max=max_shape[0])
+ return np.stack([x1, y1, x2, y2], axis=-1)
+
+
+def distance2kps(points, distance, max_shape=None):
+ """Decode distance prediction to bounding box.
+
+ Args:
+ points (Tensor): Shape (n, 2), [x, y].
+ distance (Tensor): Distance from the given point to 4
+ boundaries (left, top, right, bottom).
+ max_shape (tuple): Shape of the image.
+
+ Returns:
+ Tensor: Decoded bboxes.
+ """
+ preds = []
+ for i in range(0, distance.shape[1], 2):
+ px = points[:, i%2] + distance[:, i]
+ py = points[:, i%2+1] + distance[:, i+1]
+ if max_shape is not None:
+ px = px.clamp(min=0, max=max_shape[1])
+ py = py.clamp(min=0, max=max_shape[0])
+ preds.append(px)
+ preds.append(py)
+ return np.stack(preds, axis=-1)
+
+
+class InsightFaceDet:
+ def __init__(self, model_path, device="cuda", **kwargs):
+ kwargs["model_file"] = model_path
+ kwargs["module_name"] = "RetinaFace"
+ kwargs["package_name"] = "..aux_models.modules"
+
+ self.model, self.model_type = load_model(model_path, device=device, **kwargs)
+ self.device = device
+
+ if self.model_type != "ori":
+ self._init_vars()
+
+ def _init_vars(self):
+ self.center_cache = {}
+
+ self.nms_thresh = 0.4
+ self.det_thresh = 0.5
+
+ self.input_size = (512, 512)
+ self.input_mean = 127.5
+ self.input_std = 128.0
+ self._anchor_ratio = 1.0
+ self.fmc = 3
+ self._feat_stride_fpn = [8, 16, 32]
+ self._num_anchors = 2
+ self.use_kps = True
+
+ self.output_names = [
+ "scores1",
+ "scores2",
+ "scores3",
+ "boxes1",
+ "boxes2",
+ "boxes3",
+ "kps1",
+ "kps2",
+ "kps3",
+ ]
+
+ def _run_model(self, blob):
+ if self.model_type == "onnx":
+ net_outs = self.model.run(None, {"image": blob})
+ elif self.model_type == "tensorrt":
+ self.model.setup({"image": blob})
+ self.model.infer()
+ net_outs = [self.model.buffer[name][0] for name in self.output_names]
+ else:
+ raise ValueError(f"Unsupported model type: {self.model_type}")
+ return net_outs
+
+ def _forward(self, img, threshold):
+ """
+ img: np.ndarray, shape (h, w, 3)
+ """
+ scores_list = []
+ bboxes_list = []
+ kpss_list = []
+ input_size = tuple(img.shape[0:2][::-1])
+ blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+ # (1, 3, 512, 512)
+ net_outs = self._run_model(blob)
+
+ input_height = blob.shape[2]
+ input_width = blob.shape[3]
+ fmc = self.fmc
+ for idx, stride in enumerate(self._feat_stride_fpn):
+ scores = net_outs[idx]
+ bbox_preds = net_outs[idx+fmc]
+ bbox_preds = bbox_preds * stride
+ if self.use_kps:
+ kps_preds = net_outs[idx+fmc*2] * stride
+ height = input_height // stride
+ width = input_width // stride
+ # K = height * width
+ key = (height, width, stride)
+ if key in self.center_cache:
+ anchor_centers = self.center_cache[key]
+ else:
+ #solution-3:
+ anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
+ anchor_centers = (anchor_centers * stride).reshape( (-1, 2) )
+ if self._num_anchors>1:
+ anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) )
+ if len(self.center_cache)<100:
+ self.center_cache[key] = anchor_centers
+
+ pos_inds = np.where(scores>=threshold)[0]
+ bboxes = distance2bbox(anchor_centers, bbox_preds)
+ pos_scores = scores[pos_inds]
+ pos_bboxes = bboxes[pos_inds]
+ scores_list.append(pos_scores)
+ bboxes_list.append(pos_bboxes)
+ if self.use_kps:
+ kpss = distance2kps(anchor_centers, kps_preds)
+ kpss = kpss.reshape( (kpss.shape[0], -1, 2) )
+ pos_kpss = kpss[pos_inds]
+ kpss_list.append(pos_kpss)
+ return scores_list, bboxes_list, kpss_list
+
+ def detect(self, img, input_size=None, max_num=0, metric='default', det_thresh=None):
+ input_size = self.input_size if input_size is None else input_size
+ det_thresh = self.det_thresh if det_thresh is None else det_thresh
+
+ im_ratio = float(img.shape[0]) / img.shape[1]
+ model_ratio = float(input_size[1]) / input_size[0]
+ if im_ratio>model_ratio:
+ new_height = input_size[1]
+ new_width = int(new_height / im_ratio)
+ else:
+ new_width = input_size[0]
+ new_height = int(new_width * im_ratio)
+ det_scale = float(new_height) / img.shape[0]
+ resized_img = cv2.resize(img, (new_width, new_height))
+ det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 )
+ det_img[:new_height, :new_width, :] = resized_img
+
+ scores_list, bboxes_list, kpss_list = self._forward(det_img, det_thresh)
+
+ scores = np.vstack(scores_list)
+ scores_ravel = scores.ravel()
+ order = scores_ravel.argsort()[::-1]
+ bboxes = np.vstack(bboxes_list) / det_scale
+ if self.use_kps:
+ kpss = np.vstack(kpss_list) / det_scale
+ pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
+ pre_det = pre_det[order, :]
+ keep = self.nms(pre_det)
+ det = pre_det[keep, :]
+ if self.use_kps:
+ kpss = kpss[order,:,:]
+ kpss = kpss[keep,:,:]
+ else:
+ kpss = None
+ if max_num > 0 and det.shape[0] > max_num:
+ area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1])
+ img_center = img.shape[0] // 2, img.shape[1] // 2
+ offsets = np.vstack([
+ (det[:, 0] + det[:, 2]) / 2 - img_center[1],
+ (det[:, 1] + det[:, 3]) / 2 - img_center[0]
+ ])
+ offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
+ if metric=='max':
+ values = area
+ else:
+ values = area - offset_dist_squared * 2.0 # some extra weight on the centering
+ bindex = np.argsort(values)[::-1] # some extra weight on the centering
+ bindex = bindex[0:max_num]
+ det = det[bindex, :]
+ if kpss is not None:
+ kpss = kpss[bindex, :]
+ return det, kpss
+
+ def nms(self, dets):
+ thresh = self.nms_thresh
+ x1 = dets[:, 0]
+ y1 = dets[:, 1]
+ x2 = dets[:, 2]
+ y2 = dets[:, 3]
+ scores = dets[:, 4]
+
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ order = scores.argsort()[::-1]
+
+ keep = []
+ while order.size > 0:
+ i = order[0]
+ keep.append(i)
+ xx1 = np.maximum(x1[i], x1[order[1:]])
+ yy1 = np.maximum(y1[i], y1[order[1:]])
+ xx2 = np.minimum(x2[i], x2[order[1:]])
+ yy2 = np.minimum(y2[i], y2[order[1:]])
+
+ w = np.maximum(0.0, xx2 - xx1 + 1)
+ h = np.maximum(0.0, yy2 - yy1 + 1)
+ inter = w * h
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
+
+ inds = np.where(ovr <= thresh)[0]
+ order = order[inds + 1]
+
+ return keep
+
+ def __call__(self, img, **kwargs):
+ if self.model_type == "ori":
+ det, kpss = self.model.detect(img, **kwargs)
+ else:
+ det, kpss = self.detect(img, **kwargs)
+
+ return det, kpss
+
diff --git a/core/aux_models/insightface_landmark106.py b/core/aux_models/insightface_landmark106.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2a701a628cb848fd36218d1bc5631c627bd7f9f
--- /dev/null
+++ b/core/aux_models/insightface_landmark106.py
@@ -0,0 +1,100 @@
+from __future__ import division
+import numpy as np
+import torch
+import cv2
+from skimage import transform as trans
+
+from ..utils.load_model import load_model
+
+
+def transform(data, center, output_size, scale, rotation):
+ scale_ratio = scale
+ rot = float(rotation) * np.pi / 180.0
+
+ t1 = trans.SimilarityTransform(scale=scale_ratio)
+ cx = center[0] * scale_ratio
+ cy = center[1] * scale_ratio
+ t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
+ t3 = trans.SimilarityTransform(rotation=rot)
+ t4 = trans.SimilarityTransform(translation=(output_size / 2,
+ output_size / 2))
+ t = t1 + t2 + t3 + t4
+ M = t.params[0:2]
+ cropped = cv2.warpAffine(data,
+ M, (output_size, output_size),
+ borderValue=0.0)
+ return cropped, M
+
+
+def trans_points2d(pts, M):
+ new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
+ for i in range(pts.shape[0]):
+ pt = pts[i]
+ new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
+ new_pt = np.dot(M, new_pt)
+ new_pts[i] = new_pt[0:2]
+
+ return new_pts
+
+
+class Landmark106:
+ def __init__(self, model_path, device="cuda", **kwargs):
+ kwargs["model_file"] = model_path
+ kwargs["module_name"] = "Landmark106"
+ kwargs["package_name"] = "..aux_models.modules"
+
+ self.model, self.model_type = load_model(model_path, device=device, **kwargs)
+ self.device = device
+
+ if self.model_type != "ori":
+ self._init_vars()
+
+ def _init_vars(self):
+ self.input_mean = 0.0
+ self.input_std = 1.0
+ self.input_size = (192, 192)
+ self.lmk_num = 106
+
+ self.output_names = ["fc1"]
+
+ def _run_model(self, blob):
+ if self.model_type == "onnx":
+ pred = self.model.run(None, {"data": blob})[0]
+ elif self.model_type == "tensorrt":
+ self.model.setup({"data": blob})
+ self.model.infer()
+ pred = self.model.buffer[self.output_names[0]][0]
+ else:
+ raise ValueError(f"Unsupported model type: {self.model_type}")
+ return pred
+
+ def get(self, img, bbox):
+ w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1])
+ center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2
+ rotate = 0
+ _scale = self.input_size[0] / (max(w, h)*1.5)
+
+ aimg, M = transform(img, center, self.input_size[0], _scale, rotate)
+ input_size = tuple(aimg.shape[0:2][::-1])
+
+ blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+
+ pred = self._run_model(blob)
+
+ pred = pred.reshape((-1, 2))
+ if self.lmk_num < pred.shape[0]:
+ pred = pred[self.lmk_num*-1:,:]
+ pred[:, 0:2] += 1
+ pred[:, 0:2] *= (self.input_size[0] // 2)
+
+ IM = cv2.invertAffineTransform(M)
+ pred = trans_points2d(pred, IM)
+ return pred
+
+ def __call__(self, img, bbox):
+ if self.model_type == "ori":
+ pred = self.model.get(img, bbox)
+ else:
+ pred = self.get(img, bbox)
+
+ return pred
diff --git a/core/aux_models/landmark203.py b/core/aux_models/landmark203.py
new file mode 100644
index 0000000000000000000000000000000000000000..72f6ef1898e0d2ed54d9b68b71ba72a1be0a6da2
--- /dev/null
+++ b/core/aux_models/landmark203.py
@@ -0,0 +1,58 @@
+import numpy as np
+from ..utils.load_model import load_model
+
+
+def _transform_pts(pts, M):
+ """ conduct similarity or affine transformation to the pts
+ pts: Nx2 ndarray
+ M: 2x3 matrix or 3x3 matrix
+ return: Nx2
+ """
+ return pts @ M[:2, :2].T + M[:2, 2]
+
+
+class Landmark203:
+ def __init__(self, model_path, device="cuda", **kwargs):
+ kwargs["model_file"] = model_path
+ kwargs["module_name"] = "Landmark203"
+ kwargs["package_name"] = "..aux_models.modules"
+
+ self.model, self.model_type = load_model(model_path, device=device, **kwargs)
+ self.device = device
+
+ self.output_names = ["landmarks"]
+ self.dsize = 224
+
+ def _run_model(self, inp):
+ if self.model_type == "onnx":
+ out_pts = self.model.run(None, {"input": inp})[0]
+ elif self.model_type == "tensorrt":
+ self.model.setup({"input": inp})
+ self.model.infer()
+ out_pts = self.model.buffer[self.output_names[0]][0]
+ else:
+ raise ValueError(f"Unsupported model type: {self.model_type}")
+ return out_pts
+
+ def run(self, img_crop_rgb, M_c2o=None):
+ # img_crop_rgb: 224x224
+
+ inp = (img_crop_rgb.astype(np.float32) / 255.).transpose(2, 0, 1)[None, ...] # HxWx3 (BGR) -> 1x3xHxW (RGB!)
+
+ out_pts = self._run_model(inp)
+
+ # 2d landmarks 203 points
+ lmk = out_pts[0].reshape(-1, 2) * self.dsize # scale to 0-224
+ if M_c2o is not None:
+ lmk = _transform_pts(lmk, M=M_c2o)
+
+ return lmk
+
+ def __call__(self, img_crop_rgb, M_c2o=None):
+ if self.model_type == "ori":
+ lmk = self.model.run(img_crop_rgb, M_c2o)
+ else:
+ lmk = self.run(img_crop_rgb, M_c2o)
+
+ return lmk
+
\ No newline at end of file
diff --git a/core/aux_models/mediapipe_landmark478.py b/core/aux_models/mediapipe_landmark478.py
new file mode 100644
index 0000000000000000000000000000000000000000..305545b25cef270b5f00e453f13ef4b29a12faff
--- /dev/null
+++ b/core/aux_models/mediapipe_landmark478.py
@@ -0,0 +1,118 @@
+from enum import Enum
+import numpy as np
+
+from ..utils.load_model import load_model
+from .blaze_face import BlazeFace
+from .face_mesh import FaceMesh
+
+
+class SizeMode(Enum):
+ DEFAULT = 0
+ SQUARE_LONG = 1
+ SQUARE_SHORT = 2
+
+
+def _select_roi_size(
+ bbox: np.ndarray, image_size, size_mode: SizeMode # x1, y1, x2, y2 # w,h
+):
+ """Return the size of an ROI based on bounding box, image size and mode"""
+ width, height = bbox[2] - bbox[0], bbox[3] - bbox[1]
+ image_width, image_height = image_size
+ if size_mode == SizeMode.SQUARE_LONG:
+ long_size = max(width, height)
+ width, height = long_size, long_size
+ elif size_mode == SizeMode.SQUARE_SHORT:
+ short_side = min(width, height)
+ width, height = short_side, short_side
+ return width, height
+
+
+def bbox_to_roi(
+ bbox: np.ndarray,
+ image_size, # w,h
+ rotation_keypoints=None,
+ scale=(1.0, 1.0), # w, h
+ size_mode: SizeMode = SizeMode.SQUARE_LONG,
+):
+ PI = np.pi
+ TWO_PI = 2 * np.pi
+ # select ROI dimensions
+ width, height = _select_roi_size(bbox, image_size, size_mode)
+ scale_x, scale_y = scale
+ # calculate ROI size and -centre
+ width, height = width * scale_x, height * scale_y
+ cx = (bbox[0] + bbox[2]) / 2
+ cy = (bbox[1] + bbox[3]) / 2
+ # calculate rotation of required
+ if rotation_keypoints is None or len(rotation_keypoints) < 2:
+ return np.array([cx, cy, width, height, 0])
+ x0, y0 = rotation_keypoints[0]
+ x1, y1 = rotation_keypoints[1]
+ angle = -np.atan2(y0 - y1, x1 - x0)
+ # normalise to [0, 2*PI]
+ rotation = angle - TWO_PI * np.floor((angle + PI) / TWO_PI)
+ return np.array([cx, cy, width, height, rotation])
+
+
+class Landmark478:
+ def __init__(self, blaze_face_model_path="", face_mesh_model_path="", device="cuda", **kwargs):
+ if kwargs.get("force_ori_type", False):
+ assert "task_path" in kwargs
+ kwargs["module_name"] = "Landmark478"
+ kwargs["package_name"] = "..aux_models.modules"
+ self.model, self.model_type = load_model("", device=device, **kwargs)
+ else:
+ self.blaze_face = BlazeFace(blaze_face_model_path, device)
+ self.face_mesh = FaceMesh(face_mesh_model_path, device)
+ self.model_type = ""
+
+ def get(self, image):
+ bboxes = self.blaze_face(image)
+ if len(bboxes) == 0:
+ return None
+ bbox = bboxes[0]
+ scale = (image.shape[1] / 128.0, image.shape[0] / 128.0)
+
+ # The first 4 numbers describe the bounding box corners:
+ #
+ # ymin, xmin, ymax, xmax
+ # These are normalized coordinates (between 0 and 1).
+ # The next 12 numbers are the x,y-coordinates of the 6 facial landmark keypoints:
+ #
+ # right_eye_x, right_eye_y
+ # left_eye_x, left_eye_y
+ # nose_x, nose_y
+ # mouth_x, mouth_y
+ # right_ear_x, right_ear_y
+ # left_ear_x, left_ear_y
+ # Tip: these labeled as seen from the perspective of the person, so their right is your left.
+ # The final number is the confidence score that this detection really is a face.
+
+ bbox[0] = bbox[0] * scale[1]
+ bbox[1] = bbox[1] * scale[0]
+ bbox[2] = bbox[2] * scale[1]
+ bbox[3] = bbox[3] * scale[0]
+ left_eye = (bbox[4], bbox[5])
+ right_eye = (bbox[6], bbox[7])
+
+ roi = bbox_to_roi(
+ bbox,
+ (image.shape[1], image.shape[0]),
+ rotation_keypoints=[left_eye, right_eye],
+ scale=(1.5, 1.5),
+ size_mode=SizeMode.SQUARE_LONG,
+ )
+
+ mesh = self.face_mesh(image, roi)
+ mesh = mesh / (image.shape[1], image.shape[0], image.shape[1])
+ return mesh
+
+ def __call__(self, image):
+ if self.model_type == "ori":
+ det = self.model.detect_from_npimage(image.copy())
+ lmk = self.model.mplmk_to_nplmk(det)
+ return lmk
+ else:
+ lmk = self.get(image)
+ lmk = lmk.reshape(1, -1, 3).astype(np.float32)
+ return lmk
diff --git a/core/aux_models/modules/__init__.py b/core/aux_models/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2542938c6fcffcc75d908fbe3c9d1c547fcadec4
--- /dev/null
+++ b/core/aux_models/modules/__init__.py
@@ -0,0 +1,5 @@
+from .retinaface import RetinaFace
+from .landmark106 import Landmark106
+from .landmark203 import Landmark203
+from .landmark478 import Landmark478
+from .hubert_stream import HubertStreamingONNX
\ No newline at end of file
diff --git a/core/aux_models/modules/hubert_stream.py b/core/aux_models/modules/hubert_stream.py
new file mode 100644
index 0000000000000000000000000000000000000000..04f0060f3a6934fb273c357e3a69b2263f4d67ab
--- /dev/null
+++ b/core/aux_models/modules/hubert_stream.py
@@ -0,0 +1,21 @@
+
+import onnxruntime
+
+
+class HubertStreamingONNX:
+ def __init__(self, model_file, device="cuda"):
+ if device == "cuda":
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
+ else:
+ providers = ["CPUExecutionProvider"]
+
+ self.session = onnxruntime.InferenceSession(model_file, providers=providers)
+
+ def forward_chunk(self, input_values):
+ encoding_out = self.session.run(
+ None,
+ {"input_values": input_values.reshape(1, -1)}
+ )[0]
+ return encoding_out
+
+
diff --git a/core/aux_models/modules/landmark106.py b/core/aux_models/modules/landmark106.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc6210dabed2c7493775c3483bb7a6e861130fb8
--- /dev/null
+++ b/core/aux_models/modules/landmark106.py
@@ -0,0 +1,83 @@
+# insightface
+from __future__ import division
+import onnxruntime
+import cv2
+import numpy as np
+from skimage import transform as trans
+
+
+def transform(data, center, output_size, scale, rotation):
+ scale_ratio = scale
+ rot = float(rotation) * np.pi / 180.0
+
+ t1 = trans.SimilarityTransform(scale=scale_ratio)
+ cx = center[0] * scale_ratio
+ cy = center[1] * scale_ratio
+ t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
+ t3 = trans.SimilarityTransform(rotation=rot)
+ t4 = trans.SimilarityTransform(translation=(output_size / 2,
+ output_size / 2))
+ t = t1 + t2 + t3 + t4
+ M = t.params[0:2]
+ cropped = cv2.warpAffine(data,
+ M, (output_size, output_size),
+ borderValue=0.0)
+ return cropped, M
+
+
+def trans_points2d(pts, M):
+ new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
+ for i in range(pts.shape[0]):
+ pt = pts[i]
+ new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
+ new_pt = np.dot(M, new_pt)
+ new_pts[i] = new_pt[0:2]
+
+ return new_pts
+
+
+
+class Landmark106:
+ def __init__(self, model_file, device="cuda"):
+ if device == "cuda":
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
+ else:
+ providers = ["CPUExecutionProvider"]
+ self.session = onnxruntime.InferenceSession(model_file, providers=providers)
+
+ self.input_mean = 0.0
+ self.input_std = 1.0
+ self.input_size = (192, 192)
+ input_cfg = self.session.get_inputs()[0]
+ input_name = input_cfg.name
+ outputs = self.session.get_outputs()
+ output_names = []
+ for out in outputs:
+ output_names.append(out.name)
+ self.input_name = input_name
+ self.output_names = output_names
+ self.lmk_num = 106
+
+ def get(self, img, bbox):
+ w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1])
+ center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2
+ rotate = 0
+ _scale = self.input_size[0] / (max(w, h)*1.5)
+
+ aimg, M = transform(img, center, self.input_size[0], _scale, rotate)
+ input_size = tuple(aimg.shape[0:2][::-1])
+
+ blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+
+ pred = self.session.run(self.output_names, {self.input_name : blob})[0][0]
+
+ pred = pred.reshape((-1, 2))
+ if self.lmk_num < pred.shape[0]:
+ pred = pred[self.lmk_num*-1:,:]
+ pred[:, 0:2] += 1
+ pred[:, 0:2] *= (self.input_size[0] // 2)
+
+ IM = cv2.invertAffineTransform(M)
+ pred = trans_points2d(pred, IM)
+ return pred
+
diff --git a/core/aux_models/modules/landmark203.py b/core/aux_models/modules/landmark203.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ff41f1f33e5b155e86113802fa8490ba30eb3eb
--- /dev/null
+++ b/core/aux_models/modules/landmark203.py
@@ -0,0 +1,42 @@
+import onnxruntime
+import numpy as np
+
+
+def _transform_pts(pts, M):
+ """ conduct similarity or affine transformation to the pts
+ pts: Nx2 ndarray
+ M: 2x3 matrix or 3x3 matrix
+ return: Nx2
+ """
+ return pts @ M[:2, :2].T + M[:2, 2]
+
+
+class Landmark203:
+ def __init__(self, model_file, device="cuda"):
+ if device == "cuda":
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
+ else:
+ providers = ["CPUExecutionProvider"]
+ self.session = onnxruntime.InferenceSession(model_file, providers=providers)
+
+ self.dsize = 224
+
+ def _run(self, inp):
+ out = self.session.run(None, {'input': inp})
+ return out
+
+ def run(self, img_crop_rgb, M_c2o=None):
+ # img_crop_rgb: 224x224
+
+ inp = (img_crop_rgb.astype(np.float32) / 255.).transpose(2, 0, 1)[None, ...] # HxWx3 (BGR) -> 1x3xHxW (RGB!)
+
+ out_lst = self._run(inp)
+ out_pts = out_lst[2]
+
+ # 2d landmarks 203 points
+ lmk = out_pts[0].reshape(-1, 2) * self.dsize # scale to 0-224
+ if M_c2o is not None:
+ lmk = _transform_pts(lmk, M=M_c2o)
+
+ return lmk
+
diff --git a/core/aux_models/modules/landmark478.py b/core/aux_models/modules/landmark478.py
new file mode 100644
index 0000000000000000000000000000000000000000..2614a02b9eda7a09f97646dcb5a1fe218c3f5df0
--- /dev/null
+++ b/core/aux_models/modules/landmark478.py
@@ -0,0 +1,35 @@
+import numpy as np
+import mediapipe as mp
+from mediapipe.tasks.python import vision, BaseOptions
+
+
+class Landmark478:
+ def __init__(self, task_path):
+ base_options = BaseOptions(model_asset_path=task_path)
+ options = vision.FaceLandmarkerOptions(
+ base_options=base_options,
+ output_face_blendshapes=True,
+ output_facial_transformation_matrixes=True,
+ num_faces=1,
+ )
+ detector = vision.FaceLandmarker.create_from_options(options)
+ self.detector = detector
+
+ def detect_from_imp(self, imp):
+ image = mp.Image.create_from_file(imp)
+ detection_result = self.detector.detect(image)
+ return detection_result
+
+ def detect_from_npimage(self, img):
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
+ detection_result = self.detector.detect(image)
+ return detection_result
+
+ @staticmethod
+ def mplmk_to_nplmk(results):
+ face_landmarks_list = results.face_landmarks
+ np_lms = []
+ for face_lms in face_landmarks_list:
+ lms = [[lm.x, lm.y, lm.z] for lm in face_lms]
+ np_lms.append(lms)
+ return np.array(np_lms).astype(np.float32)
diff --git a/core/aux_models/modules/retinaface.py b/core/aux_models/modules/retinaface.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca5f33300d0f535c1d715e48ee860010f2451320
--- /dev/null
+++ b/core/aux_models/modules/retinaface.py
@@ -0,0 +1,215 @@
+# insightface
+from __future__ import division
+import onnxruntime
+import cv2
+import numpy as np
+
+
+def distance2bbox(points, distance, max_shape=None):
+ """Decode distance prediction to bounding box.
+
+ Args:
+ points (Tensor): Shape (n, 2), [x, y].
+ distance (Tensor): Distance from the given point to 4
+ boundaries (left, top, right, bottom).
+ max_shape (tuple): Shape of the image.
+
+ Returns:
+ Tensor: Decoded bboxes.
+ """
+ x1 = points[:, 0] - distance[:, 0]
+ y1 = points[:, 1] - distance[:, 1]
+ x2 = points[:, 0] + distance[:, 2]
+ y2 = points[:, 1] + distance[:, 3]
+ if max_shape is not None:
+ x1 = x1.clamp(min=0, max=max_shape[1])
+ y1 = y1.clamp(min=0, max=max_shape[0])
+ x2 = x2.clamp(min=0, max=max_shape[1])
+ y2 = y2.clamp(min=0, max=max_shape[0])
+ return np.stack([x1, y1, x2, y2], axis=-1)
+
+
+def distance2kps(points, distance, max_shape=None):
+ """Decode distance prediction to bounding box.
+
+ Args:
+ points (Tensor): Shape (n, 2), [x, y].
+ distance (Tensor): Distance from the given point to 4
+ boundaries (left, top, right, bottom).
+ max_shape (tuple): Shape of the image.
+
+ Returns:
+ Tensor: Decoded bboxes.
+ """
+ preds = []
+ for i in range(0, distance.shape[1], 2):
+ px = points[:, i%2] + distance[:, i]
+ py = points[:, i%2+1] + distance[:, i+1]
+ if max_shape is not None:
+ px = px.clamp(min=0, max=max_shape[1])
+ py = py.clamp(min=0, max=max_shape[0])
+ preds.append(px)
+ preds.append(py)
+ return np.stack(preds, axis=-1)
+
+
+class RetinaFace:
+ def __init__(self, model_file, device="cuda"):
+ if device == "cuda":
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
+ else:
+ providers = ["CPUExecutionProvider"]
+ self.session = onnxruntime.InferenceSession(model_file, providers=providers)
+
+ self.center_cache = {}
+ self.nms_thresh = 0.4
+ self.det_thresh = 0.5
+ self._init_vars()
+
+ def _init_vars(self):
+ self.input_size = (512, 512)
+ input_cfg = self.session.get_inputs()[0]
+ input_name = input_cfg.name
+ outputs = self.session.get_outputs()
+ output_names = []
+ for o in outputs:
+ output_names.append(o.name)
+ self.input_name = input_name
+ self.output_names = output_names
+ self.input_mean = 127.5
+ self.input_std = 128.0
+ self._anchor_ratio = 1.0
+ self.fmc = 3
+ self._feat_stride_fpn = [8, 16, 32]
+ self._num_anchors = 2
+ self.use_kps = True
+
+ def forward(self, img, threshold):
+ scores_list = []
+ bboxes_list = []
+ kpss_list = []
+ input_size = tuple(img.shape[0:2][::-1])
+ blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+ net_outs = self.session.run(self.output_names, {self.input_name : blob})
+
+ input_height = blob.shape[2]
+ input_width = blob.shape[3]
+ fmc = self.fmc
+ for idx, stride in enumerate(self._feat_stride_fpn):
+ scores = net_outs[idx]
+ bbox_preds = net_outs[idx+fmc]
+ bbox_preds = bbox_preds * stride
+ if self.use_kps:
+ kps_preds = net_outs[idx+fmc*2] * stride
+ height = input_height // stride
+ width = input_width // stride
+ # K = height * width
+ key = (height, width, stride)
+ if key in self.center_cache:
+ anchor_centers = self.center_cache[key]
+ else:
+ #solution-3:
+ anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
+ anchor_centers = (anchor_centers * stride).reshape( (-1, 2) )
+ if self._num_anchors>1:
+ anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) )
+ if len(self.center_cache)<100:
+ self.center_cache[key] = anchor_centers
+
+ pos_inds = np.where(scores>=threshold)[0]
+ bboxes = distance2bbox(anchor_centers, bbox_preds)
+ pos_scores = scores[pos_inds]
+ pos_bboxes = bboxes[pos_inds]
+ scores_list.append(pos_scores)
+ bboxes_list.append(pos_bboxes)
+ if self.use_kps:
+ kpss = distance2kps(anchor_centers, kps_preds)
+ kpss = kpss.reshape( (kpss.shape[0], -1, 2) )
+ pos_kpss = kpss[pos_inds]
+ kpss_list.append(pos_kpss)
+ return scores_list, bboxes_list, kpss_list
+
+
+ def detect(self, img, input_size=None, max_num=0, metric='default', det_thresh=None):
+ input_size = self.input_size if input_size is None else input_size
+ det_thresh = self.det_thresh if det_thresh is None else det_thresh
+
+ im_ratio = float(img.shape[0]) / img.shape[1]
+ model_ratio = float(input_size[1]) / input_size[0]
+ if im_ratio>model_ratio:
+ new_height = input_size[1]
+ new_width = int(new_height / im_ratio)
+ else:
+ new_width = input_size[0]
+ new_height = int(new_width * im_ratio)
+ det_scale = float(new_height) / img.shape[0]
+ resized_img = cv2.resize(img, (new_width, new_height))
+ det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 )
+ det_img[:new_height, :new_width, :] = resized_img
+
+ scores_list, bboxes_list, kpss_list = self.forward(det_img, det_thresh)
+
+ scores = np.vstack(scores_list)
+ scores_ravel = scores.ravel()
+ order = scores_ravel.argsort()[::-1]
+ bboxes = np.vstack(bboxes_list) / det_scale
+ if self.use_kps:
+ kpss = np.vstack(kpss_list) / det_scale
+ pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
+ pre_det = pre_det[order, :]
+ keep = self.nms(pre_det)
+ det = pre_det[keep, :]
+ if self.use_kps:
+ kpss = kpss[order,:,:]
+ kpss = kpss[keep,:,:]
+ else:
+ kpss = None
+ if max_num > 0 and det.shape[0] > max_num:
+ area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1])
+ img_center = img.shape[0] // 2, img.shape[1] // 2
+ offsets = np.vstack([
+ (det[:, 0] + det[:, 2]) / 2 - img_center[1],
+ (det[:, 1] + det[:, 3]) / 2 - img_center[0]
+ ])
+ offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
+ if metric=='max':
+ values = area
+ else:
+ values = area - offset_dist_squared * 2.0 # some extra weight on the centering
+ bindex = np.argsort(values)[::-1] # some extra weight on the centering
+ bindex = bindex[0:max_num]
+ det = det[bindex, :]
+ if kpss is not None:
+ kpss = kpss[bindex, :]
+ return det, kpss
+
+ def nms(self, dets):
+ thresh = self.nms_thresh
+ x1 = dets[:, 0]
+ y1 = dets[:, 1]
+ x2 = dets[:, 2]
+ y2 = dets[:, 3]
+ scores = dets[:, 4]
+
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ order = scores.argsort()[::-1]
+
+ keep = []
+ while order.size > 0:
+ i = order[0]
+ keep.append(i)
+ xx1 = np.maximum(x1[i], x1[order[1:]])
+ yy1 = np.maximum(y1[i], y1[order[1:]])
+ xx2 = np.minimum(x2[i], x2[order[1:]])
+ yy2 = np.minimum(y2[i], y2[order[1:]])
+
+ w = np.maximum(0.0, xx2 - xx1 + 1)
+ h = np.maximum(0.0, yy2 - yy1 + 1)
+ inter = w * h
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
+
+ inds = np.where(ovr <= thresh)[0]
+ order = order[inds + 1]
+
+ return keep
+
diff --git a/core/models/appearance_extractor.py b/core/models/appearance_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..906ca1211b936c3a3e5bc94ec1da76dd8334d0fe
--- /dev/null
+++ b/core/models/appearance_extractor.py
@@ -0,0 +1,29 @@
+import numpy as np
+import torch
+from ..utils.load_model import load_model
+
+
+class AppearanceExtractor:
+ def __init__(self, model_path, device="cuda"):
+ kwargs = {
+ "module_name": "AppearanceFeatureExtractor",
+ }
+ self.model, self.model_type = load_model(model_path, device=device, **kwargs)
+ self.device = device
+
+ def __call__(self, image):
+ """
+ image: np.ndarray, shape (1, 3, 256, 256), float32, range [0, 1]
+ """
+ if self.model_type == "onnx":
+ pred = self.model.run(None, {"image": image})[0]
+ elif self.model_type == "tensorrt":
+ self.model.setup({"image": image})
+ self.model.infer()
+ pred = self.model.buffer["pred"][0].copy()
+ elif self.model_type == 'pytorch':
+ with torch.no_grad(), torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=True):
+ pred = self.model(torch.from_numpy(image).to(self.device)).float().cpu().numpy()
+ else:
+ raise ValueError(f"Unsupported model type: {self.model_type}")
+ return pred
diff --git a/core/models/decoder.py b/core/models/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..54f6336bdd3e4cf02410798274d365f38ec5c941
--- /dev/null
+++ b/core/models/decoder.py
@@ -0,0 +1,30 @@
+import numpy as np
+import torch
+from ..utils.load_model import load_model
+
+
+class Decoder:
+ def __init__(self, model_path, device="cuda"):
+ kwargs = {
+ "module_name": "SPADEDecoder",
+ }
+ self.model, self.model_type = load_model(model_path, device=device, **kwargs)
+ self.device = device
+
+ def __call__(self, feature):
+
+ if self.model_type == "onnx":
+ pred = self.model.run(None, {"feature": feature})[0]
+ elif self.model_type == "tensorrt":
+ self.model.setup({"feature": feature})
+ self.model.infer()
+ pred = self.model.buffer["output"][0].copy()
+ elif self.model_type == 'pytorch':
+ with torch.no_grad(), torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=True):
+ pred = self.model(torch.from_numpy(feature).to(self.device)).float().cpu().numpy()
+ else:
+ raise ValueError(f"Unsupported model type: {self.model_type}")
+
+ pred = np.transpose(pred[0], [1, 2, 0]).clip(0, 1) * 255 # [h, w, c]
+
+ return pred
diff --git a/core/models/lmdm.py b/core/models/lmdm.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6467fa8f77e072165c461a16afa3525d8cebd27
--- /dev/null
+++ b/core/models/lmdm.py
@@ -0,0 +1,140 @@
+import numpy as np
+import torch
+from ..utils.load_model import load_model
+
+
+def make_beta(n_timestep, cosine_s=8e-3):
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+ return betas.numpy()
+
+
+class LMDM:
+ def __init__(self, model_path, device="cuda", **kwargs):
+ kwargs["module_name"] = "LMDM"
+
+ self.model, self.model_type = load_model(model_path, device=device, **kwargs)
+ self.device = device
+
+ self.motion_feat_dim = kwargs.get("motion_feat_dim", 265)
+ self.audio_feat_dim = kwargs.get("audio_feat_dim", 1024+35)
+ self.seq_frames = kwargs.get("seq_frames", 80)
+
+ if self.model_type == "pytorch":
+ pass
+ else:
+ self._init_np()
+
+ def setup(self, sampling_timesteps):
+ if self.model_type == "pytorch":
+ self.model.setup(sampling_timesteps)
+ else:
+ self._setup_np(sampling_timesteps)
+
+ def _init_np(self):
+ self.sampling_timesteps = None
+ self.n_timestep = 1000
+
+ betas = torch.Tensor(make_beta(n_timestep=self.n_timestep))
+ alphas = 1.0 - betas
+ self.alphas_cumprod = torch.cumprod(alphas, axis=0).cpu().numpy()
+
+ def _setup_np(self, sampling_timesteps=50):
+ if self.sampling_timesteps == sampling_timesteps:
+ return
+
+ self.sampling_timesteps = sampling_timesteps
+
+ total_timesteps = self.n_timestep
+ eta = 1
+ shape = (1, self.seq_frames, self.motion_feat_dim)
+
+ times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
+ times = list(reversed(times.int().tolist()))
+ self.time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
+
+ self.time_cond_list = []
+ self.alpha_next_sqrt_list = []
+ self.sigma_list = []
+ self.c_list = []
+ self.noise_list = []
+
+ for time, time_next in self.time_pairs:
+ time_cond = np.full((1,), time, dtype=np.int64)
+ self.time_cond_list.append(time_cond)
+ if time_next < 0:
+ continue
+
+ alpha = self.alphas_cumprod[time]
+ alpha_next = self.alphas_cumprod[time_next]
+
+ sigma = eta * np.sqrt((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha))
+ c = np.sqrt(1 - alpha_next - sigma ** 2)
+ noise = np.random.randn(*shape).astype(np.float32)
+
+ self.alpha_next_sqrt_list.append(np.sqrt(alpha_next))
+ self.sigma_list.append(sigma)
+ self.c_list.append(c)
+ self.noise_list.append(noise)
+
+ def _one_step(self, x, cond_frame, cond, time_cond):
+ if self.model_type == "onnx":
+ pred = self.model.run(None, {"x": x, "cond_frame": cond_frame, "cond": cond, "time_cond": time_cond})
+ pred_noise, x_start = pred[0], pred[1]
+ elif self.model_type == "tensorrt":
+ self.model.setup({"x": x, "cond_frame": cond_frame, "cond": cond, "time_cond": time_cond})
+ self.model.infer()
+ pred_noise, x_start = self.model.buffer["pred_noise"][0], self.model.buffer["x_start"][0]
+ elif self.model_type == "pytorch":
+ with torch.no_grad():
+ pred_noise, x_start = self.model(x, cond_frame, cond, time_cond)
+ else:
+ raise ValueError(f"Unsupported model type: {self.model_type}")
+
+ return pred_noise, x_start
+
+ def _call_np(self, kp_cond, aud_cond, sampling_timesteps):
+ self._setup_np(sampling_timesteps)
+
+ cond_frame = kp_cond
+ cond = aud_cond
+
+ x = np.random.randn(1, self.seq_frames, self.motion_feat_dim).astype(np.float32)
+
+ x_start = None
+ i = 0
+ for _, time_next in self.time_pairs:
+ time_cond = self.time_cond_list[i]
+ pred_noise, x_start = self._one_step(x, cond_frame, cond, time_cond)
+ if time_next < 0:
+ x = x_start
+ continue
+
+ alpha_next_sqrt = self.alpha_next_sqrt_list[i]
+ c = self.c_list[i]
+ sigma = self.sigma_list[i]
+ noise = self.noise_list[i]
+ x = x_start * alpha_next_sqrt + c * pred_noise + sigma * noise
+
+ i += 1
+
+ return x
+
+ def __call__(self, kp_cond, aud_cond, sampling_timesteps):
+ if self.model_type == "pytorch":
+ pred_kp_seq = self.model.ddim_sample(
+ torch.from_numpy(kp_cond).to(self.device),
+ torch.from_numpy(aud_cond).to(self.device),
+ sampling_timesteps,
+ ).cpu().numpy()
+ else:
+ pred_kp_seq = self._call_np(kp_cond, aud_cond, sampling_timesteps)
+ return pred_kp_seq
+
+
diff --git a/core/models/modules/LMDM.py b/core/models/modules/LMDM.py
new file mode 100644
index 0000000000000000000000000000000000000000..323ff25c61e85eb6344b75ba19338eb934853613
--- /dev/null
+++ b/core/models/modules/LMDM.py
@@ -0,0 +1,154 @@
+# Latent Motion Diffusion Model
+import torch
+import torch.nn as nn
+from .lmdm_modules.model import MotionDecoder
+from .lmdm_modules.utils import extract, make_beta_schedule
+
+
+class LMDM(nn.Module):
+ def __init__(
+ self,
+ motion_feat_dim=265,
+ audio_feat_dim=1024+35,
+ seq_frames=80,
+ checkpoint='',
+ device='cuda',
+ clip_denoised=False, # clip denoised (-1,1)
+ multi_cond_frame=False,
+ ):
+ super().__init__()
+
+ self.motion_feat_dim = motion_feat_dim
+ self.audio_feat_dim = audio_feat_dim
+ self.seq_frames = seq_frames
+ self.device = device
+
+ self.n_timestep = 1000
+ self.clip_denoised = clip_denoised
+ self.guidance_weight = 2
+
+ self.model = MotionDecoder(
+ nfeats=motion_feat_dim,
+ seq_len=seq_frames,
+ latent_dim=512,
+ ff_size=1024,
+ num_layers=8,
+ num_heads=8,
+ dropout=0.1,
+ cond_feature_dim=audio_feat_dim,
+ multi_cond_frame=multi_cond_frame,
+ )
+
+ self.init_diff()
+
+ self.sampling_timesteps = None
+
+ def init_diff(self):
+ n_timestep = self.n_timestep
+ betas = torch.Tensor(
+ make_beta_schedule(schedule="cosine", n_timestep=n_timestep)
+ )
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, axis=0)
+
+ self.register_buffer("alphas_cumprod", alphas_cumprod)
+ self.register_buffer(
+ "sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)
+ )
+ self.register_buffer("sqrt_recip1m_alphas_cumprod", torch.sqrt(1.0 / (1.0 - alphas_cumprod)))
+
+ def predict_noise_from_start(self, x_t, t, x0):
+ a = extract(self.sqrt_recip1m_alphas_cumprod, t, x_t.shape)
+ b = extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+ return (a * x_t - x0 / b)
+
+ def maybe_clip(self, x):
+ if self.clip_denoised:
+ return torch.clamp(x, min=-1., max=1.)
+ else:
+ return x
+
+ def model_predictions(self, x, cond_frame, cond, t):
+ weight = self.guidance_weight
+ x_start = self.model.guided_forward(x, cond_frame, cond, t, weight)
+ x_start = self.maybe_clip(x_start)
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
+ return pred_noise, x_start
+
+ @torch.no_grad()
+ def forward(self, x, cond_frame, cond, time_cond):
+ pred_noise, x_start = self.model_predictions(x, cond_frame, cond, time_cond)
+ return pred_noise, x_start
+
+ def load_model(self, ckpt_path):
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
+ self.model.load_state_dict(checkpoint["model_state_dict"])
+ self.eval()
+ return self
+
+ def setup(self, sampling_timesteps=50):
+ if self.sampling_timesteps == sampling_timesteps:
+ return
+
+ self.sampling_timesteps = sampling_timesteps
+
+ total_timesteps = self.n_timestep
+ device = self.device
+ eta = 1
+ shape = (1, self.seq_frames, self.motion_feat_dim)
+
+ times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
+ times = list(reversed(times.int().tolist()))
+ self.time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
+
+ self.time_cond_list = []
+ self.alpha_next_sqrt_list = []
+ self.sigma_list = []
+ self.c_list = []
+ self.noise_list = []
+
+ for time, time_next in self.time_pairs:
+ time_cond = torch.full((1,), time, device=device, dtype=torch.long)
+ self.time_cond_list.append(time_cond)
+ if time_next < 0:
+ continue
+ alpha = self.alphas_cumprod[time]
+ alpha_next = self.alphas_cumprod[time_next]
+
+ sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
+ c = (1 - alpha_next - sigma ** 2).sqrt()
+ noise = torch.randn(shape, device=device)
+
+ self.alpha_next_sqrt_list.append(alpha_next.sqrt())
+ self.sigma_list.append(sigma)
+ self.c_list.append(c)
+ self.noise_list.append(noise)
+
+ @torch.no_grad()
+ def ddim_sample(self, kp_cond, aud_cond, sampling_timesteps):
+ self.setup(sampling_timesteps)
+
+ cond_frame = kp_cond
+ cond = aud_cond
+
+ shape = (1, self.seq_frames, self.motion_feat_dim)
+ x = torch.randn(shape, device=self.device)
+
+ x_start = None
+ i = 0
+ for _, time_next in self.time_pairs:
+ time_cond = self.time_cond_list[i]
+ pred_noise, x_start = self.model_predictions(x, cond_frame, cond, time_cond)
+ if time_next < 0:
+ x = x_start
+ continue
+
+ alpha_next_sqrt = self.alpha_next_sqrt_list[i]
+ c = self.c_list[i]
+ sigma = self.sigma_list[i]
+ noise = self.noise_list[i]
+ x = x_start * alpha_next_sqrt + c * pred_noise + sigma * noise
+
+ i += 1
+ return x # pred_kp_seq
+
diff --git a/core/models/modules/__init__.py b/core/models/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..925672f5f1221820a5dbb483af1f8ea6c492e613
--- /dev/null
+++ b/core/models/modules/__init__.py
@@ -0,0 +1,6 @@
+from .appearance_feature_extractor import AppearanceFeatureExtractor
+from .motion_extractor import MotionExtractor
+from .warping_network import WarpingNetwork
+from .spade_generator import SPADEDecoder
+from .stitching_network import StitchingNetwork
+from .LMDM import LMDM
\ No newline at end of file
diff --git a/core/models/modules/appearance_feature_extractor.py b/core/models/modules/appearance_feature_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc5c1a72a613762d11a04fe9316859a78e591fec
--- /dev/null
+++ b/core/models/modules/appearance_feature_extractor.py
@@ -0,0 +1,74 @@
+# coding: utf-8
+
+"""
+Appearance extractor(F) defined in paper, which maps the source image s to a 3D appearance feature volume.
+"""
+
+import torch
+from torch import nn
+from .util import SameBlock2d, DownBlock2d, ResBlock3d
+
+
+class AppearanceFeatureExtractor(nn.Module):
+
+ def __init__(
+ self,
+ image_channel=3,
+ block_expansion=64,
+ num_down_blocks=2,
+ max_features=512,
+ reshape_channel=32,
+ reshape_depth=16,
+ num_resblocks=6,
+ ):
+ super(AppearanceFeatureExtractor, self).__init__()
+ self.image_channel = image_channel
+ self.block_expansion = block_expansion
+ self.num_down_blocks = num_down_blocks
+ self.max_features = max_features
+ self.reshape_channel = reshape_channel
+ self.reshape_depth = reshape_depth
+
+ self.first = SameBlock2d(
+ image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1)
+ )
+
+ down_blocks = []
+ for i in range(num_down_blocks):
+ in_features = min(max_features, block_expansion * (2**i))
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
+ down_blocks.append(
+ DownBlock2d(
+ in_features, out_features, kernel_size=(3, 3), padding=(1, 1)
+ )
+ )
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ self.second = nn.Conv2d(
+ in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1
+ )
+
+ self.resblocks_3d = torch.nn.Sequential()
+ for i in range(num_resblocks):
+ self.resblocks_3d.add_module(
+ "3dr" + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)
+ )
+
+ def forward(self, source_image):
+ out = self.first(source_image) # Bx3x256x256 -> Bx64x256x256
+
+ for i in range(len(self.down_blocks)):
+ out = self.down_blocks[i](out)
+ out = self.second(out)
+ bs, c, h, w = out.shape # ->Bx512x64x64
+
+ f_s = out.view(
+ bs, self.reshape_channel, self.reshape_depth, h, w
+ ) # ->Bx32x16x64x64
+ f_s = self.resblocks_3d(f_s) # ->Bx32x16x64x64
+ return f_s
+
+ def load_model(self, ckpt_path):
+ self.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage))
+ self.eval()
+ return self
diff --git a/core/models/modules/convnextv2.py b/core/models/modules/convnextv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..1004d6ef08cb0d1c4b6a601f361f83a60833efee
--- /dev/null
+++ b/core/models/modules/convnextv2.py
@@ -0,0 +1,150 @@
+# coding: utf-8
+
+"""
+This moudle is adapted to the ConvNeXtV2 version for the extraction of implicit keypoints, poses, and expression deformation.
+"""
+
+import torch
+import torch.nn as nn
+# from timm.models.layers import trunc_normal_, DropPath
+from .util import LayerNorm, DropPath, trunc_normal_, GRN
+
+__all__ = ['convnextv2_tiny']
+
+
+class Block(nn.Module):
+ """ ConvNeXtV2 Block.
+
+ Args:
+ dim (int): Number of input channels.
+ drop_path (float): Stochastic depth rate. Default: 0.0
+ """
+
+ def __init__(self, dim, drop_path=0.):
+ super().__init__()
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
+ self.norm = LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.grn = GRN(4 * dim)
+ self.pwconv2 = nn.Linear(4 * dim, dim)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x):
+ input = x
+ x = self.dwconv(x)
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.grn(x)
+ x = self.pwconv2(x)
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+
+ x = input + self.drop_path(x)
+ return x
+
+
+class ConvNeXtV2(nn.Module):
+ """ ConvNeXt V2
+
+ Args:
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
+ """
+
+ def __init__(
+ self,
+ in_chans=3,
+ depths=[3, 3, 9, 3],
+ dims=[96, 192, 384, 768],
+ drop_path_rate=0.,
+ **kwargs
+ ):
+ super().__init__()
+ self.depths = depths
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
+ stem = nn.Sequential(
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
+ )
+ self.downsample_layers.append(stem)
+ for i in range(3):
+ downsample_layer = nn.Sequential(
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
+ )
+ self.downsample_layers.append(downsample_layer)
+
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
+ cur = 0
+ for i in range(4):
+ stage = nn.Sequential(
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
+ )
+ self.stages.append(stage)
+ cur += depths[i]
+
+ self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
+
+ # NOTE: the output semantic items
+ num_bins = kwargs.get('num_bins', 66)
+ num_kp = kwargs.get('num_kp', 24) # the number of implicit keypoints
+ self.fc_kp = nn.Linear(dims[-1], 3 * num_kp) # implicit keypoints
+
+ # print('dims[-1]: ', dims[-1])
+ self.fc_scale = nn.Linear(dims[-1], 1) # scale
+ self.fc_pitch = nn.Linear(dims[-1], num_bins) # pitch bins
+ self.fc_yaw = nn.Linear(dims[-1], num_bins) # yaw bins
+ self.fc_roll = nn.Linear(dims[-1], num_bins) # roll bins
+ self.fc_t = nn.Linear(dims[-1], 3) # translation
+ self.fc_exp = nn.Linear(dims[-1], 3 * num_kp) # expression / delta
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ trunc_normal_(m.weight, std=.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward_features(self, x):
+ for i in range(4):
+ x = self.downsample_layers[i](x)
+ x = self.stages[i](x)
+ return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+
+ # implicit keypoints
+ kp = self.fc_kp(x)
+
+ # pose and expression deformation
+ pitch = self.fc_pitch(x)
+ yaw = self.fc_yaw(x)
+ roll = self.fc_roll(x)
+ t = self.fc_t(x)
+ exp = self.fc_exp(x)
+ scale = self.fc_scale(x)
+
+ # ret_dct = {
+ # 'pitch': pitch,
+ # 'yaw': yaw,
+ # 'roll': roll,
+ # 't': t,
+ # 'exp': exp,
+ # 'scale': scale,
+
+ # 'kp': kp, # canonical keypoint
+ # }
+
+ # return ret_dct
+ return pitch, yaw, roll, t, exp, scale, kp
+
+
+def convnextv2_tiny(**kwargs):
+ model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
+ return model
diff --git a/core/models/modules/dense_motion.py b/core/models/modules/dense_motion.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1a7f9a5a1b5463c4a1ddcee2bf36dcd34735706
--- /dev/null
+++ b/core/models/modules/dense_motion.py
@@ -0,0 +1,104 @@
+# coding: utf-8
+
+"""
+The module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
+"""
+
+from torch import nn
+import torch.nn.functional as F
+import torch
+from .util import Hourglass, make_coordinate_grid, kp2gaussian
+
+
+class DenseMotionNetwork(nn.Module):
+ def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, estimate_occlusion_map=True):
+ super(DenseMotionNetwork, self).__init__()
+ self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) # ~60+G
+
+ self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) # 65G! NOTE: computation cost is large
+ self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) # 0.8G
+ self.norm = nn.BatchNorm3d(compress, affine=True)
+ self.num_kp = num_kp
+ self.flag_estimate_occlusion_map = estimate_occlusion_map
+
+ if self.flag_estimate_occlusion_map:
+ self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3)
+ else:
+ self.occlusion = None
+
+ def create_sparse_motions(self, feature, kp_driving, kp_source):
+ bs, _, d, h, w = feature.shape # (bs, 4, 16, 64, 64)
+ identity_grid = make_coordinate_grid((d, h, w), ref=kp_source) # (16, 64, 64, 3)
+ identity_grid = identity_grid.view(1, 1, d, h, w, 3) # (1, 1, d=16, h=64, w=64, 3)
+ coordinate_grid = identity_grid - kp_driving.view(bs, self.num_kp, 1, 1, 1, 3)
+
+ k = coordinate_grid.shape[1]
+
+ # NOTE: there lacks an one-order flow
+ driving_to_source = coordinate_grid + kp_source.view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3)
+
+ # adding background feature
+ identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1)
+ sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) # (bs, 1+num_kp, d, h, w, 3)
+ return sparse_motions
+
+ def create_deformed_feature(self, feature, sparse_motions):
+ bs, _, d, h, w = feature.shape
+ feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w)
+ feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w)
+ sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3)
+ sparse_deformed = F.grid_sample(feature_repeat, sparse_motions, align_corners=False)
+ sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w)
+
+ return sparse_deformed
+
+ def create_heatmap_representations(self, feature, kp_driving, kp_source):
+ spatial_size = feature.shape[3:] # (d=16, h=64, w=64)
+ gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w)
+ gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w)
+ heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w)
+
+ # adding background feature
+ zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.dtype).to(heatmap.device)
+ heatmap = torch.cat([zeros, heatmap], dim=1)
+ heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w)
+ return heatmap
+
+ def forward(self, feature, kp_driving, kp_source):
+ bs, _, d, h, w = feature.shape # (bs, 32, 16, 64, 64)
+
+ feature = self.compress(feature) # (bs, 4, 16, 64, 64)
+ feature = self.norm(feature) # (bs, 4, 16, 64, 64)
+ feature = F.relu(feature) # (bs, 4, 16, 64, 64)
+
+ out_dict = dict()
+
+ # 1. deform 3d feature
+ sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) # (bs, 1+num_kp, d, h, w, 3)
+ deformed_feature = self.create_deformed_feature(feature, sparse_motion) # (bs, 1+num_kp, c=4, d=16, h=64, w=64)
+
+ # 2. (bs, 1+num_kp, d, h, w)
+ heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) # (bs, 1+num_kp, 1, d, h, w)
+
+ input = torch.cat([heatmap, deformed_feature], dim=2) # (bs, 1+num_kp, c=5, d=16, h=64, w=64)
+ input = input.view(bs, -1, d, h, w) # (bs, (1+num_kp)*c=105, d=16, h=64, w=64)
+
+ prediction = self.hourglass(input)
+
+ mask = self.mask(prediction)
+ mask = F.softmax(mask, dim=1) # (bs, 1+num_kp, d=16, h=64, w=64)
+ out_dict['mask'] = mask
+ mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
+ sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w)
+ deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) mask take effect in this place
+ deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3)
+
+ out_dict['deformation'] = deformation
+
+ if self.flag_estimate_occlusion_map:
+ bs, _, d, h, w = prediction.shape
+ prediction_reshape = prediction.view(bs, -1, h, w)
+ occlusion_map = torch.sigmoid(self.occlusion(prediction_reshape)) # Bx1x64x64
+ out_dict['occlusion_map'] = occlusion_map
+
+ return out_dict
diff --git a/core/models/modules/lmdm_modules/model.py b/core/models/modules/lmdm_modules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ad7eb6d77074c931ac986d3a81bdf8bcf5f48c0
--- /dev/null
+++ b/core/models/modules/lmdm_modules/model.py
@@ -0,0 +1,398 @@
+from typing import Callable, Optional, Union
+import torch
+import torch.nn as nn
+from einops import rearrange
+from einops.layers.torch import Rearrange
+from torch import Tensor
+from torch.nn import functional as F
+
+from .rotary_embedding_torch import RotaryEmbedding
+from .utils import PositionalEncoding, SinusoidalPosEmb, prob_mask_like
+
+
+class DenseFiLM(nn.Module):
+ """Feature-wise linear modulation (FiLM) generator."""
+
+ def __init__(self, embed_channels):
+ super().__init__()
+ self.embed_channels = embed_channels
+ self.block = nn.Sequential(
+ nn.Mish(), nn.Linear(embed_channels, embed_channels * 2)
+ )
+
+ def forward(self, position):
+ pos_encoding = self.block(position)
+ pos_encoding = rearrange(pos_encoding, "b c -> b 1 c")
+ scale_shift = pos_encoding.chunk(2, dim=-1)
+ return scale_shift
+
+
+def featurewise_affine(x, scale_shift):
+ scale, shift = scale_shift
+ return (scale + 1) * x + shift
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ nhead: int,
+ dim_feedforward: int = 2048,
+ dropout: float = 0.1,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ layer_norm_eps: float = 1e-5,
+ batch_first: bool = False,
+ norm_first: bool = True,
+ device=None,
+ dtype=None,
+ rotary=None,
+ ) -> None:
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(
+ d_model, nhead, dropout=dropout, batch_first=batch_first
+ )
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm_first = norm_first
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.activation = activation
+
+ self.rotary = rotary
+ self.use_rotary = rotary is not None
+
+ def forward(
+ self,
+ src: Tensor,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ x = src
+ if self.norm_first:
+ x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
+ x = x + self._ff_block(self.norm2(x))
+ else:
+ x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
+ x = self.norm2(x + self._ff_block(x))
+
+ return x
+
+ # self-attention block
+ def _sa_block(
+ self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]
+ ) -> Tensor:
+ qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
+ x = self.self_attn(
+ qk,
+ qk,
+ x,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask,
+ need_weights=False,
+ )[0]
+ return self.dropout1(x)
+
+ # feed forward block
+ def _ff_block(self, x: Tensor) -> Tensor:
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
+ return self.dropout2(x)
+
+
+class FiLMTransformerDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ nhead: int,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation=F.relu,
+ layer_norm_eps=1e-5,
+ batch_first=False,
+ norm_first=True,
+ device=None,
+ dtype=None,
+ rotary=None,
+ ):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(
+ d_model, nhead, dropout=dropout, batch_first=batch_first
+ )
+ self.multihead_attn = nn.MultiheadAttention(
+ d_model, nhead, dropout=dropout, batch_first=batch_first
+ )
+ # Feedforward
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm_first = norm_first
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+ self.activation = activation
+
+ self.film1 = DenseFiLM(d_model)
+ self.film2 = DenseFiLM(d_model)
+ self.film3 = DenseFiLM(d_model)
+
+ self.rotary = rotary
+ self.use_rotary = rotary is not None
+
+ # x, cond, t
+ def forward(
+ self,
+ tgt,
+ memory,
+ t,
+ tgt_mask=None,
+ memory_mask=None,
+ tgt_key_padding_mask=None,
+ memory_key_padding_mask=None,
+ ):
+ x = tgt
+ if self.norm_first:
+ # self-attention -> film -> residual
+ x_1 = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
+ x = x + featurewise_affine(x_1, self.film1(t))
+ # cross-attention -> film -> residual
+ x_2 = self._mha_block(
+ self.norm2(x), memory, memory_mask, memory_key_padding_mask
+ )
+ x = x + featurewise_affine(x_2, self.film2(t))
+ # feedforward -> film -> residual
+ x_3 = self._ff_block(self.norm3(x))
+ x = x + featurewise_affine(x_3, self.film3(t))
+ else:
+ x = self.norm1(
+ x
+ + featurewise_affine(
+ self._sa_block(x, tgt_mask, tgt_key_padding_mask), self.film1(t)
+ )
+ )
+ x = self.norm2(
+ x
+ + featurewise_affine(
+ self._mha_block(x, memory, memory_mask, memory_key_padding_mask),
+ self.film2(t),
+ )
+ )
+ x = self.norm3(x + featurewise_affine(self._ff_block(x), self.film3(t)))
+ return x
+
+ # self-attention block
+ # qkv
+ def _sa_block(self, x, attn_mask, key_padding_mask):
+ qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
+ x = self.self_attn(
+ qk,
+ qk,
+ x,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask,
+ need_weights=False,
+ )[0]
+ return self.dropout1(x)
+
+ # multihead attention block
+ # qkv
+ def _mha_block(self, x, mem, attn_mask, key_padding_mask):
+ q = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
+ k = self.rotary.rotate_queries_or_keys(mem) if self.use_rotary else mem
+ x = self.multihead_attn(
+ q,
+ k,
+ mem,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask,
+ need_weights=False,
+ )[0]
+ return self.dropout2(x)
+
+ # feed forward block
+ def _ff_block(self, x):
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
+ return self.dropout3(x)
+
+
+class DecoderLayerStack(nn.Module):
+ def __init__(self, stack):
+ super().__init__()
+ self.stack = stack
+
+ def forward(self, x, cond, t):
+ for layer in self.stack:
+ x = layer(x, cond, t)
+ return x
+
+
+class MotionDecoder(nn.Module):
+ def __init__(
+ self,
+ nfeats: int,
+ seq_len: int = 100, # 4 seconds, 25 fps
+ latent_dim: int = 256,
+ ff_size: int = 1024,
+ num_layers: int = 4,
+ num_heads: int = 4,
+ dropout: float = 0.1,
+ cond_feature_dim: int = 4800,
+ activation: Callable[[Tensor], Tensor] = F.gelu,
+ use_rotary=True,
+ multi_cond_frame=False,
+ **kwargs
+ ) -> None:
+
+ super().__init__()
+
+ self.multi_cond_frame = multi_cond_frame
+
+ output_feats = nfeats
+
+ # positional embeddings
+ self.rotary = None
+ self.abs_pos_encoding = nn.Identity()
+ # if rotary, replace absolute embedding with a rotary embedding instance (absolute becomes an identity)
+ if use_rotary:
+ self.rotary = RotaryEmbedding(dim=latent_dim)
+ else:
+ self.abs_pos_encoding = PositionalEncoding(
+ latent_dim, dropout, batch_first=True
+ )
+
+ # time embedding processing
+ self.time_mlp = nn.Sequential(
+ SinusoidalPosEmb(latent_dim), # learned?
+ nn.Linear(latent_dim, latent_dim * 4),
+ nn.Mish(),
+ )
+
+ self.to_time_cond = nn.Sequential(nn.Linear(latent_dim * 4, latent_dim),)
+
+ self.to_time_tokens = nn.Sequential(
+ nn.Linear(latent_dim * 4, latent_dim * 2), # 2 time tokens
+ Rearrange("b (r d) -> b r d", r=2),
+ )
+
+ # null embeddings for guidance dropout
+ self.null_cond_embed = nn.Parameter(torch.randn(1, seq_len, latent_dim))
+ self.null_cond_hidden = nn.Parameter(torch.randn(1, latent_dim))
+
+ self.norm_cond = nn.LayerNorm(latent_dim)
+
+ # input projection
+ if self.multi_cond_frame:
+ self.input_projection = nn.Linear(nfeats * 2 + 1, latent_dim)
+ else:
+ self.input_projection = nn.Linear(nfeats * 2, latent_dim)
+ self.cond_encoder = nn.Sequential()
+ for _ in range(2):
+ self.cond_encoder.append(
+ TransformerEncoderLayer(
+ d_model=latent_dim,
+ nhead=num_heads,
+ dim_feedforward=ff_size,
+ dropout=dropout,
+ activation=activation,
+ batch_first=True,
+ rotary=self.rotary,
+ )
+ )
+ # conditional projection
+ self.cond_projection = nn.Linear(cond_feature_dim, latent_dim)
+ self.non_attn_cond_projection = nn.Sequential(
+ nn.LayerNorm(latent_dim),
+ nn.Linear(latent_dim, latent_dim),
+ nn.SiLU(),
+ nn.Linear(latent_dim, latent_dim),
+ )
+ # decoder
+ decoderstack = nn.ModuleList([])
+ for _ in range(num_layers):
+ decoderstack.append(
+ FiLMTransformerDecoderLayer(
+ latent_dim,
+ num_heads,
+ dim_feedforward=ff_size,
+ dropout=dropout,
+ activation=activation,
+ batch_first=True,
+ rotary=self.rotary,
+ )
+ )
+
+ self.seqTransDecoder = DecoderLayerStack(decoderstack)
+
+ self.final_layer = nn.Linear(latent_dim, output_feats)
+
+ self.epsilon = 0.00001
+
+ def guided_forward(self, x, cond_frame, cond_embed, times, guidance_weight):
+ unc = self.forward(x, cond_frame, cond_embed, times, cond_drop_prob=1)
+ conditioned = self.forward(x, cond_frame, cond_embed, times, cond_drop_prob=0)
+
+ return unc + (conditioned - unc) * guidance_weight
+
+ def forward(
+ self, x: Tensor, cond_frame: Tensor, cond_embed: Tensor, times: Tensor, cond_drop_prob: float = 0.0
+ ):
+ batch_size, device = x.shape[0], x.device
+
+ # concat last frame, project to latent space
+ # cond_frame: [b, dim] | [b, n, dim+1]
+ if self.multi_cond_frame:
+ # [b, n, dim+1] (+1 mask)
+ x = torch.cat([x, cond_frame], dim=-1)
+ else:
+ # [b, dim]
+ x = torch.cat([x, cond_frame.unsqueeze(1).repeat(1, x.shape[1], 1)], dim=-1)
+ x = self.input_projection(x)
+ # add the positional embeddings of the input sequence to provide temporal information
+ x = self.abs_pos_encoding(x)
+
+ # create audio conditional embedding with conditional dropout
+ keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device=device)
+ keep_mask_embed = rearrange(keep_mask, "b -> b 1 1")
+ keep_mask_hidden = rearrange(keep_mask, "b -> b 1")
+
+ cond_tokens = self.cond_projection(cond_embed)
+ # encode tokens
+ cond_tokens = self.abs_pos_encoding(cond_tokens)
+ cond_tokens = self.cond_encoder(cond_tokens)
+
+ null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype)
+ cond_tokens = torch.where(keep_mask_embed, cond_tokens, null_cond_embed)
+
+ mean_pooled_cond_tokens = cond_tokens.mean(dim=-2)
+ cond_hidden = self.non_attn_cond_projection(mean_pooled_cond_tokens)
+
+ # create the diffusion timestep embedding, add the extra audio projection
+ t_hidden = self.time_mlp(times)
+
+ # project to attention and FiLM conditioning
+ t = self.to_time_cond(t_hidden)
+ t_tokens = self.to_time_tokens(t_hidden)
+
+ # FiLM conditioning
+ null_cond_hidden = self.null_cond_hidden.to(t.dtype)
+ cond_hidden = torch.where(keep_mask_hidden, cond_hidden, null_cond_hidden)
+ t += cond_hidden
+
+ # cross-attention conditioning
+ c = torch.cat((cond_tokens, t_tokens), dim=-2)
+ cond_tokens = self.norm_cond(c)
+
+ # Pass through the transformer decoder
+ # attending to the conditional embedding
+ output = self.seqTransDecoder(x, cond_tokens, t)
+
+ output = self.final_layer(output)
+
+ return output
\ No newline at end of file
diff --git a/core/models/modules/lmdm_modules/rotary_embedding_torch.py b/core/models/modules/lmdm_modules/rotary_embedding_torch.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f9410d04bfce12cfd7a19c7a8de9ebbf9ce4540
--- /dev/null
+++ b/core/models/modules/lmdm_modules/rotary_embedding_torch.py
@@ -0,0 +1,132 @@
+from inspect import isfunction
+from math import log, pi
+
+import torch
+from einops import rearrange, repeat
+from torch import einsum, nn
+
+# helper functions
+
+
+def exists(val):
+ return val is not None
+
+
+def broadcat(tensors, dim=-1):
+ num_tensors = len(tensors)
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
+ shape_len = list(shape_lens)[0]
+
+ dim = (dim + shape_len) if dim < 0 else dim
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
+
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
+ assert all(
+ [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
+ ), "invalid dimensions for broadcastable concatentation"
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
+ expanded_dims.insert(dim, (dim, dims[dim]))
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
+ return torch.cat(tensors, dim=dim)
+
+
+# rotary embedding helper functions
+
+
+def rotate_half(x):
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
+ x1, x2 = x.unbind(dim=-1)
+ x = torch.stack((-x2, x1), dim=-1)
+ return rearrange(x, "... d r -> ... (d r)")
+
+
+def apply_rotary_emb(freqs, t, start_index=0):
+ freqs = freqs.to(t)
+ rot_dim = freqs.shape[-1]
+ end_index = start_index + rot_dim
+ assert (
+ rot_dim <= t.shape[-1]
+ ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
+ t_left, t, t_right = (
+ t[..., :start_index],
+ t[..., start_index:end_index],
+ t[..., end_index:],
+ )
+ t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
+ return torch.cat((t_left, t, t_right), dim=-1)
+
+
+# learned rotation helpers
+
+
+def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
+ if exists(freq_ranges):
+ rotations = einsum("..., f -> ... f", rotations, freq_ranges)
+ rotations = rearrange(rotations, "... r f -> ... (r f)")
+
+ rotations = repeat(rotations, "... n -> ... (n r)", r=2)
+ return apply_rotary_emb(rotations, t, start_index=start_index)
+
+
+# classes
+
+
+class RotaryEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim,
+ custom_freqs=None,
+ freqs_for="lang",
+ theta=10000,
+ max_freq=10,
+ num_freqs=1,
+ learned_freq=False,
+ ):
+ super().__init__()
+ if exists(custom_freqs):
+ freqs = custom_freqs
+ elif freqs_for == "lang":
+ freqs = 1.0 / (
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
+ )
+ elif freqs_for == "pixel":
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
+ elif freqs_for == "constant":
+ freqs = torch.ones(num_freqs).float()
+ else:
+ raise ValueError(f"unknown modality {freqs_for}")
+
+ self.cache = dict()
+
+ if learned_freq:
+ self.freqs = nn.Parameter(freqs)
+ else:
+ self.register_buffer("freqs", freqs)
+
+ def rotate_queries_or_keys(self, t, seq_dim=-2):
+ device = t.device
+ seq_len = t.shape[seq_dim]
+ freqs = self.forward(
+ lambda: torch.arange(seq_len, device=device), cache_key=seq_len
+ )
+ return apply_rotary_emb(freqs, t)
+
+ def forward(self, t, cache_key=None):
+ if exists(cache_key) and cache_key in self.cache:
+ return self.cache[cache_key]
+
+ if isfunction(t):
+ t = t()
+
+ freqs = self.freqs
+
+ freqs = torch.einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
+
+ if exists(cache_key):
+ self.cache[cache_key] = freqs
+
+ return freqs
diff --git a/core/models/modules/lmdm_modules/utils.py b/core/models/modules/lmdm_modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..883cf5ba52cdadbb3033c870a95147486bd4487e
--- /dev/null
+++ b/core/models/modules/lmdm_modules/utils.py
@@ -0,0 +1,96 @@
+import math
+import numpy as np
+import torch
+from torch import nn
+
+
+# absolute positional embedding used for vanilla transformer sequential data
+class PositionalEncoding(nn.Module):
+ def __init__(self, d_model, dropout=0.1, max_len=500, batch_first=False):
+ super().__init__()
+ self.batch_first = batch_first
+
+ self.dropout = nn.Dropout(p=dropout)
+
+ pe = torch.zeros(max_len, d_model)
+ position = torch.arange(0, max_len).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0).transpose(0, 1)
+
+ self.register_buffer("pe", pe)
+
+ def forward(self, x):
+ if self.batch_first:
+ x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :]
+ else:
+ x = x + self.pe[: x.shape[0], :]
+ return self.dropout(x)
+
+
+# very similar positional embedding used for diffusion timesteps
+class SinusoidalPosEmb(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x):
+ device = x.device
+ half_dim = self.dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+ emb = x[:, None] * emb[None, :]
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+ return emb
+
+
+# dropout mask
+def prob_mask_like(shape, prob, device):
+ if prob == 1:
+ return torch.ones(shape, device=device, dtype=torch.bool)
+ elif prob == 0:
+ return torch.zeros(shape, device=device, dtype=torch.bool)
+ else:
+ return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
+
+
+def extract(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def make_beta_schedule(
+ schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
+):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(
+ linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64
+ )
+ ** 2
+ )
+
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(
+ linear_start, linear_end, n_timestep, dtype=torch.float64
+ )
+ elif schedule == "sqrt":
+ betas = (
+ torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ ** 0.5
+ )
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+ return betas.numpy()
diff --git a/core/models/modules/motion_extractor.py b/core/models/modules/motion_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e6d1ac9b439e1280c01dd5a15aa47e930fdc380
--- /dev/null
+++ b/core/models/modules/motion_extractor.py
@@ -0,0 +1,25 @@
+# coding: utf-8
+
+"""
+Motion extractor(M), which directly predicts the canonical keypoints, head pose and expression deformation of the input image
+"""
+
+from torch import nn
+import torch
+
+from .convnextv2 import convnextv2_tiny
+
+
+class MotionExtractor(nn.Module):
+ def __init__(self, num_kp=21, backbone="convnextv2_tiny"):
+ super(MotionExtractor, self).__init__()
+ self.detector = convnextv2_tiny(num_kp=num_kp, backbone=backbone)
+
+ def forward(self, x):
+ out = self.detector(x)
+ return out # pitch, yaw, roll, t, exp, scale, kp
+
+ def load_model(self, ckpt_path):
+ self.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage))
+ self.eval()
+ return self
diff --git a/core/models/modules/spade_generator.py b/core/models/modules/spade_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7bf31087dd8f1de265e40198099a8a9e743d78a
--- /dev/null
+++ b/core/models/modules/spade_generator.py
@@ -0,0 +1,87 @@
+# coding: utf-8
+
+"""
+Spade decoder(G) defined in the paper, which input the warped feature to generate the animated image.
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from .util import SPADEResnetBlock
+
+
+class SPADEDecoder(nn.Module):
+ def __init__(
+ self,
+ upscale=2,
+ max_features=512,
+ block_expansion=64,
+ out_channels=64,
+ num_down_blocks=2,
+ ):
+ for i in range(num_down_blocks):
+ input_channels = min(max_features, block_expansion * (2 ** (i + 1)))
+ self.upscale = upscale
+ super().__init__()
+ norm_G = "spadespectralinstance"
+ label_num_channels = input_channels # 256
+
+ self.fc = nn.Conv2d(input_channels, 2 * input_channels, 3, padding=1)
+ self.G_middle_0 = SPADEResnetBlock(
+ 2 * input_channels, 2 * input_channels, norm_G, label_num_channels
+ )
+ self.G_middle_1 = SPADEResnetBlock(
+ 2 * input_channels, 2 * input_channels, norm_G, label_num_channels
+ )
+ self.G_middle_2 = SPADEResnetBlock(
+ 2 * input_channels, 2 * input_channels, norm_G, label_num_channels
+ )
+ self.G_middle_3 = SPADEResnetBlock(
+ 2 * input_channels, 2 * input_channels, norm_G, label_num_channels
+ )
+ self.G_middle_4 = SPADEResnetBlock(
+ 2 * input_channels, 2 * input_channels, norm_G, label_num_channels
+ )
+ self.G_middle_5 = SPADEResnetBlock(
+ 2 * input_channels, 2 * input_channels, norm_G, label_num_channels
+ )
+ self.up_0 = SPADEResnetBlock(
+ 2 * input_channels, input_channels, norm_G, label_num_channels
+ )
+ self.up_1 = SPADEResnetBlock(
+ input_channels, out_channels, norm_G, label_num_channels
+ )
+ self.up = nn.Upsample(scale_factor=2)
+
+ if self.upscale is None or self.upscale <= 1:
+ self.conv_img = nn.Conv2d(out_channels, 3, 3, padding=1)
+ else:
+ self.conv_img = nn.Sequential(
+ nn.Conv2d(out_channels, 3 * (2 * 2), kernel_size=3, padding=1),
+ nn.PixelShuffle(upscale_factor=2),
+ )
+
+ def forward(self, feature):
+ seg = feature # Bx256x64x64
+ x = self.fc(feature) # Bx512x64x64
+ x = self.G_middle_0(x, seg)
+ x = self.G_middle_1(x, seg)
+ x = self.G_middle_2(x, seg)
+ x = self.G_middle_3(x, seg)
+ x = self.G_middle_4(x, seg)
+ x = self.G_middle_5(x, seg)
+
+ x = self.up(x) # Bx512x64x64 -> Bx512x128x128
+ x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128
+ x = self.up(x) # Bx256x128x128 -> Bx256x256x256
+ x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256
+
+ x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW
+ x = torch.sigmoid(x) # Bx3xHxW
+
+ return x
+
+ def load_model(self, ckpt_path):
+ self.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage))
+ self.eval()
+ return self
diff --git a/core/models/modules/stitching_network.py b/core/models/modules/stitching_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..1db44d67c781fe2585bd5a4eeb5664ebcf0546f0
--- /dev/null
+++ b/core/models/modules/stitching_network.py
@@ -0,0 +1,65 @@
+# coding: utf-8
+
+"""
+Stitching module(S) and two retargeting modules(R) defined in the paper.
+
+- The stitching module pastes the animated portrait back into the original image space without pixel misalignment, such as in
+the stitching region.
+
+- The eyes retargeting module is designed to address the issue of incomplete eye closure during cross-id reenactment, especially
+when a person with small eyes drives a person with larger eyes.
+
+- The lip retargeting module is designed similarly to the eye retargeting module, and can also normalize the input by ensuring that
+the lips are in a closed state, which facilitates better animation driving.
+"""
+import torch
+from torch import nn
+
+
+def remove_ddp_dumplicate_key(state_dict):
+ from collections import OrderedDict
+ state_dict_new = OrderedDict()
+ for key in state_dict.keys():
+ state_dict_new[key.replace('module.', '')] = state_dict[key]
+ return state_dict_new
+
+
+class StitchingNetwork(nn.Module):
+ def __init__(self, input_size=126, hidden_sizes=[128, 128, 64], output_size=65):
+ super(StitchingNetwork, self).__init__()
+ layers = []
+ for i in range(len(hidden_sizes)):
+ if i == 0:
+ layers.append(nn.Linear(input_size, hidden_sizes[i]))
+ else:
+ layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i]))
+ layers.append(nn.ReLU(inplace=True))
+ layers.append(nn.Linear(hidden_sizes[-1], output_size))
+ self.mlp = nn.Sequential(*layers)
+
+ def _forward(self, x):
+ return self.mlp(x)
+
+ def load_model(self, ckpt_path):
+ checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
+ self.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_shoulder']))
+ self.eval()
+ return self
+
+ def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
+ """ conduct the stitching
+ kp_source: Bxnum_kpx3
+ kp_driving: Bxnum_kpx3
+ """
+ bs, num_kp = kp_source.shape[:2]
+ kp_driving_new = kp_driving.clone()
+ delta = self._forward(torch.cat([kp_source.view(bs, -1), kp_driving_new.view(bs, -1)], dim=1))
+ delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) # 1x20x3
+ delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) # 1x1x2
+ kp_driving_new += delta_exp
+ kp_driving_new[..., :2] += delta_tx_ty
+ return kp_driving_new
+
+ def forward(self, kp_source, kp_driving):
+ out = self.stitching(kp_source, kp_driving)
+ return out
\ No newline at end of file
diff --git a/core/models/modules/util.py b/core/models/modules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc6b925ff4d93dbb89d0d1e593bee15c888c39ee
--- /dev/null
+++ b/core/models/modules/util.py
@@ -0,0 +1,452 @@
+# coding: utf-8
+
+"""
+This file defines various neural network modules and utility functions, including convolutional and residual blocks,
+normalizations, and functions for spatial transformation and tensor manipulation.
+"""
+
+from torch import nn
+import torch.nn.functional as F
+import torch
+import torch.nn.utils.spectral_norm as spectral_norm
+import math
+import warnings
+import collections.abc
+from itertools import repeat
+
+def kp2gaussian(kp, spatial_size, kp_variance):
+ """
+ Transform a keypoint into gaussian like representation
+ """
+ mean = kp
+
+ coordinate_grid = make_coordinate_grid(spatial_size, mean)
+ number_of_leading_dimensions = len(mean.shape) - 1
+ shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
+ coordinate_grid = coordinate_grid.view(*shape)
+ repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
+ coordinate_grid = coordinate_grid.repeat(*repeats)
+
+ # Preprocess kp shape
+ shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
+ mean = mean.view(*shape)
+
+ mean_sub = (coordinate_grid - mean)
+
+ out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
+
+ return out
+
+
+def make_coordinate_grid(spatial_size, ref, **kwargs):
+ d, h, w = spatial_size
+ x = torch.arange(w).type(ref.dtype).to(ref.device)
+ y = torch.arange(h).type(ref.dtype).to(ref.device)
+ z = torch.arange(d).type(ref.dtype).to(ref.device)
+
+ # NOTE: must be right-down-in
+ x = (2 * (x / (w - 1)) - 1) # the x axis faces to the right
+ y = (2 * (y / (h - 1)) - 1) # the y axis faces to the bottom
+ z = (2 * (z / (d - 1)) - 1) # the z axis faces to the inner
+
+ yy = y.view(1, -1, 1).repeat(d, 1, w)
+ xx = x.view(1, 1, -1).repeat(d, h, 1)
+ zz = z.view(-1, 1, 1).repeat(1, h, w)
+
+ meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3)
+
+ return meshed
+
+
+class ConvT2d(nn.Module):
+ """
+ Upsampling block for use in decoder.
+ """
+
+ def __init__(self, in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1):
+ super(ConvT2d, self).__init__()
+
+ self.convT = nn.ConvTranspose2d(in_features, out_features, kernel_size=kernel_size, stride=stride,
+ padding=padding, output_padding=output_padding)
+ self.norm = nn.InstanceNorm2d(out_features)
+
+ def forward(self, x):
+ out = self.convT(x)
+ out = self.norm(out)
+ out = F.leaky_relu(out)
+ return out
+
+
+class ResBlock3d(nn.Module):
+ """
+ Res block, preserve spatial resolution.
+ """
+
+ def __init__(self, in_features, kernel_size, padding):
+ super(ResBlock3d, self).__init__()
+ self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding)
+ self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding)
+ self.norm1 = nn.BatchNorm3d(in_features, affine=True)
+ self.norm2 = nn.BatchNorm3d(in_features, affine=True)
+
+ def forward(self, x):
+ out = self.norm1(x)
+ out = F.relu(out)
+ out = self.conv1(out)
+ out = self.norm2(out)
+ out = F.relu(out)
+ out = self.conv2(out)
+ out += x
+ return out
+
+
+class UpBlock3d(nn.Module):
+ """
+ Upsampling block for use in decoder.
+ """
+
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(UpBlock3d, self).__init__()
+
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups)
+ self.norm = nn.BatchNorm3d(out_features, affine=True)
+
+ def forward(self, x):
+ out = F.interpolate(x, scale_factor=(1, 2, 2))
+ out = self.conv(out)
+ out = self.norm(out)
+ out = F.relu(out)
+ return out
+
+
+class DownBlock2d(nn.Module):
+ """
+ Downsampling block for use in encoder.
+ """
+
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(DownBlock2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups)
+ self.norm = nn.BatchNorm2d(out_features, affine=True)
+ self.pool = nn.AvgPool2d(kernel_size=(2, 2))
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.norm(out)
+ out = F.relu(out)
+ out = self.pool(out)
+ return out
+
+
+class DownBlock3d(nn.Module):
+ """
+ Downsampling block for use in encoder.
+ """
+
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(DownBlock3d, self).__init__()
+ '''
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups, stride=(1, 2, 2))
+ '''
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups)
+ self.norm = nn.BatchNorm3d(out_features, affine=True)
+ self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2))
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.norm(out)
+ out = F.relu(out)
+ out = self.pool(out)
+ return out
+
+
+class SameBlock2d(nn.Module):
+ """
+ Simple block, preserve spatial resolution.
+ """
+
+ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False):
+ super(SameBlock2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups)
+ self.norm = nn.BatchNorm2d(out_features, affine=True)
+ if lrelu:
+ self.ac = nn.LeakyReLU()
+ else:
+ self.ac = nn.ReLU()
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.norm(out)
+ out = self.ac(out)
+ return out
+
+
+class Encoder(nn.Module):
+ """
+ Hourglass Encoder
+ """
+
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(Encoder, self).__init__()
+
+ down_blocks = []
+ for i in range(num_blocks):
+ down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), kernel_size=3, padding=1))
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ def forward(self, x):
+ outs = [x]
+ for down_block in self.down_blocks:
+ outs.append(down_block(outs[-1]))
+ return outs
+
+
+class Decoder(nn.Module):
+ """
+ Hourglass Decoder
+ """
+
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(Decoder, self).__init__()
+
+ up_blocks = []
+
+ for i in range(num_blocks)[::-1]:
+ in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
+ out_filters = min(max_features, block_expansion * (2 ** i))
+ up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
+
+ self.up_blocks = nn.ModuleList(up_blocks)
+ self.out_filters = block_expansion + in_features
+
+ self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1)
+ self.norm = nn.BatchNorm3d(self.out_filters, affine=True)
+
+ def forward(self, x):
+ out = x.pop()
+ for up_block in self.up_blocks:
+ out = up_block(out)
+ skip = x.pop()
+ out = torch.cat([out, skip], dim=1)
+ out = self.conv(out)
+ out = self.norm(out)
+ out = F.relu(out)
+ return out
+
+
+class Hourglass(nn.Module):
+ """
+ Hourglass architecture.
+ """
+
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(Hourglass, self).__init__()
+ self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
+ self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
+ self.out_filters = self.decoder.out_filters
+
+ def forward(self, x):
+ return self.decoder(self.encoder(x))
+
+
+class SPADE(nn.Module):
+ def __init__(self, norm_nc, label_nc):
+ super().__init__()
+
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
+ nhidden = 128
+
+ self.mlp_shared = nn.Sequential(
+ nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
+ nn.ReLU())
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
+
+ def forward(self, x, segmap):
+ normalized = self.param_free_norm(x)
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
+ actv = self.mlp_shared(segmap)
+ gamma = self.mlp_gamma(actv)
+ beta = self.mlp_beta(actv)
+ out = normalized * (1 + gamma) + beta
+ return out
+
+
+class SPADEResnetBlock(nn.Module):
+ def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
+ super().__init__()
+ # Attributes
+ self.learned_shortcut = (fin != fout)
+ fmiddle = min(fin, fout)
+ self.use_se = use_se
+ # create conv layers
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
+ if self.learned_shortcut:
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
+ # apply spectral norm if specified
+ if 'spectral' in norm_G:
+ self.conv_0 = spectral_norm(self.conv_0)
+ self.conv_1 = spectral_norm(self.conv_1)
+ if self.learned_shortcut:
+ self.conv_s = spectral_norm(self.conv_s)
+ # define normalization layers
+ self.norm_0 = SPADE(fin, label_nc)
+ self.norm_1 = SPADE(fmiddle, label_nc)
+ if self.learned_shortcut:
+ self.norm_s = SPADE(fin, label_nc)
+
+ def forward(self, x, seg1):
+ x_s = self.shortcut(x, seg1)
+ dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
+ dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
+ out = x_s + dx
+ return out
+
+ def shortcut(self, x, seg1):
+ if self.learned_shortcut:
+ x_s = self.conv_s(self.norm_s(x, seg1))
+ else:
+ x_s = x
+ return x_s
+
+ def actvn(self, x):
+ return F.leaky_relu(x, 2e-1)
+
+
+def filter_state_dict(state_dict, remove_name='fc'):
+ new_state_dict = {}
+ for key in state_dict:
+ if remove_name in key:
+ continue
+ new_state_dict[key] = state_dict[key]
+ return new_state_dict
+
+
+class GRN(nn.Module):
+ """ GRN (Global Response Normalization) layer
+ """
+
+ def __init__(self, dim):
+ super().__init__()
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
+
+ def forward(self, x):
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
+ return self.gamma * (x * Nx) + self.beta + x
+
+
+class LayerNorm(nn.Module):
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+ with shape (batch_size, channels, height, width).
+ """
+
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.eps = eps
+ self.data_format = data_format
+ if self.data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError
+ self.normalized_shape = (normalized_shape, )
+
+ def forward(self, x):
+ if self.data_format == "channels_last":
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+ elif self.data_format == "channels_first":
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def drop_path(x, drop_prob=0., training=False, scale_by_keep=True):
+ """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+class DropPath(nn.Module):
+ """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+
+ def __init__(self, drop_prob=None, scale_by_keep=True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ return tuple(repeat(x, n))
+ return parse
+
+to_2tuple = _ntuple(2)
diff --git a/core/models/modules/warping_network.py b/core/models/modules/warping_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..3895b899eb574a513557f9345df8f69bab552444
--- /dev/null
+++ b/core/models/modules/warping_network.py
@@ -0,0 +1,87 @@
+# coding: utf-8
+
+"""
+Warping field estimator(W) defined in the paper, which generates a warping field using the implicit
+keypoint representations x_s and x_d, and employs this flow field to warp the source feature volume f_s.
+"""
+import torch
+from torch import nn
+import torch.nn.functional as F
+from .util import SameBlock2d
+from .dense_motion import DenseMotionNetwork
+
+
+class WarpingNetwork(nn.Module):
+ def __init__(
+ self,
+ num_kp=21,
+ block_expansion=64,
+ max_features=512,
+ num_down_blocks=2,
+ reshape_channel=32,
+ estimate_occlusion_map=True,
+ **kwargs
+ ):
+ super(WarpingNetwork, self).__init__()
+
+ self.upscale = kwargs.get('upscale', 1)
+ self.flag_use_occlusion_map = kwargs.get('flag_use_occlusion_map', True)
+
+ dense_motion_params = {
+ "block_expansion": 32,
+ "max_features": 1024,
+ "num_blocks": 5,
+ "reshape_depth": 16,
+ "compress": 4,
+ }
+
+ self.dense_motion_network = DenseMotionNetwork(
+ num_kp=num_kp,
+ feature_channel=reshape_channel,
+ estimate_occlusion_map=estimate_occlusion_map,
+ **dense_motion_params
+ )
+
+ self.third = SameBlock2d(max_features, block_expansion * (2 ** num_down_blocks), kernel_size=(3, 3), padding=(1, 1), lrelu=True)
+ self.fourth = nn.Conv2d(in_channels=block_expansion * (2 ** num_down_blocks), out_channels=block_expansion * (2 ** num_down_blocks), kernel_size=1, stride=1)
+
+ self.estimate_occlusion_map = estimate_occlusion_map
+
+ def deform_input(self, inp, deformation):
+ return F.grid_sample(inp, deformation, align_corners=False)
+
+ def forward(self, feature_3d, kp_source, kp_driving):
+ # Feature warper, Transforming feature representation according to deformation and occlusion
+ dense_motion = self.dense_motion_network(
+ feature=feature_3d, kp_driving=kp_driving, kp_source=kp_source
+ )
+ if 'occlusion_map' in dense_motion:
+ occlusion_map = dense_motion['occlusion_map'] # Bx1x64x64
+ else:
+ occlusion_map = None
+
+ deformation = dense_motion['deformation'] # Bx16x64x64x3
+ out = self.deform_input(feature_3d, deformation) # Bx32x16x64x64
+
+ bs, c, d, h, w = out.shape # Bx32x16x64x64
+ out = out.view(bs, c * d, h, w) # -> Bx512x64x64
+ out = self.third(out) # -> Bx256x64x64
+ out = self.fourth(out) # -> Bx256x64x64
+
+ if self.flag_use_occlusion_map and (occlusion_map is not None):
+ out = out * occlusion_map
+
+ # ret_dct = {
+ # 'occlusion_map': occlusion_map,
+ # 'deformation': deformation,
+ # 'out': out,
+ # }
+
+ # return ret_dct
+
+ return out
+
+ def load_model(self, ckpt_path):
+ self.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage))
+ self.eval()
+ return self
diff --git a/core/models/motion_extractor.py b/core/models/motion_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4e05ac286334e99912af134eaef373a7d8d1969
--- /dev/null
+++ b/core/models/motion_extractor.py
@@ -0,0 +1,49 @@
+import numpy as np
+import torch
+from ..utils.load_model import load_model
+
+
+class MotionExtractor:
+ def __init__(self, model_path, device="cuda"):
+ kwargs = {
+ "module_name": "MotionExtractor",
+ }
+ self.model, self.model_type = load_model(model_path, device=device, **kwargs)
+ self.device = device
+
+ self.output_names = [
+ "pitch",
+ "yaw",
+ "roll",
+ "t",
+ "exp",
+ "scale",
+ "kp",
+ ]
+
+ def __call__(self, image):
+ """
+ image: np.ndarray, shape (1, 3, 256, 256), RGB, 0-1
+ """
+ outputs = {}
+ if self.model_type == "onnx":
+ out_list = self.model.run(None, {"image": image})
+ for i, name in enumerate(self.output_names):
+ outputs[name] = out_list[i]
+ elif self.model_type == "tensorrt":
+ self.model.setup({"image": image})
+ self.model.infer()
+ for name in self.output_names:
+ outputs[name] = self.model.buffer[name][0].copy()
+ elif self.model_type == "pytorch":
+ with torch.no_grad(), torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=True):
+ pred = self.model(torch.from_numpy(image).to(self.device))
+ for i, name in enumerate(self.output_names):
+ outputs[name] = pred[i].float().cpu().numpy()
+ else:
+ raise ValueError(f"Unsupported model type: {self.model_type}")
+ outputs["exp"] = outputs["exp"].reshape(1, -1)
+ outputs["kp"] = outputs["kp"].reshape(1, -1)
+ return outputs
+
+
diff --git a/core/models/stitch_network.py b/core/models/stitch_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..56efa7ad8b1e79eafffd92121a15cbfc92af65dd
--- /dev/null
+++ b/core/models/stitch_network.py
@@ -0,0 +1,30 @@
+import numpy as np
+import torch
+from ..utils.load_model import load_model
+
+
+class StitchNetwork:
+ def __init__(self, model_path, device="cuda"):
+ kwargs = {
+ "module_name": "StitchingNetwork",
+ }
+ self.model, self.model_type = load_model(model_path, device=device, **kwargs)
+ self.device = device
+
+ def __call__(self, kp_source, kp_driving):
+ if self.model_type == "onnx":
+ pred = self.model.run(None, {"kp_source": kp_source, "kp_driving": kp_driving})[0]
+ elif self.model_type == "tensorrt":
+ self.model.setup({"kp_source": kp_source, "kp_driving": kp_driving})
+ self.model.infer()
+ pred = self.model.buffer["out"][0].copy()
+ elif self.model_type == 'pytorch':
+ with torch.no_grad():
+ pred = self.model(
+ torch.from_numpy(kp_source).to(self.device),
+ torch.from_numpy(kp_driving).to(self.device)
+ ).cpu().numpy()
+ else:
+ raise ValueError(f"Unsupported model type: {self.model_type}")
+
+ return pred
diff --git a/core/models/warp_network.py b/core/models/warp_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f14edf66c57567bebd5341a4c376dca0a646d46
--- /dev/null
+++ b/core/models/warp_network.py
@@ -0,0 +1,35 @@
+import numpy as np
+import torch
+from ..utils.load_model import load_model
+
+
+class WarpNetwork:
+ def __init__(self, model_path, device="cuda"):
+ kwargs = {
+ "module_name": "WarpingNetwork",
+ }
+ self.model, self.model_type = load_model(model_path, device=device, **kwargs)
+ self.device = device
+
+ def __call__(self, feature_3d, kp_source, kp_driving):
+ """
+ feature_3d: np.ndarray, shape (1, 32, 16, 64, 64)
+ kp_source | kp_driving: np.ndarray, shape (1, 21, 3)
+ """
+ if self.model_type == "onnx":
+ pred = self.model.run(None, {"feature_3d": feature_3d, "kp_source": kp_source, "kp_driving": kp_driving})[0]
+ elif self.model_type == "tensorrt":
+ self.model.setup({"feature_3d": feature_3d, "kp_source": kp_source, "kp_driving": kp_driving})
+ self.model.infer()
+ pred = self.model.buffer["out"][0].copy()
+ elif self.model_type == 'pytorch':
+ with torch.no_grad(), torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=True):
+ pred = self.model(
+ torch.from_numpy(feature_3d).to(self.device),
+ torch.from_numpy(kp_source).to(self.device),
+ torch.from_numpy(kp_driving).to(self.device)
+ ).float().cpu().numpy()
+ else:
+ raise ValueError(f"Unsupported model type: {self.model_type}")
+
+ return pred
diff --git a/core/utils/blend/__init__.py b/core/utils/blend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9e91b0dfd73ec50b041d112d91aa1893a27c600
--- /dev/null
+++ b/core/utils/blend/__init__.py
@@ -0,0 +1,4 @@
+import pyximport
+pyximport.install()
+
+from .blend import blend_images_cy
\ No newline at end of file
diff --git a/core/utils/blend/blend.pyx b/core/utils/blend/blend.pyx
new file mode 100644
index 0000000000000000000000000000000000000000..68e4695dc644ad1f6e2cc0538a13c853fbef2f7a
--- /dev/null
+++ b/core/utils/blend/blend.pyx
@@ -0,0 +1,38 @@
+#cython: language_level=3
+import numpy as np
+cimport numpy as np
+
+cdef extern from "blend_impl.h":
+ void _blend_images_cy_impl(
+ const float* mask_warped,
+ const float* frame_warped,
+ const unsigned char* frame_rgb,
+ const int height,
+ const int width,
+ unsigned char* result
+ ) noexcept nogil
+
+def blend_images_cy(
+ np.ndarray[np.float32_t, ndim=2] mask_warped,
+ np.ndarray[np.float32_t, ndim=3] frame_warped,
+ np.ndarray[np.uint8_t, ndim=3] frame_rgb,
+ np.ndarray[np.uint8_t, ndim=3] result
+):
+ cdef int h = mask_warped.shape[0]
+ cdef int w = mask_warped.shape[1]
+
+ if not mask_warped.flags['C_CONTIGUOUS']:
+ mask_warped = np.ascontiguousarray(mask_warped)
+ if not frame_warped.flags['C_CONTIGUOUS']:
+ frame_warped = np.ascontiguousarray(frame_warped)
+ if not frame_rgb.flags['C_CONTIGUOUS']:
+ frame_rgb = np.ascontiguousarray(frame_rgb)
+
+ with nogil:
+ _blend_images_cy_impl(
+ mask_warped.data,
+ frame_warped.data,
+ frame_rgb.data,
+ h, w,
+ result.data
+ )
\ No newline at end of file
diff --git a/core/utils/blend/blend.pyxbld b/core/utils/blend/blend.pyxbld
new file mode 100644
index 0000000000000000000000000000000000000000..75e1d54ac97c8c38475250b6b441f54ed0b5c5b5
--- /dev/null
+++ b/core/utils/blend/blend.pyxbld
@@ -0,0 +1,11 @@
+import numpy as np
+import os
+
+def make_ext(modname, pyxfilename):
+ from distutils.extension import Extension
+
+ return Extension(name=modname,
+ sources=[pyxfilename, os.path.join(os.path.dirname(pyxfilename), "blend_impl.c")],
+ include_dirs=[np.get_include(), os.path.dirname(pyxfilename)],
+ extra_compile_args=["-O3", "-std=c99", "-march=native", "-ffast-math"],
+ )
diff --git a/core/utils/blend/blend_impl.c b/core/utils/blend/blend_impl.c
new file mode 100644
index 0000000000000000000000000000000000000000..bf0661d494057650cad4f5981db298a8a05f8478
--- /dev/null
+++ b/core/utils/blend/blend_impl.c
@@ -0,0 +1,34 @@
+#include
+
+void _blend_images_cy_impl(
+ const float* mask_warped,
+ const float* frame_warped,
+ const unsigned char* frame_rgb,
+ const int height,
+ const int width,
+ unsigned char* result) {
+
+ const float* mask_pointer = mask_warped;
+ const float* frame_warped_pointer = frame_warped;
+ const unsigned char* frame_rgb_pointer = frame_rgb;
+ unsigned char* result_pointer = result;
+
+ for(int i = 0; i < height; i++) {
+ for(int j = 0; j < width; j++) {
+ float mask = *mask_pointer;
+ float mask_inv = 1.0f - mask;
+
+ float blended1 = mask * (*frame_warped_pointer) + mask_inv * (*frame_rgb_pointer);
+ float blended2 = mask * (*(frame_warped_pointer+1)) + mask_inv * (*(frame_rgb_pointer+1));
+ float blended3 = mask * (*(frame_warped_pointer+2)) + mask_inv * (*(frame_rgb_pointer+2));
+
+ *(result_pointer++) = blended1 > 255 ? 255 : (blended1 < 0) ? 0 : (unsigned char)blended1;
+ *(result_pointer++) = blended2 > 255 ? 255 : (blended2 < 0) ? 0 : (unsigned char)blended2;
+ *(result_pointer++) = blended3 > 255 ? 255 : (blended3 < 0) ? 0 : (unsigned char)blended3;
+
+ frame_warped_pointer+=3;
+ frame_rgb_pointer+=3;
+ mask_pointer++;
+ }
+ }
+}
\ No newline at end of file
diff --git a/core/utils/blend/blend_impl.h b/core/utils/blend/blend_impl.h
new file mode 100644
index 0000000000000000000000000000000000000000..854092525031c01a65dc7e288cc83324f59b7a1c
--- /dev/null
+++ b/core/utils/blend/blend_impl.h
@@ -0,0 +1,13 @@
+
+#ifndef __BLEND_IMAGES_CY_IMPL_H__
+#define __BLEND_IMAGES_CY_IMPL_H__
+
+void _blend_images_cy_impl(
+ const float* mask_warped,
+ const float* frame_warped,
+ const unsigned char* frame_rgb,
+ const int height,
+ const int width,
+ unsigned char* result);
+
+#endif // __BLEND_IMAGES_CY_IMPL_H__
\ No newline at end of file
diff --git a/core/utils/crop.py b/core/utils/crop.py
new file mode 100644
index 0000000000000000000000000000000000000000..e01d08d53fb1d212c40da1753149b33eb5b14deb
--- /dev/null
+++ b/core/utils/crop.py
@@ -0,0 +1,459 @@
+# coding: utf-8
+
+"""
+cropping function and the related preprocess functions for cropping
+"""
+
+import numpy as np
+import os.path as osp
+from math import sin, cos, acos, degrees
+import cv2
+
+DTYPE = np.float32
+CV2_INTERP = cv2.INTER_LINEAR
+
+
+def _transform_img(img, M, dsize, flags=CV2_INTERP, borderMode=None):
+ """conduct similarity or affine transformation to the image, do not do border operation!
+ img:
+ M: 2x3 matrix or 3x3 matrix
+ dsize: target shape (width, height)
+ """
+ if isinstance(dsize, tuple) or isinstance(dsize, list):
+ _dsize = tuple(dsize)
+ else:
+ _dsize = (dsize, dsize)
+
+ if borderMode is not None:
+ return cv2.warpAffine(
+ img,
+ M[:2, :],
+ dsize=_dsize,
+ flags=flags,
+ borderMode=borderMode,
+ borderValue=(0, 0, 0),
+ )
+ else:
+ return cv2.warpAffine(img, M[:2, :], dsize=_dsize, flags=flags)
+
+
+def _transform_pts(pts, M):
+ """conduct similarity or affine transformation to the pts
+ pts: Nx2 ndarray
+ M: 2x3 matrix or 3x3 matrix
+ return: Nx2
+ """
+ return pts @ M[:2, :2].T + M[:2, 2]
+
+
+def parse_pt2_from_pt101(pt101, use_lip=True):
+ """
+ parsing the 2 points according to the 101 points, which cancels the roll
+ """
+ # the former version use the eye center, but it is not robust, now use interpolation
+ pt_left_eye = np.mean(pt101[[39, 42, 45, 48]], axis=0) # left eye center
+ pt_right_eye = np.mean(pt101[[51, 54, 57, 60]], axis=0) # right eye center
+
+ if use_lip:
+ # use lip
+ pt_center_eye = (pt_left_eye + pt_right_eye) / 2
+ pt_center_lip = (pt101[75] + pt101[81]) / 2
+ pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0)
+ else:
+ pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0)
+ return pt2
+
+
+def parse_pt2_from_pt106(pt106, use_lip=True):
+ """
+ parsing the 2 points according to the 106 points, which cancels the roll
+ """
+ pt_left_eye = np.mean(pt106[[33, 35, 40, 39]], axis=0) # left eye center
+ pt_right_eye = np.mean(pt106[[87, 89, 94, 93]], axis=0) # right eye center
+
+ if use_lip:
+ # use lip
+ pt_center_eye = (pt_left_eye + pt_right_eye) / 2
+ pt_center_lip = (pt106[52] + pt106[61]) / 2
+ pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0)
+ else:
+ pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0)
+ return pt2
+
+
+def parse_pt2_from_pt203(pt203, use_lip=True):
+ """
+ parsing the 2 points according to the 203 points, which cancels the roll
+ """
+ pt_left_eye = np.mean(pt203[[0, 6, 12, 18]], axis=0) # left eye center
+ pt_right_eye = np.mean(pt203[[24, 30, 36, 42]], axis=0) # right eye center
+ if use_lip:
+ # use lip
+ pt_center_eye = (pt_left_eye + pt_right_eye) / 2
+ pt_center_lip = (pt203[48] + pt203[66]) / 2
+ pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0)
+ else:
+ pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0)
+ return pt2
+
+
+def parse_pt2_from_pt68(pt68, use_lip=True):
+ """
+ parsing the 2 points according to the 68 points, which cancels the roll
+ """
+ lm_idx = np.array([31, 37, 40, 43, 46, 49, 55], dtype=np.int32) - 1
+ if use_lip:
+ pt5 = np.stack(
+ [
+ np.mean(pt68[lm_idx[[1, 2]], :], 0), # left eye
+ np.mean(pt68[lm_idx[[3, 4]], :], 0), # right eye
+ pt68[lm_idx[0], :], # nose
+ pt68[lm_idx[5], :], # lip
+ pt68[lm_idx[6], :], # lip
+ ],
+ axis=0,
+ )
+
+ pt2 = np.stack([(pt5[0] + pt5[1]) / 2, (pt5[3] + pt5[4]) / 2], axis=0)
+ else:
+ pt2 = np.stack(
+ [
+ np.mean(pt68[lm_idx[[1, 2]], :], 0), # left eye
+ np.mean(pt68[lm_idx[[3, 4]], :], 0), # right eye
+ ],
+ axis=0,
+ )
+
+ return pt2
+
+
+def parse_pt2_from_pt5(pt5, use_lip=True):
+ """
+ parsing the 2 points according to the 5 points, which cancels the roll
+ """
+ if use_lip:
+ pt2 = np.stack([(pt5[0] + pt5[1]) / 2, (pt5[3] + pt5[4]) / 2], axis=0)
+ else:
+ pt2 = np.stack([pt5[0], pt5[1]], axis=0)
+ return pt2
+
+
+def parse_pt2_from_pt9(pt9, use_lip=True):
+ """
+ parsing the 2 points according to the 9 points, which cancels the roll
+ ['right eye right', 'right eye left', 'left eye right', 'left eye left', 'nose tip', 'lip right', 'lip left', 'upper lip', 'lower lip']
+ """
+ if use_lip:
+ pt9 = np.stack(
+ [
+ (pt9[2] + pt9[3]) / 2, # left eye
+ (pt9[0] + pt9[1]) / 2, # right eye
+ pt9[4],
+ (pt9[5] + pt9[6]) / 2, # lip
+ ],
+ axis=0,
+ )
+ pt2 = np.stack([(pt9[0] + pt9[1]) / 2, pt9[3]], axis=0) # eye # lip
+ else:
+ pt2 = np.stack(
+ [
+ (pt9[2] + pt9[3]) / 2,
+ (pt9[0] + pt9[1]) / 2,
+ ],
+ axis=0,
+ )
+
+ return pt2
+
+
+def parse_pt2_from_pt_x(pts, use_lip=True):
+ if pts.shape[0] == 101:
+ pt2 = parse_pt2_from_pt101(pts, use_lip=use_lip)
+ elif pts.shape[0] == 106:
+ pt2 = parse_pt2_from_pt106(pts, use_lip=use_lip)
+ elif pts.shape[0] == 68:
+ pt2 = parse_pt2_from_pt68(pts, use_lip=use_lip)
+ elif pts.shape[0] == 5:
+ pt2 = parse_pt2_from_pt5(pts, use_lip=use_lip)
+ elif pts.shape[0] == 203:
+ pt2 = parse_pt2_from_pt203(pts, use_lip=use_lip)
+ elif pts.shape[0] > 101:
+ # take the first 101 points
+ pt2 = parse_pt2_from_pt101(pts[:101], use_lip=use_lip)
+ elif pts.shape[0] == 9:
+ pt2 = parse_pt2_from_pt9(pts, use_lip=use_lip)
+ else:
+ raise Exception(f"Unknow shape: {pts.shape}")
+
+ if not use_lip:
+ # NOTE: to compile with the latter code, need to rotate the pt2 90 degrees clockwise manually
+ v = pt2[1] - pt2[0]
+ pt2[1, 0] = pt2[0, 0] - v[1]
+ pt2[1, 1] = pt2[0, 1] + v[0]
+
+ return pt2
+
+
+def parse_rect_from_landmark(
+ pts,
+ scale=1.5,
+ need_square=True,
+ vx_ratio=0,
+ vy_ratio=0,
+ use_deg_flag=False,
+ **kwargs,
+):
+ """parsing center, size, angle from 101/68/5/x landmarks
+ vx_ratio: the offset ratio along the pupil axis x-axis, multiplied by size
+ vy_ratio: the offset ratio along the pupil axis y-axis, multiplied by size, which is used to contain more forehead area
+
+ judge with pts.shape
+ """
+ pt2 = parse_pt2_from_pt_x(pts, use_lip=kwargs.get("use_lip", True))
+
+ uy = pt2[1] - pt2[0]
+ l = np.linalg.norm(uy)
+ if l <= 1e-3:
+ uy = np.array([0, 1], dtype=DTYPE)
+ else:
+ uy /= l
+ ux = np.array((uy[1], -uy[0]), dtype=DTYPE)
+
+ # the rotation degree of the x-axis, the clockwise is positive, the counterclockwise is negative (image coordinate system)
+ # print(uy)
+ # print(ux)
+ angle = acos(ux[0])
+ if ux[1] < 0:
+ angle = -angle
+
+ # rotation matrix
+ M = np.array([ux, uy])
+
+ # calculate the size which contains the angle degree of the bbox, and the center
+ center0 = np.mean(pts, axis=0)
+ rpts = (pts - center0) @ M.T # (M @ P.T).T = P @ M.T
+ lt_pt = np.min(rpts, axis=0)
+ rb_pt = np.max(rpts, axis=0)
+ center1 = (lt_pt + rb_pt) / 2
+
+ size = rb_pt - lt_pt
+ if need_square:
+ m = max(size[0], size[1])
+ size[0] = m
+ size[1] = m
+
+ size *= scale # scale size
+ center = (
+ center0 + ux * center1[0] + uy * center1[1]
+ ) # counterclockwise rotation, equivalent to M.T @ center1.T
+ center = (
+ center + ux * (vx_ratio * size) + uy * (vy_ratio * size)
+ ) # considering the offset in vx and vy direction
+
+ if use_deg_flag:
+ angle = degrees(angle)
+
+ return center, size, angle
+
+
+def parse_bbox_from_landmark(pts, **kwargs):
+ center, size, angle = parse_rect_from_landmark(pts, **kwargs)
+ cx, cy = center
+ w, h = size
+
+ # calculate the vertex positions before rotation
+ bbox = np.array(
+ [
+ [cx - w / 2, cy - h / 2], # left, top
+ [cx + w / 2, cy - h / 2],
+ [cx + w / 2, cy + h / 2], # right, bottom
+ [cx - w / 2, cy + h / 2],
+ ],
+ dtype=DTYPE,
+ )
+
+ # construct rotation matrix
+ bbox_rot = bbox.copy()
+ R = np.array(
+ [[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]], dtype=DTYPE
+ )
+
+ # calculate the relative position of each vertex from the rotation center, then rotate these positions, and finally add the coordinates of the rotation center
+ bbox_rot = (bbox_rot - center) @ R.T + center
+
+ return {
+ "center": center, # 2x1
+ "size": size, # scalar
+ "angle": angle, # rad, counterclockwise
+ "bbox": bbox, # 4x2
+ "bbox_rot": bbox_rot, # 4x2
+ }
+
+
+def crop_image_by_bbox(
+ img, bbox, lmk=None, dsize=512, angle=None, flag_rot=False, **kwargs
+):
+ left, top, right, bot = bbox
+ if int(right - left) != int(bot - top):
+ print(f"right-left {right-left} != bot-top {bot-top}")
+ size = right - left
+
+ src_center = np.array([(left + right) / 2, (top + bot) / 2], dtype=DTYPE)
+ tgt_center = np.array([dsize / 2, dsize / 2], dtype=DTYPE)
+
+ s = dsize / size # scale
+ if flag_rot and angle is not None:
+ costheta, sintheta = cos(angle), sin(angle)
+ cx, cy = src_center[0], src_center[1] # ori center
+ tcx, tcy = tgt_center[0], tgt_center[1] # target center
+ # need to infer
+ M_o2c = np.array(
+ [
+ [s * costheta, s * sintheta, tcx - s * (costheta * cx + sintheta * cy)],
+ [
+ -s * sintheta,
+ s * costheta,
+ tcy - s * (-sintheta * cx + costheta * cy),
+ ],
+ ],
+ dtype=DTYPE,
+ )
+ else:
+ M_o2c = np.array(
+ [
+ [s, 0, tgt_center[0] - s * src_center[0]],
+ [0, s, tgt_center[1] - s * src_center[1]],
+ ],
+ dtype=DTYPE,
+ )
+
+ # if flag_rot and angle is None:
+ # print('angle is None, but flag_rotate is True', style="bold yellow")
+
+ img_crop = _transform_img(
+ img, M_o2c, dsize=dsize, borderMode=kwargs.get("borderMode", None)
+ )
+ lmk_crop = _transform_pts(lmk, M_o2c) if lmk is not None else None
+
+ M_o2c = np.vstack([M_o2c, np.array([0, 0, 1], dtype=DTYPE)])
+ M_c2o = np.linalg.inv(M_o2c)
+
+ # cv2.imwrite('crop.jpg', img_crop)
+
+ return {
+ "img_crop": img_crop,
+ "lmk_crop": lmk_crop,
+ "M_o2c": M_o2c,
+ "M_c2o": M_c2o,
+ }
+
+
+def _estimate_similar_transform_from_pts(
+ pts, dsize, scale=1.5, vx_ratio=0, vy_ratio=-0.1, flag_do_rot=True, **kwargs
+):
+ """calculate the affine matrix of the cropped image from sparse points, the original image to the cropped image, the inverse is the cropped image to the original image
+ pts: landmark, 101 or 68 points or other points, Nx2
+ scale: the larger scale factor, the smaller face ratio
+ vx_ratio: x shift
+ vy_ratio: y shift, the smaller the y shift, the lower the face region
+ rot_flag: if it is true, conduct correction
+ """
+ center, size, angle = parse_rect_from_landmark(
+ pts,
+ scale=scale,
+ vx_ratio=vx_ratio,
+ vy_ratio=vy_ratio,
+ use_lip=kwargs.get("use_lip", True),
+ )
+
+ s = dsize / size[0] # scale
+ tgt_center = np.array([dsize / 2, dsize / 2], dtype=DTYPE) # center of dsize
+
+ if flag_do_rot:
+ costheta, sintheta = cos(angle), sin(angle)
+ cx, cy = center[0], center[1] # ori center
+ tcx, tcy = tgt_center[0], tgt_center[1] # target center
+ # need to infer
+ M_INV = np.array(
+ [
+ [s * costheta, s * sintheta, tcx - s * (costheta * cx + sintheta * cy)],
+ [
+ -s * sintheta,
+ s * costheta,
+ tcy - s * (-sintheta * cx + costheta * cy),
+ ],
+ ],
+ dtype=DTYPE,
+ )
+ else:
+ M_INV = np.array(
+ [
+ [s, 0, tgt_center[0] - s * center[0]],
+ [0, s, tgt_center[1] - s * center[1]],
+ ],
+ dtype=DTYPE,
+ )
+
+ M_INV_H = np.vstack([M_INV, np.array([0, 0, 1])])
+ M = np.linalg.inv(M_INV_H)
+
+ # M_INV is from the original image to the cropped image, M is from the cropped image to the original image
+ return M_INV, M[:2, ...]
+
+
+def crop_image(img, pts: np.ndarray, **kwargs):
+ dsize = kwargs.get("dsize", 224)
+ scale = kwargs.get("scale", 1.5) # 1.5 | 1.6
+ vy_ratio = kwargs.get("vy_ratio", -0.1) # -0.0625 | -0.1
+
+ pt_crop_flag = kwargs.get("pt_crop_flag", True)
+
+ M_INV, _ = _estimate_similar_transform_from_pts(
+ pts,
+ dsize=dsize,
+ scale=scale,
+ vy_ratio=vy_ratio,
+ flag_do_rot=kwargs.get("flag_do_rot", True),
+ )
+
+ img_crop = _transform_img(img, M_INV, dsize) # origin to crop
+ if pt_crop_flag:
+ pt_crop = _transform_pts(pts, M_INV)
+ else:
+ pt_crop = None
+
+ M_o2c = np.vstack([M_INV, np.array([0, 0, 1], dtype=DTYPE)])
+ M_c2o = np.linalg.inv(M_o2c)
+
+ ret_dct = {
+ "M_o2c": M_o2c, # from the original image to the cropped image 3x3
+ "M_c2o": M_c2o, # from the cropped image to the original image 3x3
+ "img_crop": img_crop, # the cropped image
+ "pt_crop": pt_crop, # the landmarks of the cropped image
+ }
+
+ return ret_dct
+
+
+def average_bbox_lst(bbox_lst):
+ if len(bbox_lst) == 0:
+ return None
+ bbox_arr = np.array(bbox_lst)
+ return np.mean(bbox_arr, axis=0).tolist()
+
+
+def prepare_paste_back(mask_crop, crop_M_c2o, dsize):
+ """prepare mask for later image paste back"""
+ mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize)
+ mask_ori = mask_ori.astype(np.float32) / 255.0
+ return mask_ori
+
+
+def paste_back(img_crop, M_c2o, img_ori, mask_ori):
+ """paste back the image"""
+ dsize = (img_ori.shape[1], img_ori.shape[0])
+ result = _transform_img(img_crop, M_c2o, dsize=dsize)
+ result = np.clip(mask_ori * result + (1 - mask_ori) * img_ori, 0, 255).astype(
+ np.uint8
+ )
+ return result
diff --git a/core/utils/eye_info.py b/core/utils/eye_info.py
new file mode 100644
index 0000000000000000000000000000000000000000..01fdc1e1d791e0777711642cc415d8c6bd90fdfe
--- /dev/null
+++ b/core/utils/eye_info.py
@@ -0,0 +1,111 @@
+import numpy as np
+from dataclasses import dataclass
+
+
+@dataclass
+class EyeIdxMP:
+ LO = [33]
+ LI = [133]
+ LD = [7, 163, 144, 145, 153, 154, 155] # O -> I
+ LU = [246, 161, 160, 159, 158, 157, 173] # O -> I
+ RO = [263]
+ RI = [362]
+ RD = [249, 390, 373, 374, 380, 381, 382] # O -> I
+ RU = [466, 388, 387, 386, 385, 384, 398] # O -> I
+
+ LW = [33, 133] # oi
+ LH0 = [145, 159]
+ LH1 = [144, 160]
+ LH2 = [153, 158]
+
+ RW = [263, 362] # oi
+ RH0 = [374, 386]
+ RH1 = [373, 387]
+ RH2 = [380, 385]
+
+ LB = [468] # eye ball
+ RB = [473]
+
+
+class EyeAttrUtilsByMP:
+ def __init__(self, lmks_mp):
+ self.IDX = EyeIdxMP()
+ self.lmks = lmks_mp # [n, 478, 3]
+
+ self.L_width = self._dist_idx(*self.IDX.LW) # [n]
+ self.R_width = self._dist_idx(*self.IDX.RW)
+
+ self.L_h0 = self._dist_idx(*self.IDX.LH0)
+ self.L_h1 = self._dist_idx(*self.IDX.LH1)
+ self.L_h2 = self._dist_idx(*self.IDX.LH2)
+
+ self.R_h0 = self._dist_idx(*self.IDX.RH0)
+ self.R_h1 = self._dist_idx(*self.IDX.RH1)
+ self.R_h2 = self._dist_idx(*self.IDX.RH2)
+
+ self.L_open = (self.L_h0 + self.L_h1 + self.L_h2) / (self.L_width + 1e-8) # [n]
+ self.R_open = (self.R_h0 + self.R_h1 + self.R_h2) / (self.R_width + 1e-8)
+
+ self.L_center = self._center_idx(*self.IDX.LW) # [n, 3/2]
+ self.R_center = self._center_idx(*self.IDX.RW)
+
+ self.L_ball = self.lmks[:, self.IDX.LB[0]] # [n, 3/2]
+ self.R_ball = self.lmks[:, self.IDX.RB[0]]
+
+ self.L_ball_direc = (self.L_ball - self.L_center) / (self.L_width[:, None] + 1e-8) # [n, 3/2]
+ self.R_ball_direc = (self.R_ball - self.R_center) / (self.R_width[:, None] + 1e-8)
+
+ self.L_eye_direc = self._direc_idx(*self.IDX.LW) # I->O
+ self.R_eye_direc = self._direc_idx(*self.IDX.RW)
+
+ self.L_ball_move_dist = self._dist(self.L_ball, self.L_center)
+ self.R_ball_move_dist = self._dist(self.R_ball, self.R_center)
+
+ self.L_ball_move_direc = self._direc(self.L_ball, self.L_center) - self.L_eye_direc
+ self.R_ball_move_direc = self._direc(self.R_ball, self.R_center) - self.R_eye_direc
+
+ self.L_ball_move = self.L_ball_move_direc * self.L_ball_move_dist[:, None]
+ self.R_ball_move = self.R_ball_move_direc * self.R_ball_move_dist[:, None]
+
+ def LR_open(self):
+ LR_open = np.stack([self.L_open, self.R_open], -1) # [n, 2]
+ return LR_open
+
+ def LR_ball_direc(self):
+ LR_ball_direc = np.stack([self.L_ball_direc, self.R_ball_direc], -1) # [n, 3, 2]
+ return LR_ball_direc
+
+ def LR_ball_move(self):
+ LR_ball_move = np.stack([self.L_ball_move, self.R_ball_move], -1)
+ return LR_ball_move
+
+ @staticmethod
+ def _dist(p1, p2):
+ # p1/p2: [n, 3/2]
+ return (((p1 - p2) ** 2).sum(-1)) ** 0.5 # [n]
+
+ @staticmethod
+ def _center(p1, p2):
+ return (p1 + p2) * 0.5 # [n, 3/2]
+
+ def _direc(self, p1, p2):
+ # p1 - p2, (2->1)
+ return (p1 - p2) / (self._dist(p1, p2)[:, None] + 1e-8)
+
+ def _dist_idx(self, idx1, idx2):
+ p1 = self.lmks[:, idx1]
+ p2 = self.lmks[:, idx2]
+ d = self._dist(p1, p2)
+ return d
+
+ def _center_idx(self, idx1, idx2):
+ p1 = self.lmks[:, idx1]
+ p2 = self.lmks[:, idx2]
+ c = self._center(p1, p2)
+ return c
+
+ def _direc_idx(self, idx1, idx2):
+ p1 = self.lmks[:, idx1]
+ p2 = self.lmks[:, idx2]
+ dir = self._direc(p1, p2)
+ return dir
\ No newline at end of file
diff --git a/core/utils/get_mask.py b/core/utils/get_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cb83658455011bbeb6e93b3461ef5b777946a52
--- /dev/null
+++ b/core/utils/get_mask.py
@@ -0,0 +1,67 @@
+import numpy as np
+
+
+def get_mask(W, H, ratio_w=0.9, ratio_h=0.9):
+ w = int(W * ratio_w)
+ h = int(H * ratio_h)
+
+ x1 = (W - w) // 2
+ x2 = x1 + w
+
+ y1 = (H - h) // 2
+ y2 = y1 + h
+
+ mask = np.ones((H, W), dtype=np.float32)
+
+ # top
+ row = np.linspace(0, 0, w)[None, :] # (1, w)
+ col = np.linspace(0, 1, y1)[:, None] # (y1, 1)
+ grad_t = np.sqrt(row**2 + col**2).astype(np.float32) # (y1, w)
+ mask[0:y1, x1:x2] = grad_t
+
+ # bottom
+ row = np.linspace(0, 0, w)[None, :] # (1, w)
+ col = np.linspace(1, 0, H - y2)[:, None] # (H-y2, 1)
+ grad_b = np.sqrt(row**2 + col**2).astype(np.float32) # (H-y2, w)
+ mask[y2:H, x1:x2] = grad_b
+
+ # left
+ row = np.linspace(0, 1, x1)[None, :] # (1, x1)
+ col = np.linspace(0, 0, h)[:, None] # (h, 1)
+ grad_l = np.sqrt(row**2 + col**2).astype(np.float32) # (h, x1)
+ mask[y1:y2, 0:x1] = grad_l
+
+ # right
+ row = np.linspace(1, 0, W - x2)[None, :] # (1, W-x2)
+ col = np.linspace(0, 0, h)[:, None] # (h, 1)
+ grad_r = np.sqrt(row**2 + col**2).astype(np.float32) # (h, W-x2)
+ mask[y1:y2, x2:W] = grad_r
+
+ # top left
+ row = np.linspace(1, 0, x1)[None, :] # (1, w)
+ col = np.linspace(1, 0, y1)[:, None] # (y1, 1)
+ grad_tl = np.sqrt(row**2 + col**2).astype(np.float32) # (y1, x1)
+ grad_tl = 1 - np.clip(grad_tl, 0, 1)
+ mask[0:y1, 0:x1] = grad_tl
+
+ # top right
+ row = np.linspace(0, 1, W - x2)[None, :] # (1, W-x2)
+ col = np.linspace(1, 0, y1)[:, None] # (y1, 1)
+ grad_tr = np.sqrt(row**2 + col**2).astype(np.float32) # (y1, W-x2)
+ grad_tr = 1 - np.clip(grad_tr, 0, 1)
+ mask[0:y1, x2:W] = grad_tr
+
+ # bottom left
+ row = np.linspace(1, 0, x1)[None, :] # (1, x1)
+ col = np.linspace(0, 1, H - y2)[:, None] # (H-y2, 1)
+ grad_bl = np.sqrt(row**2 + col**2).astype(np.float32) # (H-y2, x1)
+ grad_bl = 1 - np.clip(grad_bl, 0, 1)
+ mask[y2:H, 0:x1] = grad_bl
+
+ # bottom right
+ row = np.linspace(0, 1, W - x2)[None, :] # (1, W-x2)
+ col = np.linspace(0, 1, H - y2)[:, None] # (H-y2, 1)
+ grad_br = np.sqrt(row**2 + col**2).astype(np.float32) # (H-y2, W-x2)
+ grad_br = 1 - np.clip(grad_br, 0, 1)
+ mask[y2:H, x2:W] = grad_br
+ return mask[:, :, None]
diff --git a/core/utils/load_model.py b/core/utils/load_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b2fda8e9405117758f95faadbfff2be3407949f
--- /dev/null
+++ b/core/utils/load_model.py
@@ -0,0 +1,64 @@
+def load_model(model_path: str, device: str = "cuda", **kwargs):
+ if kwargs.get("force_ori_type", False):
+ # for hubert, landmark, retinaface, mediapipe
+ model = load_force_ori_type(model_path, device, **kwargs)
+ return model, "ori"
+
+ if model_path.endswith(".onnx"):
+ # onnx
+ import onnxruntime
+
+ if device == "cuda":
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
+ else:
+ providers = ["CPUExecutionProvider"]
+ model = onnxruntime.InferenceSession(model_path, providers=providers)
+ return model, "onnx"
+
+ elif model_path.endswith(".engine") or model_path.endswith(".trt"):
+ # tensorRT
+ from .tensorrt_utils import TRTWrapper
+
+ model = TRTWrapper(model_path)
+ return model, "tensorrt"
+
+ elif model_path.endswith(".pt") or model_path.endswith(".pth"):
+ # pytorch
+ model = create_model(model_path, device, **kwargs)
+ return model, "pytorch"
+
+ else:
+ raise ValueError(f"Unsupported model file type: {model_path}")
+
+
+def create_model(
+ model_path: str,
+ device: str = "cuda",
+ module_name="",
+ package_name="..models.modules",
+ **kwargs,
+):
+ import importlib
+
+ # module = getattr(importlib.import_module('..models.modules', __package__), module_name)
+ module = getattr(importlib.import_module(package_name, __package__), module_name)
+ # from import
+
+ model = module(**kwargs)
+ model.load_model(model_path).to(device)
+ return model
+
+
+def load_force_ori_type(
+ model_path: str,
+ device: str = "cuda",
+ module_name="",
+ package_name="..aux_models.modules",
+ force_ori_type=False,
+ **kwargs,
+):
+ import importlib
+
+ module = getattr(importlib.import_module(package_name, __package__), module_name)
+ model = module(**kwargs)
+ return model
diff --git a/core/utils/tensorrt_utils.py b/core/utils/tensorrt_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff233b9c54ac81e25722e36afe0b8a70def27844
--- /dev/null
+++ b/core/utils/tensorrt_utils.py
@@ -0,0 +1,273 @@
+import ctypes
+from collections import OrderedDict
+from typing import Type
+from cuda import cuda, cudart, nvrtc
+import numpy as np
+import torch
+import ctypes
+import os
+import torch
+
+try:
+ import tensorrt as trt
+except ImportError:
+ import tensorrt_libs
+
+ trt_libs_path = tensorrt_libs.__path__[0]
+ ctypes.CDLL(os.path.join(trt_libs_path, "libnvinfer.so.8"))
+ ctypes.CDLL(os.path.join(trt_libs_path, "libnvinfer_plugin.so.8"))
+ ctypes.CDLL(os.path.join(trt_libs_path, "libnvonnxparser.so.8"))
+ ctypes.CDLL(os.path.join(trt_libs_path, "libnvparsers.so.8"))
+ ctypes.CDLL(os.path.join(trt_libs_path, "libnvinfer_builder_resource.so.8.6.1"))
+ import tensorrt as trt
+
+logger = trt.Logger(trt.Logger.ERROR)
+trt.init_libnvinfer_plugins(logger, "")
+
+
+def _cudaGetErrorEnum(error):
+ if isinstance(error, cuda.CUresult):
+ err, name = cuda.cuGetErrorName(error)
+ return name if err == cuda.CUresult.CUDA_SUCCESS else ""
+ elif isinstance(error, cudart.cudaError_t):
+ return cudart.cudaGetErrorName(error)[1]
+ elif isinstance(error, nvrtc.nvrtcResult):
+ return nvrtc.nvrtcGetErrorString(error)[1]
+ else:
+ raise RuntimeError("Unknown error type: {}".format(error))
+
+
+def checkCudaErrors(result):
+ if result[0].value:
+ raise RuntimeError(
+ "CUDA error code={}({})".format(
+ result[0].value, _cudaGetErrorEnum(result[0])
+ )
+ )
+ if len(result) == 1:
+ return None
+ elif len(result) == 2:
+ return result[1]
+ else:
+ return result[1:]
+
+
+class MyOutputAllocator(trt.IOutputAllocator):
+ def __init__(self) -> None:
+ super().__init__()
+ # members for outside use
+ self.shape = None
+ self.n_bytes = 0
+ self.address = 0
+
+ def reallocate_output(self, tensor_name, old_address, size, alignment) -> int:
+ return self.reallocate_common(tensor_name, old_address, size, alignment)
+
+ def reallocate_output_async(
+ self, tensor_name, old_address, size, alignment, stream
+ ) -> int:
+ return self.reallocate_common(tensor_name, old_address, size, alignment, stream)
+
+ def notify_shape(self, tensor_name, shape):
+ self.shape = shape
+ return
+
+ def reallocate_common(
+ self, tensor_name, old_address, size, alignment, stream=-1
+ ): # not necessary API
+ if size <= self.n_bytes:
+ return old_address
+ if old_address != 0:
+ checkCudaErrors(cudart.cudaFree(old_address))
+ if stream == -1:
+ address = checkCudaErrors(cudart.cudaMalloc(size))
+ else:
+ address = checkCudaErrors(cudart.cudaMallocAsync(size, stream))
+ self.n_bytes = size
+ self.address = address
+ return address
+
+
+class TRTWrapper:
+ def __init__(
+ self,
+ trt_file: str,
+ plugin_file_list: list = [],
+ ) -> None:
+ # Load custom plugins
+ for plugin_file in plugin_file_list:
+ ctypes.cdll.LoadLibrary(plugin_file)
+
+ # Load engine bytes from file
+ self.model = trt_file
+ with open(trt_file, "rb") as f, trt.Runtime(logger) as runtime:
+ assert runtime
+ self.engine = runtime.deserialize_cuda_engine(f.read())
+ assert self.engine
+ self.buffer = OrderedDict()
+ self.output_allocator_map = OrderedDict()
+ self.context = self.engine.create_execution_context()
+ return
+
+ def setup(self, input_data: dict = {}) -> None:
+ for name, value in self.buffer.items():
+ _, device_buffer, _ = value
+ if (
+ device_buffer is not None
+ and device_buffer != 0
+ and name not in self.output_allocator_map
+ ):
+ checkCudaErrors(cudart.cudaFree(device_buffer))
+ self.buffer[name][1] = None
+ self.buffer[name][2] = 0
+ self.tensor_name_list = [
+ self.engine.get_tensor_name(i) for i in range(self.engine.num_io_tensors)
+ ]
+ self.n_input = sum(
+ [
+ self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT
+ for name in self.tensor_name_list
+ ]
+ )
+ self.n_output = self.engine.num_io_tensors - self.n_input
+
+ for name, data in input_data.items():
+ if self.engine.get_tensor_location(name) == trt.TensorLocation.DEVICE:
+ self.context.set_input_shape(name, data.shape)
+ else:
+ self.context.set_tensor_address(name, data.ctypes.data)
+
+ # Prepare work before inference
+ for name in self.tensor_name_list:
+ data_type = self.engine.get_tensor_dtype(name)
+ runtime_shape = self.context.get_tensor_shape(name)
+ if name not in self.output_allocator_map:
+ if -1 in runtime_shape:
+ # for Data-Dependent-Shape (DDS) output, "else" branch for normal output
+ n_byte = 0 # self.context.get_max_output_size(name)
+ self.output_allocator_map[name] = MyOutputAllocator()
+ self.context.set_output_allocator(
+ name, self.output_allocator_map[name]
+ )
+ host_buffer = np.empty(0, dtype=trt.nptype(data_type))
+ device_buffer = None
+ else:
+ n_byte = trt.volume(runtime_shape) * data_type.itemsize
+ host_buffer = np.empty(runtime_shape, dtype=trt.nptype(data_type))
+ if (
+ self.engine.get_tensor_location(name)
+ == trt.TensorLocation.DEVICE
+ ):
+ device_buffer = checkCudaErrors(cudart.cudaMalloc(n_byte))
+ else:
+ device_buffer = None
+ self.buffer[name] = [host_buffer, device_buffer, n_byte]
+ else:
+ # for DDS output, don't need to reallocate
+ pass
+
+ for name, data in input_data.items():
+ self.buffer[name][0] = np.ascontiguousarray(data)
+
+ for name in self.tensor_name_list:
+ if self.engine.get_tensor_location(name) == trt.TensorLocation.DEVICE:
+ if self.buffer[name][1] is not None:
+ self.context.set_tensor_address(name, self.buffer[name][1])
+ elif self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
+ self.context.set_tensor_address(name, self.buffer[name][0].ctypes.data)
+
+ return
+
+ def infer(self, stream=0) -> None:
+ # Do inference and print output
+ for name in self.tensor_name_list:
+ if (
+ self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT
+ and self.engine.get_tensor_location(name) == trt.TensorLocation.DEVICE
+ ):
+ cudart.cudaMemcpy(
+ self.buffer[name][1],
+ self.buffer[name][0].ctypes.data,
+ self.buffer[name][2],
+ cudart.cudaMemcpyKind.cudaMemcpyHostToDevice,
+ )
+
+ self.context.execute_async_v3(stream)
+
+ for name in self.output_allocator_map:
+ myOutputAllocator = self.context.get_output_allocator(name)
+ runtime_shape = myOutputAllocator.shape
+ data_type = self.engine.get_tensor_dtype(name)
+ host_buffer = np.empty(runtime_shape, dtype=trt.nptype(data_type))
+ device_buffer = myOutputAllocator.address
+ n_bytes = trt.volume(runtime_shape) * data_type.itemsize
+ self.buffer[name] = [host_buffer, device_buffer, n_bytes]
+
+ for name in self.tensor_name_list:
+ if (
+ self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT
+ and self.engine.get_tensor_location(name) == trt.TensorLocation.DEVICE
+ ):
+ cudart.cudaMemcpy(
+ self.buffer[name][0].ctypes.data,
+ self.buffer[name][1],
+ self.buffer[name][2],
+ cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost,
+ )
+
+ return
+
+ def infer_async(self, stream=0) -> None:
+ # Do inference and print output
+ for name in self.tensor_name_list:
+ if (
+ self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT
+ and self.engine.get_tensor_location(name) == trt.TensorLocation.DEVICE
+ ):
+ cudart.cudaMemcpyAsync(
+ self.buffer[name][1],
+ self.buffer[name][0].ctypes.data,
+ self.buffer[name][2],
+ cudart.cudaMemcpyKind.cudaMemcpyHostToDevice,
+ stream=stream,
+ )
+
+ self.context.execute_async_v3(stream)
+
+ for name in self.output_allocator_map:
+ myOutputAllocator = self.context.get_output_allocator(name)
+ runtime_shape = myOutputAllocator.shape
+ data_type = self.engine.get_tensor_dtype(name)
+ host_buffer = np.empty(runtime_shape, dtype=trt.nptype(data_type))
+ device_buffer = myOutputAllocator.address
+ n_bytes = trt.volume(runtime_shape) * data_type.itemsize
+ self.buffer[name] = [host_buffer, device_buffer, n_bytes]
+
+ for name in self.tensor_name_list:
+ if (
+ self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT
+ and self.engine.get_tensor_location(name) == trt.TensorLocation.DEVICE
+ ):
+ cudart.cudaMemcpyAsync(
+ self.buffer[name][0].ctypes.data,
+ self.buffer[name][1],
+ self.buffer[name][2],
+ cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost,
+ stream=stream,
+ )
+
+ return
+
+ def __del__(self):
+ if hasattr(self, "buffer") and self.buffer is not None:
+ for _, device_buffer, _ in self.buffer.values():
+ if (
+ device_buffer is not None
+ and device_buffer != 0
+ and cudart is not None
+ ):
+ try:
+ checkCudaErrors(cudart.cudaFree(device_buffer))
+ except TypeError:
+ pass
+ return
diff --git a/environment.yaml b/environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bae39453d68edbb7af3f88de11bb3b89220d51e5
--- /dev/null
+++ b/environment.yaml
@@ -0,0 +1,142 @@
+name: ditto
+channels:
+ - pytorch
+ - nvidia
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - _openmp_mutex=5.1=1_gnu
+ - blas=1.0=mkl
+ - brotli-python=1.0.9=py310h6a678d5_9
+ - bzip2=1.0.8=h5eee18b_6
+ - ca-certificates=2024.11.26=h06a4308_0
+ - certifi=2024.12.14=py310h06a4308_0
+ - charset-normalizer=3.3.2=pyhd3eb1b0_0
+ - cuda-cudart=12.1.105=0
+ - cuda-cupti=12.1.105=0
+ - cuda-libraries=12.1.0=0
+ - cuda-nvrtc=12.1.105=0
+ - cuda-nvtx=12.1.105=0
+ - cuda-opencl=12.6.77=0
+ - cuda-runtime=12.1.0=0
+ - cuda-version=12.6=3
+ - ffmpeg=4.3=hf484d3e_0
+ - filelock=3.13.1=py310h06a4308_0
+ - freetype=2.12.1=h4a9f257_0
+ - giflib=5.2.2=h5eee18b_0
+ - gmp=6.2.1=h295c915_3
+ - gmpy2=2.1.2=py310heeb90bb_0
+ - gnutls=3.6.15=he1e5248_0
+ - idna=3.7=py310h06a4308_0
+ - intel-openmp=2023.1.0=hdb19cb5_46306
+ - jinja2=3.1.4=py310h06a4308_1
+ - jpeg=9e=h5eee18b_3
+ - lame=3.100=h7b6447c_0
+ - lcms2=2.16=hb9589c4_0
+ - ld_impl_linux-64=2.40=h12ee557_0
+ - lerc=4.0.0=h6a678d5_0
+ - libcublas=12.1.0.26=0
+ - libcufft=11.0.2.4=0
+ - libcufile=1.11.1.6=0
+ - libcurand=10.3.7.77=0
+ - libcusolver=11.4.4.55=0
+ - libcusparse=12.0.2.55=0
+ - libdeflate=1.22=h5eee18b_0
+ - libffi=3.4.4=h6a678d5_1
+ - libgcc-ng=11.2.0=h1234567_1
+ - libgomp=11.2.0=h1234567_1
+ - libiconv=1.16=h5eee18b_3
+ - libidn2=2.3.4=h5eee18b_0
+ - libjpeg-turbo=2.0.0=h9bf148f_0
+ - libnpp=12.0.2.50=0
+ - libnvjitlink=12.1.105=0
+ - libnvjpeg=12.1.1.14=0
+ - libpng=1.6.39=h5eee18b_0
+ - libstdcxx-ng=11.2.0=h1234567_1
+ - libtasn1=4.19.0=h5eee18b_0
+ - libtiff=4.5.1=hffd6297_1
+ - libunistring=0.9.10=h27cfd23_0
+ - libuuid=1.41.5=h5eee18b_0
+ - libwebp=1.3.2=h11a3e52_0
+ - libwebp-base=1.3.2=h5eee18b_1
+ - llvm-openmp=14.0.6=h9e868ea_0
+ - lz4-c=1.9.4=h6a678d5_1
+ - markupsafe=2.1.3=py310h5eee18b_0
+ - mkl=2023.1.0=h213fc3f_46344
+ - mkl-service=2.4.0=py310h5eee18b_1
+ - mkl_fft=1.3.11=py310h5eee18b_0
+ - mkl_random=1.2.8=py310h1128e8f_0
+ - mpc=1.1.0=h10f8cd9_1
+ - mpfr=4.0.2=hb69a4c5_1
+ - mpmath=1.3.0=py310h06a4308_0
+ - ncurses=6.4=h6a678d5_0
+ - nettle=3.7.3=hbbd107a_1
+ - networkx=3.2.1=py310h06a4308_0
+ - numpy=2.0.1=py310h5f9d8c6_1
+ - numpy-base=2.0.1=py310hb5e798b_1
+ - openh264=2.1.1=h4ff587b_0
+ - openjpeg=2.5.2=he7f1fd0_0
+ - openssl=3.0.15=h5eee18b_0
+ - pillow=11.0.0=py310hcea889d_1
+ - pip=24.2=py310h06a4308_0
+ - pysocks=1.7.1=py310h06a4308_0
+ - python=3.10.16=he870216_1
+ - pytorch=2.5.1=py3.10_cuda12.1_cudnn9.1.0_0
+ - pytorch-cuda=12.1=ha16c6d3_6
+ - pytorch-mutex=1.0=cuda
+ - pyyaml=6.0.2=py310h5eee18b_0
+ - readline=8.2=h5eee18b_0
+ - requests=2.32.3=py310h06a4308_1
+ - setuptools=75.1.0=py310h06a4308_0
+ - sqlite=3.45.3=h5eee18b_0
+ - sympy=1.13.3=py310h06a4308_0
+ - tbb=2021.8.0=hdb19cb5_0
+ - tk=8.6.14=h39e8969_0
+ - torchaudio=2.5.1=py310_cu121
+ - torchtriton=3.1.0=py310
+ - torchvision=0.20.1=py310_cu121
+ - typing_extensions=4.12.2=py310h06a4308_0
+ - tzdata=2024b=h04d1e81_0
+ - urllib3=2.2.3=py310h06a4308_0
+ - wheel=0.44.0=py310h06a4308_0
+ - xz=5.4.6=h5eee18b_1
+ - yaml=0.2.5=h7b6447c_0
+ - zlib=1.2.13=h5eee18b_1
+ - zstd=1.5.6=hc292b87_0
+ - pip:
+ - audioread==3.0.1
+ - cffi==1.17.1
+ - cuda-python==12.6.2.post1
+ - cython==3.0.11
+ - decorator==5.1.1
+ - filetype==1.2.0
+ - imageio==2.36.1
+ - imageio-ffmpeg==0.5.1
+ - joblib==1.4.2
+ - lazy-loader==0.4
+ - librosa==0.10.2.post1
+ - llvmlite==0.43.0
+ - msgpack==1.1.0
+ - numba==0.60.0
+ - nvidia-cublas-cu12==12.6.4.1
+ - nvidia-cuda-runtime-cu12==12.6.77
+ - nvidia-cudnn-cu12==9.6.0.74
+ - opencv-python-headless==4.10.0.84
+ - packaging==24.2
+ - platformdirs==4.3.6
+ - pooch==1.8.2
+ - pycparser==2.22
+ - scikit-image==0.25.0
+ - scikit-learn==1.6.0
+ - scipy==1.15.0
+ - soundfile==0.13.0
+ - soxr==0.5.0.post1
+ - tensorrt==8.6.1
+ - tensorrt-bindings==8.6.1
+ - tensorrt-libs==8.6.1
+ - threadpoolctl==3.5.0
+ - tifffile==2024.12.12
+ - tqdm==4.67.1
+ - polygraphy
+ - colored
+prefix: /opt/conda/envs/ditto
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..155109ce6ab260eb1fa04ea4459b9107c19e0f9f
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,89 @@
+import librosa
+import math
+import os
+import numpy as np
+import random
+import torch
+import pickle
+
+from stream_pipeline_offline import StreamSDK
+
+
+def seed_everything(seed):
+ os.environ["PYTHONHASHSEED"] = str(seed)
+ os.environ["PL_GLOBAL_SEED"] = str(seed)
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def load_pkl(pkl):
+ with open(pkl, "rb") as f:
+ return pickle.load(f)
+
+
+def run(SDK: StreamSDK, audio_path: str, source_path: str, output_path: str, more_kwargs: str | dict = {}):
+
+ if isinstance(more_kwargs, str):
+ more_kwargs = load_pkl(more_kwargs)
+ setup_kwargs = more_kwargs.get("setup_kwargs", {})
+ run_kwargs = more_kwargs.get("run_kwargs", {})
+
+ SDK.setup(source_path, output_path, **setup_kwargs)
+
+ audio, sr = librosa.core.load(audio_path, sr=16000)
+ num_f = math.ceil(len(audio) / 16000 * 25)
+
+ fade_in = run_kwargs.get("fade_in", -1)
+ fade_out = run_kwargs.get("fade_out", -1)
+ ctrl_info = run_kwargs.get("ctrl_info", {})
+ SDK.setup_Nd(N_d=num_f, fade_in=fade_in, fade_out=fade_out, ctrl_info=ctrl_info)
+
+ online_mode = SDK.online_mode
+ if online_mode:
+ chunksize = run_kwargs.get("chunksize", (3, 5, 2))
+ audio = np.concatenate([np.zeros((chunksize[0] * 640,), dtype=np.float32), audio], 0)
+ split_len = int(sum(chunksize) * 0.04 * 16000) + 80 # 6480
+ for i in range(0, len(audio), chunksize[1] * 640):
+ audio_chunk = audio[i:i + split_len]
+ if len(audio_chunk) < split_len:
+ audio_chunk = np.pad(audio_chunk, (0, split_len - len(audio_chunk)), mode="constant")
+ SDK.run_chunk(audio_chunk, chunksize)
+ else:
+ aud_feat = SDK.wav2feat.wav2feat(audio)
+ SDK.audio2motion_queue.put(aud_feat)
+ SDK.close()
+
+ cmd = f'ffmpeg -loglevel error -y -i "{SDK.tmp_output_path}" -i "{audio_path}" -map 0:v -map 1:a -c:v copy -c:a aac "{output_path}"'
+ print(cmd)
+ os.system(cmd)
+
+ print(output_path)
+
+
+if __name__ == "__main__":
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--data_root", type=str, default="./checkpoints/ditto_trt_Ampere_Plus", help="path to trt data_root")
+ parser.add_argument("--cfg_pkl", type=str, default="./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl", help="path to cfg_pkl")
+
+ parser.add_argument("--audio_path", type=str, help="path to input wav")
+ parser.add_argument("--source_path", type=str, help="path to input image")
+ parser.add_argument("--output_path", type=str, help="path to output mp4")
+ args = parser.parse_args()
+
+ # init sdk
+ data_root = args.data_root # model dir
+ cfg_pkl = args.cfg_pkl # cfg pkl
+ SDK = StreamSDK(cfg_pkl, data_root)
+
+ # input args
+ audio_path = args.audio_path # .wav
+ source_path = args.source_path # video|image
+ output_path = args.output_path # .mp4
+
+ # run
+ # seed_everything(1024)
+ run(SDK, audio_path, source_path, output_path)
diff --git a/scripts/cvt_onnx_to_trt.py b/scripts/cvt_onnx_to_trt.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e7b1c0e2a268af31188bf6956d2e07bcdee736a
--- /dev/null
+++ b/scripts/cvt_onnx_to_trt.py
@@ -0,0 +1,164 @@
+import os
+import torch
+import argparse
+
+
+def onnx_to_trt(onnx_file, trt_file, fp16=False, more_cmd=None):
+ cap = torch.cuda.get_device_capability()
+ if cap[0] >= 8:
+ compatiable = "--hardware-compatibility-level=Ampere_Plus"
+ else:
+ compatiable = ""
+ cmd = [
+ "polygraphy",
+ "convert",
+ onnx_file,
+ "-o",
+ trt_file,
+ compatiable,
+ "--fp16" if fp16 else "",
+ f"--builder-optimization-level=5",
+ ]
+ if more_cmd:
+ cmd = cmd + more_cmd
+ print(" ".join(cmd))
+ os.system(" ".join(cmd))
+
+
+def onnx_to_trt_for_gridsample(onnx_file, trt_file, fp16=False, plugin_file="./libgrid_sample_3d_plugin.so"):
+ import tensorrt as trt
+
+ logger = trt.Logger(trt.Logger.INFO)
+ trt.init_libnvinfer_plugins(logger, "")
+ plugin_libs = [plugin_file]
+
+ onnx_path = onnx_file
+ engine_path = trt_file
+
+ builder = trt.Builder(logger)
+ for pluginlib in plugin_libs:
+ builder.get_plugin_registry().load_library(pluginlib)
+ network = builder.create_network(
+ 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
+ )
+
+ parser = trt.OnnxParser(network, logger)
+ res = parser.parse_from_file(onnx_path) # parse from file
+ if not res:
+ print(f"Fail parsing {onnx_path}")
+ for i in range(parser.num_errors): # Get error information
+ error = parser.get_error(i)
+ print(error) # Print error information
+ print(
+ f"{error.code() = }\n{error.file() = }\n{error.func() = }\n{error.line() = }\n{error.local_function_stack_size() = }"
+ )
+ print(
+ f"{error.local_function_stack() = }\n{error.node_name() = }\n{error.node_operator() = }\n{error.node() = }"
+ )
+ parser.clear_errors()
+ config = builder.create_builder_config()
+ # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32)
+ config.builder_optimization_level = 5
+ # Set the flag of hardware compatibility, Hardware-compatible engines are only supported on Ampere and beyond
+ cap = torch.cuda.get_device_capability()
+ if cap[0] >= 8:
+ compatible = True
+ else:
+ compatible = False
+
+ if compatible:
+ config.hardware_compatibility_level = (
+ trt.HardwareCompatibilityLevel.AMPERE_PLUS
+ )
+
+ if fp16:
+ config.set_flag(trt.BuilderFlag.FP16)
+ config.set_flag(trt.BuilderFlag.PREFER_PRECISION_CONSTRAINTS)
+ config.set_preview_feature(trt.PreviewFeature.PROFILE_SHARING_0806, True)
+ exclude_list = [
+ "SHAPE",
+ "ASSERTION",
+ "SHUFFLE",
+ "IDENTITY",
+ "CONSTANT",
+ "CONCATENATION",
+ "GATHER",
+ "SLICE",
+ "CONDITION",
+ "CONDITIONAL_INPUT",
+ "CONDITIONAL_OUTPUT",
+ "FILL",
+ "NON_ZERO",
+ "ONE_HOT",
+ ]
+ for i in range(0, network.num_layers):
+ layer = network.get_layer(i)
+ if str(layer.type)[10:] in exclude_list:
+ continue
+ if "GridSample" in layer.name:
+ print(f"set {layer.name} to float32")
+ layer.precision = trt.float32
+ config.plugins_to_serialize = plugin_libs
+ engineString = builder.build_serialized_network(network, config)
+ if engineString is not None:
+ with open(engine_path, "wb") as f:
+ f.write(engineString)
+
+
+def main(onnx_dir, trt_dir, grid_sample_plugin_file=""):
+ names = [i[:-5] for i in os.listdir(onnx_dir) if i.endswith(".onnx")]
+ for name in names:
+ if name == "warp_network_ori":
+ continue
+
+ print("=" * 20, f"{name} start", "=" * 20)
+
+ fp16 = False if name in {"motion_extractor", "hubert", "wavlm"} or name.startswith("lmdm") else True
+
+ more_cmd = None
+ if name == "wavlm":
+ more_cmd = [
+ "--trt-min-shapes audio:[1,1000]",
+ "--trt-max-shapes audio:[1,320080]",
+ "--trt-opt-shapes audio:[1,320080]",
+ ]
+ elif name == "hubert":
+ more_cmd = [
+ "--trt-min-shapes input_values:[1,3240]",
+ "--trt-max-shapes input_values:[1,12960]",
+ "--trt-opt-shapes input_values:[1,6480]",
+ ]
+
+
+ onnx_file = f"{onnx_dir}/{name}.onnx"
+ trt_file = f"{trt_dir}/{name}_fp{16 if fp16 else 32}.engine"
+
+ if os.path.isfile(trt_file):
+ print("=" * 20, f"{name} skip", "=" * 20)
+ continue
+
+ if name == "warp_network":
+ onnx_to_trt_for_gridsample(onnx_file, trt_file, fp16, plugin_file=grid_sample_plugin_file)
+ else:
+ onnx_to_trt(onnx_file, trt_file, fp16, more_cmd=more_cmd)
+
+ print("=" * 20, f"{name} done", "=" * 20)
+
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--onnx_dir", type=str, help="input onnx dir")
+ parser.add_argument("--trt_dir", type=str, help="output trt dir")
+ args = parser.parse_args()
+
+ onnx_dir = args.onnx_dir
+ trt_dir = args.trt_dir
+
+ assert os.path.isdir(onnx_dir)
+ os.makedirs(trt_dir, exist_ok=True)
+
+ grid_sample_plugin_file = os.path.join(onnx_dir, "libgrid_sample_3d_plugin.so")
+ main(onnx_dir, trt_dir, grid_sample_plugin_file)
+
diff --git a/stream_pipeline_offline.py b/stream_pipeline_offline.py
new file mode 100644
index 0000000000000000000000000000000000000000..fea86099b0eebc881cb2be2d85246da69a265a03
--- /dev/null
+++ b/stream_pipeline_offline.py
@@ -0,0 +1,545 @@
+import threading
+import queue
+import numpy as np
+import traceback
+from tqdm import tqdm
+
+from core.atomic_components.avatar_registrar import AvatarRegistrar, smooth_x_s_info_lst
+from core.atomic_components.condition_handler import ConditionHandler, _mirror_index
+from core.atomic_components.audio2motion import Audio2Motion
+from core.atomic_components.motion_stitch import MotionStitch
+from core.atomic_components.warp_f3d import WarpF3D
+from core.atomic_components.decode_f3d import DecodeF3D
+from core.atomic_components.putback import PutBack
+from core.atomic_components.writer import VideoWriterByImageIO
+from core.atomic_components.wav2feat import Wav2Feat
+from core.atomic_components.cfg import parse_cfg, print_cfg
+
+
+class StreamSDK:
+ def __init__(self, cfg_pkl, data_root, **kwargs):
+
+ [
+ avatar_registrar_cfg,
+ condition_handler_cfg,
+ lmdm_cfg,
+ stitch_network_cfg,
+ warp_network_cfg,
+ decoder_cfg,
+ wav2feat_cfg,
+ default_kwargs,
+ ] = parse_cfg(cfg_pkl, data_root, kwargs)
+
+ self.default_kwargs = default_kwargs
+
+ self.avatar_registrar = AvatarRegistrar(**avatar_registrar_cfg)
+ self.condition_handler = ConditionHandler(**condition_handler_cfg)
+ self.audio2motion = Audio2Motion(lmdm_cfg)
+ self.motion_stitch = MotionStitch(stitch_network_cfg)
+ self.warp_f3d = WarpF3D(warp_network_cfg)
+ self.decode_f3d = DecodeF3D(decoder_cfg)
+ self.putback = PutBack()
+
+ self.wav2feat = Wav2Feat(**wav2feat_cfg)
+
+ def _merge_kwargs(self, default_kwargs, run_kwargs):
+ for k, v in default_kwargs.items():
+ if k not in run_kwargs:
+ run_kwargs[k] = v
+ return run_kwargs
+
+ def setup_Nd(self, N_d, fade_in=-1, fade_out=-1, ctrl_info=None):
+ # for eye open at video end
+ self.motion_stitch.set_Nd(N_d)
+
+ # for fade in/out alpha
+ if ctrl_info is None:
+ ctrl_info = self.ctrl_info
+ if fade_in > 0:
+ for i in range(fade_in):
+ alpha = i / fade_in
+ item = ctrl_info.get(i, {})
+ item["fade_alpha"] = alpha
+ ctrl_info[i] = item
+ if fade_out > 0:
+ ss = N_d - fade_out - 1
+ ee = N_d - 1
+ for i in range(ss, N_d):
+ alpha = max((ee - i) / (ee - ss), 0)
+ item = ctrl_info.get(i, {})
+ item["fade_alpha"] = alpha
+ ctrl_info[i] = item
+ self.ctrl_info = ctrl_info
+
+ def setup(self, source_path, output_path, **kwargs):
+
+ # ======== Prepare Options ========
+ kwargs = self._merge_kwargs(self.default_kwargs, kwargs)
+ print("=" * 20, "setup kwargs", "=" * 20)
+ print_cfg(**kwargs)
+ print("=" * 50)
+
+ # -- avatar_registrar: template cfg --
+ self.max_size = kwargs.get("max_size", 1920)
+ self.template_n_frames = kwargs.get("template_n_frames", -1)
+
+ # -- avatar_registrar: crop cfg --
+ self.crop_scale = kwargs.get("crop_scale", 2.3)
+ self.crop_vx_ratio = kwargs.get("crop_vx_ratio", 0)
+ self.crop_vy_ratio = kwargs.get("crop_vy_ratio", -0.125)
+ self.crop_flag_do_rot = kwargs.get("crop_flag_do_rot", True)
+
+ # -- avatar_registrar: smo for video --
+ self.smo_k_s = kwargs.get('smo_k_s', 13)
+
+ # -- condition_handler: ECS --
+ self.emo = kwargs.get("emo", 4) # int | [int] | [[int]] | numpy
+ self.eye_f0_mode = kwargs.get("eye_f0_mode", False) # for video
+ self.ch_info = kwargs.get("ch_info", None) # dict of np.ndarray
+
+ # -- audio2motion: setup --
+ self.overlap_v2 = kwargs.get("overlap_v2", 10)
+ self.fix_kp_cond = kwargs.get("fix_kp_cond", 0)
+ self.fix_kp_cond_dim = kwargs.get("fix_kp_cond_dim", None) # [ds,de]
+ self.sampling_timesteps = kwargs.get("sampling_timesteps", 50)
+ self.online_mode = kwargs.get("online_mode", False)
+ self.v_min_max_for_clip = kwargs.get('v_min_max_for_clip', None)
+ self.smo_k_d = kwargs.get("smo_k_d", 3)
+
+ # -- motion_stitch: setup --
+ self.N_d = kwargs.get("N_d", -1)
+ self.use_d_keys = kwargs.get("use_d_keys", None)
+ self.relative_d = kwargs.get("relative_d", True)
+ self.drive_eye = kwargs.get("drive_eye", None) # None: true4image, false4video
+ self.delta_eye_arr = kwargs.get("delta_eye_arr", None)
+ self.delta_eye_open_n = kwargs.get("delta_eye_open_n", 0)
+ self.fade_type = kwargs.get("fade_type", "") # "" | "d0" | "s"
+ self.fade_out_keys = kwargs.get("fade_out_keys", ("exp",))
+ self.flag_stitching = kwargs.get("flag_stitching", True)
+
+ self.ctrl_info = kwargs.get("ctrl_info", dict())
+ self.overall_ctrl_info = kwargs.get("overall_ctrl_info", dict())
+ """
+ ctrl_info: list or dict
+ {
+ fid: ctrl_kwargs
+ }
+
+ ctrl_kwargs (see motion_stitch.py):
+ fade_alpha
+ fade_out_keys
+
+ delta_pitch
+ delta_yaw
+ delta_roll
+ """
+
+ # only hubert support online mode
+ assert self.wav2feat.support_streaming or not self.online_mode
+
+ # ======== Register Avatar ========
+ crop_kwargs = {
+ "crop_scale": self.crop_scale,
+ "crop_vx_ratio": self.crop_vx_ratio,
+ "crop_vy_ratio": self.crop_vy_ratio,
+ "crop_flag_do_rot": self.crop_flag_do_rot,
+ }
+ n_frames = self.template_n_frames if self.template_n_frames > 0 else self.N_d
+ source_info = self.avatar_registrar(
+ source_path,
+ max_dim=self.max_size,
+ n_frames=n_frames,
+ **crop_kwargs,
+ )
+
+ if len(source_info["x_s_info_lst"]) > 1 and self.smo_k_s > 1:
+ source_info["x_s_info_lst"] = smooth_x_s_info_lst(source_info["x_s_info_lst"], smo_k=self.smo_k_s)
+
+ self.source_info = source_info
+ self.source_info_frames = len(source_info["x_s_info_lst"])
+
+ # ======== Setup Condition Handler ========
+ self.condition_handler.setup(source_info, self.emo, eye_f0_mode=self.eye_f0_mode, ch_info=self.ch_info)
+
+ # ======== Setup Audio2Motion (LMDM) ========
+ x_s_info_0 = self.condition_handler.x_s_info_0
+ self.audio2motion.setup(
+ x_s_info_0,
+ overlap_v2=self.overlap_v2,
+ fix_kp_cond=self.fix_kp_cond,
+ fix_kp_cond_dim=self.fix_kp_cond_dim,
+ sampling_timesteps=self.sampling_timesteps,
+ online_mode=self.online_mode,
+ v_min_max_for_clip=self.v_min_max_for_clip,
+ smo_k_d=self.smo_k_d,
+ )
+
+ # ======== Setup Motion Stitch ========
+ is_image_flag = source_info["is_image_flag"]
+ x_s_info = source_info['x_s_info_lst'][0]
+ self.motion_stitch.setup(
+ N_d=self.N_d,
+ use_d_keys=self.use_d_keys,
+ relative_d=self.relative_d,
+ drive_eye=self.drive_eye,
+ delta_eye_arr=self.delta_eye_arr,
+ delta_eye_open_n=self.delta_eye_open_n,
+ fade_out_keys=self.fade_out_keys,
+ fade_type=self.fade_type,
+ flag_stitching=self.flag_stitching,
+ is_image_flag=is_image_flag,
+ x_s_info=x_s_info,
+ d0=None,
+ ch_info=self.ch_info,
+ overall_ctrl_info=self.overall_ctrl_info,
+ )
+
+ # ======== Video Writer ========
+ self.output_path = output_path
+ self.tmp_output_path = output_path + ".tmp.mp4"
+ self.writer = VideoWriterByImageIO(self.tmp_output_path)
+ self.writer_pbar = tqdm(desc="writer")
+
+ # ======== Audio Feat Buffer ========
+ if self.online_mode:
+ # buffer: seq_frames - valid_clip_len
+ self.audio_feat = self.wav2feat.wav2feat(np.zeros((self.overlap_v2 * 640,), dtype=np.float32), sr=16000)
+ assert len(self.audio_feat) == self.overlap_v2, f"{len(self.audio_feat)}"
+ else:
+ self.audio_feat = np.zeros((0, self.wav2feat.feat_dim), dtype=np.float32)
+ self.cond_idx_start = 0 - len(self.audio_feat)
+
+ # ======== Setup Worker Threads ========
+ QUEUE_MAX_SIZE = 100
+ # self.QUEUE_TIMEOUT = None
+
+ self.worker_exception = None
+ self.stop_event = threading.Event()
+
+ self.audio2motion_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE)
+ self.motion_stitch_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE)
+ self.warp_f3d_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE)
+ self.decode_f3d_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE)
+ self.putback_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE)
+ self.writer_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE)
+
+ self.thread_list = [
+ threading.Thread(target=self.audio2motion_worker),
+ threading.Thread(target=self.motion_stitch_worker),
+ threading.Thread(target=self.warp_f3d_worker),
+ threading.Thread(target=self.decode_f3d_worker),
+ threading.Thread(target=self.putback_worker),
+ threading.Thread(target=self.writer_worker),
+ ]
+
+ for thread in self.thread_list:
+ thread.start()
+
+ def _get_ctrl_info(self, fid):
+ try:
+ if isinstance(self.ctrl_info, dict):
+ return self.ctrl_info.get(fid, {})
+ elif isinstance(self.ctrl_info, list):
+ return self.ctrl_info[fid]
+ else:
+ return {}
+ except Exception as e:
+ traceback.print_exc()
+ return {}
+
+ def writer_worker(self):
+ try:
+ self._writer_worker()
+ except Exception as e:
+ self.worker_exception = e
+ self.stop_event.set()
+
+ def _writer_worker(self):
+ while not self.stop_event.is_set():
+ try:
+ item = self.writer_queue.get(timeout=1)
+ except queue.Empty:
+ continue
+
+ if item is None:
+ break
+ res_frame_rgb = item
+ self.writer(res_frame_rgb, fmt="rgb")
+ self.writer_pbar.update()
+
+ def putback_worker(self):
+ try:
+ self._putback_worker()
+ except Exception as e:
+ self.worker_exception = e
+ self.stop_event.set()
+
+ def _putback_worker(self):
+ while not self.stop_event.is_set():
+ try:
+ item = self.putback_queue.get(timeout=1)
+ except queue.Empty:
+ continue
+ if item is None:
+ self.writer_queue.put(None)
+ break
+ frame_idx, render_img = item
+ frame_rgb = self.source_info["img_rgb_lst"][frame_idx]
+ M_c2o = self.source_info["M_c2o_lst"][frame_idx]
+ res_frame_rgb = self.putback(frame_rgb, render_img, M_c2o)
+ self.writer_queue.put(res_frame_rgb)
+
+ def decode_f3d_worker(self):
+ try:
+ self._decode_f3d_worker()
+ except Exception as e:
+ self.worker_exception = e
+ self.stop_event.set()
+
+ def _decode_f3d_worker(self):
+ while not self.stop_event.is_set():
+ try:
+ item = self.decode_f3d_queue.get(timeout=1)
+ except queue.Empty:
+ continue
+ if item is None:
+ self.putback_queue.put(None)
+ break
+ frame_idx, f_3d = item
+ render_img = self.decode_f3d(f_3d)
+ self.putback_queue.put([frame_idx, render_img])
+
+ def warp_f3d_worker(self):
+ try:
+ self._warp_f3d_worker()
+ except Exception as e:
+ self.worker_exception = e
+ self.stop_event.set()
+
+ def _warp_f3d_worker(self):
+ while not self.stop_event.is_set():
+ try:
+ item = self.warp_f3d_queue.get(timeout=1)
+ except queue.Empty:
+ continue
+ if item is None:
+ self.decode_f3d_queue.put(None)
+ break
+ frame_idx, x_s, x_d = item
+ f_s = self.source_info["f_s_lst"][frame_idx]
+ f_3d = self.warp_f3d(f_s, x_s, x_d)
+ self.decode_f3d_queue.put([frame_idx, f_3d])
+
+ def motion_stitch_worker(self):
+ try:
+ self._motion_stitch_worker()
+ except Exception as e:
+ self.worker_exception = e
+ self.stop_event.set()
+
+ def _motion_stitch_worker(self):
+ while not self.stop_event.is_set():
+ try:
+ item = self.motion_stitch_queue.get(timeout=1)
+ except queue.Empty:
+ continue
+ if item is None:
+ self.warp_f3d_queue.put(None)
+ break
+
+ frame_idx, x_d_info, ctrl_kwargs = item
+ x_s_info = self.source_info["x_s_info_lst"][frame_idx]
+ x_s, x_d = self.motion_stitch(x_s_info, x_d_info, **ctrl_kwargs)
+ self.warp_f3d_queue.put([frame_idx, x_s, x_d])
+
+ def audio2motion_worker(self):
+ try:
+ # self._audio2motion_worker()
+ self._audio2motion_offline()
+ except Exception as e:
+ self.worker_exception = e
+ self.stop_event.set()
+
+ def _audio2motion_offline(self):
+
+ while not self.stop_event.is_set():
+ try:
+ item = self.audio2motion_queue.get(timeout=1) # audio feat
+ except queue.Empty:
+ continue
+
+ if item is None:
+ break
+
+ aud_feat = item
+
+ aud_cond_all = self.condition_handler(aud_feat, 0)
+ seq_frames = self.audio2motion.seq_frames
+ valid_clip_len = self.audio2motion.valid_clip_len
+ num_frames = len(aud_cond_all)
+ idx = 0
+ res_kp_seq = None
+ pbar = tqdm(desc="dit")
+ while idx < num_frames:
+ pbar.update()
+ aud_cond = aud_cond_all[idx:idx + seq_frames][None]
+ if aud_cond.shape[1] < seq_frames:
+ pad = np.stack([aud_cond[:, -1]] * (seq_frames - aud_cond.shape[1]), 1)
+ aud_cond = np.concatenate([aud_cond, pad], 1)
+ res_kp_seq = self.audio2motion(aud_cond, res_kp_seq)
+ idx += valid_clip_len
+
+ pbar.close()
+ res_kp_seq = res_kp_seq[:, :num_frames]
+ res_kp_seq = self.audio2motion._smo(res_kp_seq, 0, res_kp_seq.shape[1])
+
+ x_d_info_list = self.audio2motion.cvt_fmt(res_kp_seq)
+
+ gen_frame_idx = 0
+ for x_d_info in x_d_info_list:
+ frame_idx = _mirror_index(gen_frame_idx, self.source_info_frames)
+ ctrl_kwargs = self._get_ctrl_info(gen_frame_idx)
+
+ while not self.stop_event.is_set():
+ try:
+ self.motion_stitch_queue.put([frame_idx, x_d_info, ctrl_kwargs], timeout=1)
+ break
+ except queue.Full:
+ continue
+ gen_frame_idx += 1
+
+ break
+
+ self.motion_stitch_queue.put(None)
+
+
+ def _audio2motion_worker(self):
+ is_end = False
+ seq_frames = self.audio2motion.seq_frames
+ valid_clip_len = self.audio2motion.valid_clip_len
+ aud_feat_dim = self.wav2feat.feat_dim
+ item_buffer = np.zeros((0, aud_feat_dim), dtype=np.float32)
+
+ res_kp_seq = None
+ res_kp_seq_valid_start = None if self.online_mode else 0
+
+ global_idx = 0 # frame idx, for template
+ local_idx = 0 # for cur audio_feat
+ gen_frame_idx = 0
+ while not self.stop_event.is_set():
+ try:
+ item = self.audio2motion_queue.get(timeout=1) # audio feat
+ except queue.Empty:
+ continue
+ if item is None:
+ is_end = True
+ else:
+ item_buffer = np.concatenate([item_buffer, item], 0)
+
+ if not is_end and item_buffer.shape[0] < valid_clip_len:
+ # wait at least valid_clip_len new item
+ continue
+ else:
+ self.audio_feat = np.concatenate([self.audio_feat, item_buffer], 0)
+ item_buffer = np.zeros((0, aud_feat_dim), dtype=np.float32)
+
+ while True:
+ # print("self.audio_feat.shape:", self.audio_feat.shape, "local_idx:", local_idx, "global_idx:", global_idx)
+ aud_feat = self.audio_feat[local_idx: local_idx+seq_frames]
+ real_valid_len = valid_clip_len
+ if len(aud_feat) == 0:
+ break
+ elif len(aud_feat) < seq_frames:
+ if not is_end:
+ # wait next chunk
+ break
+ else:
+ # final clip: pad to seq_frames
+ real_valid_len = len(aud_feat)
+ pad = np.stack([aud_feat[-1]] * (seq_frames - len(aud_feat)), 0)
+ aud_feat = np.concatenate([aud_feat, pad], 0)
+
+ aud_cond = self.condition_handler(aud_feat, global_idx + self.cond_idx_start)
+ res_kp_seq = self.audio2motion(aud_cond, res_kp_seq)
+ if res_kp_seq_valid_start is None:
+ # online mode, first chunk
+ res_kp_seq_valid_start = res_kp_seq.shape[1] - self.audio2motion.fuse_length
+ d0 = self.audio2motion.cvt_fmt(res_kp_seq[0:1])[0]
+ self.motion_stitch.d0 = d0
+
+ local_idx += real_valid_len
+ global_idx += real_valid_len
+ continue
+ else:
+ valid_res_kp_seq = res_kp_seq[:, res_kp_seq_valid_start: res_kp_seq_valid_start + real_valid_len]
+ x_d_info_list = self.audio2motion.cvt_fmt(valid_res_kp_seq)
+
+ for x_d_info in x_d_info_list:
+ frame_idx = _mirror_index(gen_frame_idx, self.source_info_frames)
+ ctrl_kwargs = self._get_ctrl_info(gen_frame_idx)
+
+ while not self.stop_event.is_set():
+ try:
+ self.motion_stitch_queue.put([frame_idx, x_d_info, ctrl_kwargs], timeout=1)
+ break
+ except queue.Full:
+ continue
+
+ gen_frame_idx += 1
+
+ res_kp_seq_valid_start += real_valid_len
+
+ local_idx += real_valid_len
+ global_idx += real_valid_len
+
+ L = res_kp_seq.shape[1]
+ if L > seq_frames * 2:
+ cut_L = L - seq_frames * 2
+ res_kp_seq = res_kp_seq[:, cut_L:]
+ res_kp_seq_valid_start -= cut_L
+
+ if local_idx >= len(self.audio_feat):
+ break
+
+ L = len(self.audio_feat)
+ if L > seq_frames * 2:
+ cut_L = L - seq_frames * 2
+ self.audio_feat = self.audio_feat[cut_L:]
+ local_idx -= cut_L
+
+ if is_end:
+ break
+
+ self.motion_stitch_queue.put(None)
+
+ def close(self):
+ # flush frames
+ self.audio2motion_queue.put(None)
+ # Wait for worker threads to finish
+ for thread in self.thread_list:
+ thread.join()
+
+ try:
+ self.writer.close()
+ self.writer_pbar.close()
+ except:
+ traceback.print_exc()
+
+ # Check if any worker encountered an exception
+ if self.worker_exception is not None:
+ raise self.worker_exception
+
+ def run_chunk(self, audio_chunk, chunksize=(3, 5, 2)):
+ # only for hubert
+ aud_feat = self.wav2feat(audio_chunk, chunksize=chunksize)
+ while not self.stop_event.is_set():
+ try:
+ self.audio2motion_queue.put(aud_feat, timeout=1)
+ break
+ except queue.Full:
+ continue
+
+
+
+
+
diff --git a/stream_pipeline_online.py b/stream_pipeline_online.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fbfd755af7c134066c98a5945031a632e7ac4da
--- /dev/null
+++ b/stream_pipeline_online.py
@@ -0,0 +1,511 @@
+import threading
+import queue
+import numpy as np
+import traceback
+from tqdm import tqdm
+
+from core.atomic_components.avatar_registrar import AvatarRegistrar, smooth_x_s_info_lst
+from core.atomic_components.condition_handler import ConditionHandler, _mirror_index
+from core.atomic_components.audio2motion import Audio2Motion
+from core.atomic_components.motion_stitch import MotionStitch
+from core.atomic_components.warp_f3d import WarpF3D
+from core.atomic_components.decode_f3d import DecodeF3D
+from core.atomic_components.putback import PutBack
+from core.atomic_components.writer import VideoWriterByImageIO
+from core.atomic_components.wav2feat import Wav2Feat
+from core.atomic_components.cfg import parse_cfg, print_cfg
+
+
+"""
+avatar_registrar_cfg:
+ insightface_det_cfg,
+ landmark106_cfg,
+ landmark203_cfg,
+ landmark478_cfg,
+ appearance_extractor_cfg,
+ motion_extractor_cfg,
+
+condition_handler_cfg:
+ use_emo=True,
+ use_sc=True,
+ use_eye_open=True,
+ use_eye_ball=True,
+ seq_frames=80,
+
+wav2feat_cfg:
+ w2f_cfg,
+ w2f_type
+"""
+
+
+class StreamSDK:
+ def __init__(self, cfg_pkl, data_root, **kwargs):
+
+ [
+ avatar_registrar_cfg,
+ condition_handler_cfg,
+ lmdm_cfg,
+ stitch_network_cfg,
+ warp_network_cfg,
+ decoder_cfg,
+ wav2feat_cfg,
+ default_kwargs,
+ ] = parse_cfg(cfg_pkl, data_root, kwargs)
+
+ self.default_kwargs = default_kwargs
+
+ self.avatar_registrar = AvatarRegistrar(**avatar_registrar_cfg)
+ self.condition_handler = ConditionHandler(**condition_handler_cfg)
+ self.audio2motion = Audio2Motion(lmdm_cfg)
+ self.motion_stitch = MotionStitch(stitch_network_cfg)
+ self.warp_f3d = WarpF3D(warp_network_cfg)
+ self.decode_f3d = DecodeF3D(decoder_cfg)
+ self.putback = PutBack()
+
+ self.wav2feat = Wav2Feat(**wav2feat_cfg)
+
+ def _merge_kwargs(self, default_kwargs, run_kwargs):
+ for k, v in default_kwargs.items():
+ if k not in run_kwargs:
+ run_kwargs[k] = v
+ return run_kwargs
+
+ def setup_Nd(self, N_d, fade_in=-1, fade_out=-1, ctrl_info=None):
+ # for eye open at video end
+ self.motion_stitch.set_Nd(N_d)
+
+ # for fade in/out alpha
+ if ctrl_info is None:
+ ctrl_info = self.ctrl_info
+ if fade_in > 0:
+ for i in range(fade_in):
+ alpha = i / fade_in
+ item = ctrl_info.get(i, {})
+ item["fade_alpha"] = alpha
+ ctrl_info[i] = item
+ if fade_out > 0:
+ ss = N_d - fade_out - 1
+ ee = N_d - 1
+ for i in range(ss, N_d):
+ alpha = max((ee - i) / (ee - ss), 0)
+ item = ctrl_info.get(i, {})
+ item["fade_alpha"] = alpha
+ ctrl_info[i] = item
+ self.ctrl_info = ctrl_info
+
+ def setup(self, source_path, output_path, **kwargs):
+
+ # ======== Prepare Options ========
+ kwargs = self._merge_kwargs(self.default_kwargs, kwargs)
+ print("=" * 20, "setup kwargs", "=" * 20)
+ print_cfg(**kwargs)
+ print("=" * 50)
+
+ # -- avatar_registrar: template cfg --
+ self.max_size = kwargs.get("max_size", 1920)
+ self.template_n_frames = kwargs.get("template_n_frames", -1)
+
+ # -- avatar_registrar: crop cfg --
+ self.crop_scale = kwargs.get("crop_scale", 2.3)
+ self.crop_vx_ratio = kwargs.get("crop_vx_ratio", 0)
+ self.crop_vy_ratio = kwargs.get("crop_vy_ratio", -0.125)
+ self.crop_flag_do_rot = kwargs.get("crop_flag_do_rot", True)
+
+ # -- avatar_registrar: smo for video --
+ self.smo_k_s = kwargs.get('smo_k_s', 13)
+
+ # -- condition_handler: ECS --
+ self.emo = kwargs.get("emo", 4) # int | [int] | [[int]] | numpy
+ self.eye_f0_mode = kwargs.get("eye_f0_mode", False) # for video
+ self.ch_info = kwargs.get("ch_info", None) # dict of np.ndarray
+
+ # -- audio2motion: setup --
+ self.overlap_v2 = kwargs.get("overlap_v2", 10)
+ self.fix_kp_cond = kwargs.get("fix_kp_cond", 0)
+ self.fix_kp_cond_dim = kwargs.get("fix_kp_cond_dim", None) # [ds,de]
+ self.sampling_timesteps = kwargs.get("sampling_timesteps", 50)
+ self.online_mode = kwargs.get("online_mode", False)
+ self.v_min_max_for_clip = kwargs.get('v_min_max_for_clip', None)
+ self.smo_k_d = kwargs.get("smo_k_d", 3)
+
+ # -- motion_stitch: setup --
+ self.N_d = kwargs.get("N_d", -1)
+ self.use_d_keys = kwargs.get("use_d_keys", None)
+ self.relative_d = kwargs.get("relative_d", True)
+ self.drive_eye = kwargs.get("drive_eye", None) # None: true4image, false4video
+ self.delta_eye_arr = kwargs.get("delta_eye_arr", None)
+ self.delta_eye_open_n = kwargs.get("delta_eye_open_n", 0)
+ self.fade_type = kwargs.get("fade_type", "") # "" | "d0" | "s"
+ self.fade_out_keys = kwargs.get("fade_out_keys", ("exp",))
+ self.flag_stitching = kwargs.get("flag_stitching", True)
+
+ self.ctrl_info = kwargs.get("ctrl_info", dict())
+ self.overall_ctrl_info = kwargs.get("overall_ctrl_info", dict())
+ """
+ ctrl_info: list or dict
+ {
+ fid: ctrl_kwargs
+ }
+
+ ctrl_kwargs (see motion_stitch.py):
+ fade_alpha
+ fade_out_keys
+
+ delta_pitch
+ delta_yaw
+ delta_roll
+ """
+
+ # only hubert support online mode
+ assert self.wav2feat.support_streaming or not self.online_mode
+
+ # ======== Register Avatar ========
+ crop_kwargs = {
+ "crop_scale": self.crop_scale,
+ "crop_vx_ratio": self.crop_vx_ratio,
+ "crop_vy_ratio": self.crop_vy_ratio,
+ "crop_flag_do_rot": self.crop_flag_do_rot,
+ }
+ n_frames = self.template_n_frames if self.template_n_frames > 0 else self.N_d
+ source_info = self.avatar_registrar(
+ source_path,
+ max_dim=self.max_size,
+ n_frames=n_frames,
+ **crop_kwargs,
+ )
+
+ if len(source_info["x_s_info_lst"]) > 1 and self.smo_k_s > 1:
+ source_info["x_s_info_lst"] = smooth_x_s_info_lst(source_info["x_s_info_lst"], smo_k=self.smo_k_s)
+
+ self.source_info = source_info
+ self.source_info_frames = len(source_info["x_s_info_lst"])
+
+ # ======== Setup Condition Handler ========
+ self.condition_handler.setup(source_info, self.emo, eye_f0_mode=self.eye_f0_mode, ch_info=self.ch_info)
+
+ # ======== Setup Audio2Motion (LMDM) ========
+ x_s_info_0 = self.condition_handler.x_s_info_0
+ self.audio2motion.setup(
+ x_s_info_0,
+ overlap_v2=self.overlap_v2,
+ fix_kp_cond=self.fix_kp_cond,
+ fix_kp_cond_dim=self.fix_kp_cond_dim,
+ sampling_timesteps=self.sampling_timesteps,
+ online_mode=self.online_mode,
+ v_min_max_for_clip=self.v_min_max_for_clip,
+ smo_k_d=self.smo_k_d,
+ )
+
+ # ======== Setup Motion Stitch ========
+ is_image_flag = source_info["is_image_flag"]
+ x_s_info = source_info['x_s_info_lst'][0]
+ self.motion_stitch.setup(
+ N_d=self.N_d,
+ use_d_keys=self.use_d_keys,
+ relative_d=self.relative_d,
+ drive_eye=self.drive_eye,
+ delta_eye_arr=self.delta_eye_arr,
+ delta_eye_open_n=self.delta_eye_open_n,
+ fade_out_keys=self.fade_out_keys,
+ fade_type=self.fade_type,
+ flag_stitching=self.flag_stitching,
+ is_image_flag=is_image_flag,
+ x_s_info=x_s_info,
+ d0=None,
+ ch_info=self.ch_info,
+ overall_ctrl_info=self.overall_ctrl_info,
+ )
+
+ # ======== Video Writer ========
+ self.output_path = output_path
+ self.tmp_output_path = output_path + ".tmp.mp4"
+ self.writer = VideoWriterByImageIO(self.tmp_output_path)
+ self.writer_pbar = tqdm(desc="writer")
+
+ # ======== Audio Feat Buffer ========
+ if self.online_mode:
+ # buffer: seq_frames - valid_clip_len
+ self.audio_feat = self.wav2feat.wav2feat(np.zeros((self.overlap_v2 * 640,), dtype=np.float32), sr=16000)
+ assert len(self.audio_feat) == self.overlap_v2, f"{len(self.audio_feat)}"
+ else:
+ self.audio_feat = np.zeros((0, self.wav2feat.feat_dim), dtype=np.float32)
+ self.cond_idx_start = 0 - len(self.audio_feat)
+
+ # ======== Setup Worker Threads ========
+ QUEUE_MAX_SIZE = 100
+ # self.QUEUE_TIMEOUT = None
+
+ self.worker_exception = None
+ self.stop_event = threading.Event()
+
+ self.audio2motion_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE)
+ self.motion_stitch_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE)
+ self.warp_f3d_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE)
+ self.decode_f3d_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE)
+ self.putback_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE)
+ self.writer_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE)
+
+ self.thread_list = [
+ threading.Thread(target=self.audio2motion_worker),
+ threading.Thread(target=self.motion_stitch_worker),
+ threading.Thread(target=self.warp_f3d_worker),
+ threading.Thread(target=self.decode_f3d_worker),
+ threading.Thread(target=self.putback_worker),
+ threading.Thread(target=self.writer_worker),
+ ]
+
+ for thread in self.thread_list:
+ thread.start()
+
+ def _get_ctrl_info(self, fid):
+ try:
+ if isinstance(self.ctrl_info, dict):
+ return self.ctrl_info.get(fid, {})
+ elif isinstance(self.ctrl_info, list):
+ return self.ctrl_info[fid]
+ else:
+ return {}
+ except Exception as e:
+ traceback.print_exc()
+ return {}
+
+ def writer_worker(self):
+ try:
+ self._writer_worker()
+ except Exception as e:
+ self.worker_exception = e
+ self.stop_event.set()
+
+ def _writer_worker(self):
+ while not self.stop_event.is_set():
+ try:
+ item = self.writer_queue.get(timeout=1)
+ except queue.Empty:
+ continue
+
+ if item is None:
+ break
+ res_frame_rgb = item
+ self.writer(res_frame_rgb, fmt="rgb")
+ self.writer_pbar.update()
+
+ def putback_worker(self):
+ try:
+ self._putback_worker()
+ except Exception as e:
+ self.worker_exception = e
+ self.stop_event.set()
+
+ def _putback_worker(self):
+ while not self.stop_event.is_set():
+ try:
+ item = self.putback_queue.get(timeout=1)
+ except queue.Empty:
+ continue
+ if item is None:
+ self.writer_queue.put(None)
+ break
+ frame_idx, render_img = item
+ frame_rgb = self.source_info["img_rgb_lst"][frame_idx]
+ M_c2o = self.source_info["M_c2o_lst"][frame_idx]
+ res_frame_rgb = self.putback(frame_rgb, render_img, M_c2o)
+ self.writer_queue.put(res_frame_rgb)
+
+ def decode_f3d_worker(self):
+ try:
+ self._decode_f3d_worker()
+ except Exception as e:
+ self.worker_exception = e
+ self.stop_event.set()
+
+ def _decode_f3d_worker(self):
+ while not self.stop_event.is_set():
+ try:
+ item = self.decode_f3d_queue.get(timeout=1)
+ except queue.Empty:
+ continue
+ if item is None:
+ self.putback_queue.put(None)
+ break
+ frame_idx, f_3d = item
+ render_img = self.decode_f3d(f_3d)
+ self.putback_queue.put([frame_idx, render_img])
+
+ def warp_f3d_worker(self):
+ try:
+ self._warp_f3d_worker()
+ except Exception as e:
+ self.worker_exception = e
+ self.stop_event.set()
+
+ def _warp_f3d_worker(self):
+ while not self.stop_event.is_set():
+ try:
+ item = self.warp_f3d_queue.get(timeout=1)
+ except queue.Empty:
+ continue
+ if item is None:
+ self.decode_f3d_queue.put(None)
+ break
+ frame_idx, x_s, x_d = item
+ f_s = self.source_info["f_s_lst"][frame_idx]
+ f_3d = self.warp_f3d(f_s, x_s, x_d)
+ self.decode_f3d_queue.put([frame_idx, f_3d])
+
+ def motion_stitch_worker(self):
+ try:
+ self._motion_stitch_worker()
+ except Exception as e:
+ self.worker_exception = e
+ self.stop_event.set()
+
+ def _motion_stitch_worker(self):
+ while not self.stop_event.is_set():
+ try:
+ item = self.motion_stitch_queue.get(timeout=1)
+ except queue.Empty:
+ continue
+ if item is None:
+ self.warp_f3d_queue.put(None)
+ break
+
+ frame_idx, x_d_info, ctrl_kwargs = item
+ x_s_info = self.source_info["x_s_info_lst"][frame_idx]
+ x_s, x_d = self.motion_stitch(x_s_info, x_d_info, **ctrl_kwargs)
+ self.warp_f3d_queue.put([frame_idx, x_s, x_d])
+
+ def audio2motion_worker(self):
+ try:
+ self._audio2motion_worker()
+ except Exception as e:
+ self.worker_exception = e
+ self.stop_event.set()
+
+ def _audio2motion_worker(self):
+ is_end = False
+ seq_frames = self.audio2motion.seq_frames
+ valid_clip_len = self.audio2motion.valid_clip_len
+ aud_feat_dim = self.wav2feat.feat_dim
+ item_buffer = np.zeros((0, aud_feat_dim), dtype=np.float32)
+
+ res_kp_seq = None
+ res_kp_seq_valid_start = None if self.online_mode else 0
+
+ global_idx = 0 # frame idx, for template
+ local_idx = 0 # for cur audio_feat
+ gen_frame_idx = 0
+ while not self.stop_event.is_set():
+ try:
+ item = self.audio2motion_queue.get(timeout=1) # audio feat
+ except queue.Empty:
+ continue
+ if item is None:
+ is_end = True
+ else:
+ item_buffer = np.concatenate([item_buffer, item], 0)
+
+ if not is_end and item_buffer.shape[0] < valid_clip_len:
+ # wait at least valid_clip_len new item
+ continue
+ else:
+ self.audio_feat = np.concatenate([self.audio_feat, item_buffer], 0)
+ item_buffer = np.zeros((0, aud_feat_dim), dtype=np.float32)
+
+ while True:
+ # print("self.audio_feat.shape:", self.audio_feat.shape, "local_idx:", local_idx, "global_idx:", global_idx)
+ aud_feat = self.audio_feat[local_idx: local_idx+seq_frames]
+ real_valid_len = valid_clip_len
+ if len(aud_feat) == 0:
+ break
+ elif len(aud_feat) < seq_frames:
+ if not is_end:
+ # wait next chunk
+ break
+ else:
+ # final clip: pad to seq_frames
+ real_valid_len = len(aud_feat)
+ pad = np.stack([aud_feat[-1]] * (seq_frames - len(aud_feat)), 0)
+ aud_feat = np.concatenate([aud_feat, pad], 0)
+
+ aud_cond = self.condition_handler(aud_feat, global_idx + self.cond_idx_start)[None]
+ res_kp_seq = self.audio2motion(aud_cond, res_kp_seq)
+ if res_kp_seq_valid_start is None:
+ # online mode, first chunk
+ res_kp_seq_valid_start = res_kp_seq.shape[1] - self.audio2motion.fuse_length
+ d0 = self.audio2motion.cvt_fmt(res_kp_seq[0:1])[0]
+ self.motion_stitch.d0 = d0
+
+ local_idx += real_valid_len
+ global_idx += real_valid_len
+ continue
+ else:
+ valid_res_kp_seq = res_kp_seq[:, res_kp_seq_valid_start: res_kp_seq_valid_start + real_valid_len]
+ x_d_info_list = self.audio2motion.cvt_fmt(valid_res_kp_seq)
+
+ for x_d_info in x_d_info_list:
+ frame_idx = _mirror_index(gen_frame_idx, self.source_info_frames)
+ ctrl_kwargs = self._get_ctrl_info(gen_frame_idx)
+
+ while not self.stop_event.is_set():
+ try:
+ self.motion_stitch_queue.put([frame_idx, x_d_info, ctrl_kwargs], timeout=1)
+ break
+ except queue.Full:
+ continue
+
+ gen_frame_idx += 1
+
+ res_kp_seq_valid_start += real_valid_len
+
+ local_idx += real_valid_len
+ global_idx += real_valid_len
+
+ L = res_kp_seq.shape[1]
+ if L > seq_frames * 2:
+ cut_L = L - seq_frames * 2
+ res_kp_seq = res_kp_seq[:, cut_L:]
+ res_kp_seq_valid_start -= cut_L
+
+ if local_idx >= len(self.audio_feat):
+ break
+
+ L = len(self.audio_feat)
+ if L > seq_frames * 2:
+ cut_L = L - seq_frames * 2
+ self.audio_feat = self.audio_feat[cut_L:]
+ local_idx -= cut_L
+
+ if is_end:
+ break
+
+ self.motion_stitch_queue.put(None)
+
+ def close(self):
+ # flush frames
+ self.audio2motion_queue.put(None)
+ # Wait for worker threads to finish
+ for thread in self.thread_list:
+ thread.join()
+
+ try:
+ self.writer.close()
+ self.writer_pbar.close()
+ except:
+ traceback.print_exc()
+
+ # Check if any worker encountered an exception
+ if self.worker_exception is not None:
+ raise self.worker_exception
+
+ def run_chunk(self, audio_chunk, chunksize=(3, 5, 2)):
+ # only for hubert
+ aud_feat = self.wav2feat(audio_chunk, chunksize=chunksize)
+ while not self.stop_event.is_set():
+ try:
+ self.audio2motion_queue.put(aud_feat, timeout=1)
+ break
+ except queue.Full:
+ continue
+
+
+