Spaces:
Sleeping
Sleeping
second half
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .idea/TalkSHOW.iml +8 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/modules.xml +8 -0
- .idea/vcs.xml +6 -0
- .idea/workspace.xml +55 -0
- config/LS3DCG.json +64 -0
- config/body_pixel.json +63 -0
- config/body_vq.json +62 -0
- config/face.json +59 -0
- data_utils/__init__.py +3 -0
- data_utils/__pycache__/__init__.cpython-37.pyc +0 -0
- data_utils/__pycache__/consts.cpython-37.pyc +0 -0
- data_utils/__pycache__/dataloader_torch.cpython-37.pyc +0 -0
- data_utils/__pycache__/lower_body.cpython-37.pyc +0 -0
- data_utils/__pycache__/mesh_dataset.cpython-37.pyc +0 -0
- data_utils/__pycache__/rotation_conversion.cpython-37.pyc +0 -0
- data_utils/__pycache__/utils.cpython-37.pyc +0 -0
- data_utils/apply_split.py +51 -0
- data_utils/axis2matrix.py +29 -0
- data_utils/consts.py +0 -0
- data_utils/dataloader_torch.py +279 -0
- data_utils/dataset_preprocess.py +170 -0
- data_utils/get_j.py +51 -0
- data_utils/hand_component.json +0 -0
- data_utils/lower_body.py +143 -0
- data_utils/mesh_dataset.py +348 -0
- data_utils/rotation_conversion.py +551 -0
- data_utils/split_more_than_2s.pkl +3 -0
- data_utils/split_train_val_test.py +27 -0
- data_utils/train_val_test.json +0 -0
- data_utils/utils.py +318 -0
- evaluation/FGD.py +199 -0
- evaluation/__init__.py +0 -0
- evaluation/__pycache__/__init__.cpython-37.pyc +0 -0
- evaluation/__pycache__/metrics.cpython-37.pyc +0 -0
- evaluation/diversity_LVD.py +64 -0
- evaluation/get_quality_samples.py +62 -0
- evaluation/metrics.py +109 -0
- evaluation/mode_transition.py +60 -0
- evaluation/peak_velocity.py +65 -0
- evaluation/util.py +148 -0
- losses/__init__.py +1 -0
- losses/__pycache__/__init__.cpython-37.pyc +0 -0
- losses/__pycache__/losses.cpython-37.pyc +0 -0
- losses/losses.py +91 -0
- nets/LS3DCG.py +414 -0
- nets/__init__.py +8 -0
- nets/__pycache__/__init__.cpython-37.pyc +0 -0
- nets/__pycache__/base.cpython-37.pyc +0 -0
- nets/__pycache__/init_model.cpython-37.pyc +0 -0
.idea/TalkSHOW.iml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$" />
|
| 5 |
+
<orderEntry type="inheritedJdk" />
|
| 6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 7 |
+
</component>
|
| 8 |
+
</module>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
.idea/modules.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectModuleManager">
|
| 4 |
+
<modules>
|
| 5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/TalkSHOW.iml" filepath="$PROJECT_DIR$/.idea/TalkSHOW.iml" />
|
| 6 |
+
</modules>
|
| 7 |
+
</component>
|
| 8 |
+
</project>
|
.idea/vcs.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="VcsDirectoryMappings">
|
| 4 |
+
<mapping directory="" vcs="Git" />
|
| 5 |
+
</component>
|
| 6 |
+
</project>
|
.idea/workspace.xml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ChangeListManager">
|
| 4 |
+
<list default="true" id="210ec380-7aeb-4dbc-977d-b398d2f30cdf" name="Changes" comment="" />
|
| 5 |
+
<option name="SHOW_DIALOG" value="false" />
|
| 6 |
+
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
| 7 |
+
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
| 8 |
+
<option name="LAST_RESOLUTION" value="IGNORE" />
|
| 9 |
+
</component>
|
| 10 |
+
<component name="Git.Settings">
|
| 11 |
+
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
| 12 |
+
</component>
|
| 13 |
+
<component name="GitHubPullRequestSearchHistory"><![CDATA[{
|
| 14 |
+
"lastFilter": {
|
| 15 |
+
"state": "OPEN",
|
| 16 |
+
"assignee": "khaitha"
|
| 17 |
+
}
|
| 18 |
+
}]]></component>
|
| 19 |
+
<component name="ProjectColorInfo"><![CDATA[{
|
| 20 |
+
"associatedIndex": 5
|
| 21 |
+
}]]></component>
|
| 22 |
+
<component name="ProjectId" id="2nX5jW7aFzy8GxYLTCwAZXC0ETa" />
|
| 23 |
+
<component name="ProjectViewState">
|
| 24 |
+
<option name="hideEmptyMiddlePackages" value="true" />
|
| 25 |
+
<option name="showLibraryContents" value="true" />
|
| 26 |
+
</component>
|
| 27 |
+
<component name="PropertiesComponent"><![CDATA[{
|
| 28 |
+
"keyToString": {
|
| 29 |
+
"RunOnceActivity.ShowReadmeOnStart": "true",
|
| 30 |
+
"git-widget-placeholder": "main",
|
| 31 |
+
"nodejs_package_manager_path": "npm",
|
| 32 |
+
"vue.rearranger.settings.migration": "true"
|
| 33 |
+
}
|
| 34 |
+
}]]></component>
|
| 35 |
+
<component name="SharedIndexes">
|
| 36 |
+
<attachedChunks>
|
| 37 |
+
<set>
|
| 38 |
+
<option value="bundled-js-predefined-d6986cc7102b-5c90d61e3bab-JavaScript-PY-242.23339.19" />
|
| 39 |
+
<option value="bundled-python-sdk-0029f7779945-399fe30bd8c1-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-242.23339.19" />
|
| 40 |
+
</set>
|
| 41 |
+
</attachedChunks>
|
| 42 |
+
</component>
|
| 43 |
+
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
|
| 44 |
+
<component name="TaskManager">
|
| 45 |
+
<task active="true" id="Default" summary="Default task">
|
| 46 |
+
<changelist id="210ec380-7aeb-4dbc-977d-b398d2f30cdf" name="Changes" comment="" />
|
| 47 |
+
<created>1729106717261</created>
|
| 48 |
+
<option name="number" value="Default" />
|
| 49 |
+
<option name="presentableId" value="Default" />
|
| 50 |
+
<updated>1729106717261</updated>
|
| 51 |
+
<workItem from="1729106719428" duration="5000" />
|
| 52 |
+
</task>
|
| 53 |
+
<servers />
|
| 54 |
+
</component>
|
| 55 |
+
</project>
|
config/LS3DCG.json
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
|
| 3 |
+
"dataset_load_mode": "pickle",
|
| 4 |
+
"store_file_path": "store.pkl",
|
| 5 |
+
"smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
|
| 6 |
+
"extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
|
| 7 |
+
"j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
|
| 8 |
+
"param": {
|
| 9 |
+
"w_j": 1,
|
| 10 |
+
"w_b": 1,
|
| 11 |
+
"w_h": 1
|
| 12 |
+
},
|
| 13 |
+
"Data": {
|
| 14 |
+
"data_root": "../ExpressiveWholeBodyDatasetv1.0/",
|
| 15 |
+
"pklname": "_3d_mfcc.pkl",
|
| 16 |
+
"whole_video": false,
|
| 17 |
+
"pose": {
|
| 18 |
+
"normalization": false,
|
| 19 |
+
"convert_to_6d": false,
|
| 20 |
+
"norm_method": "all",
|
| 21 |
+
"augmentation": false,
|
| 22 |
+
"generate_length": 88,
|
| 23 |
+
"pre_pose_length": 0,
|
| 24 |
+
"pose_dim": 99,
|
| 25 |
+
"expression": true
|
| 26 |
+
},
|
| 27 |
+
"aud": {
|
| 28 |
+
"feat_method": "mfcc",
|
| 29 |
+
"aud_feat_dim": 64,
|
| 30 |
+
"aud_feat_win_size": null,
|
| 31 |
+
"context_info": false
|
| 32 |
+
}
|
| 33 |
+
},
|
| 34 |
+
"Model": {
|
| 35 |
+
"model_type": "body",
|
| 36 |
+
"model_name": "s2g_LS3DCG",
|
| 37 |
+
"code_num": 2048,
|
| 38 |
+
"AudioOpt": "Adam",
|
| 39 |
+
"encoder_choice": "mfcc",
|
| 40 |
+
"gan": false
|
| 41 |
+
},
|
| 42 |
+
"DataLoader": {
|
| 43 |
+
"batch_size": 128,
|
| 44 |
+
"num_workers": 0
|
| 45 |
+
},
|
| 46 |
+
"Train": {
|
| 47 |
+
"epochs": 100,
|
| 48 |
+
"max_gradient_norm": 5,
|
| 49 |
+
"learning_rate": {
|
| 50 |
+
"generator_learning_rate": 1e-4,
|
| 51 |
+
"discriminator_learning_rate": 1e-4
|
| 52 |
+
},
|
| 53 |
+
"weights": {
|
| 54 |
+
"keypoint_loss_weight": 1.0,
|
| 55 |
+
"gan_loss_weight": 1.0
|
| 56 |
+
}
|
| 57 |
+
},
|
| 58 |
+
"Log": {
|
| 59 |
+
"save_every": 50,
|
| 60 |
+
"print_every": 200,
|
| 61 |
+
"name": "LS3DCG"
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
|
config/body_pixel.json
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
|
| 3 |
+
"dataset_load_mode": "json",
|
| 4 |
+
"store_file_path": "store.pkl",
|
| 5 |
+
"smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
|
| 6 |
+
"extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
|
| 7 |
+
"j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
|
| 8 |
+
"param": {
|
| 9 |
+
"w_j": 1,
|
| 10 |
+
"w_b": 1,
|
| 11 |
+
"w_h": 1
|
| 12 |
+
},
|
| 13 |
+
"Data": {
|
| 14 |
+
"data_root": "../ExpressiveWholeBodyDatasetv1.0/",
|
| 15 |
+
"pklname": "_3d_mfcc.pkl",
|
| 16 |
+
"whole_video": false,
|
| 17 |
+
"pose": {
|
| 18 |
+
"normalization": false,
|
| 19 |
+
"convert_to_6d": false,
|
| 20 |
+
"norm_method": "all",
|
| 21 |
+
"augmentation": false,
|
| 22 |
+
"generate_length": 88,
|
| 23 |
+
"pre_pose_length": 0,
|
| 24 |
+
"pose_dim": 99,
|
| 25 |
+
"expression": true
|
| 26 |
+
},
|
| 27 |
+
"aud": {
|
| 28 |
+
"feat_method": "mfcc",
|
| 29 |
+
"aud_feat_dim": 64,
|
| 30 |
+
"aud_feat_win_size": null,
|
| 31 |
+
"context_info": false
|
| 32 |
+
}
|
| 33 |
+
},
|
| 34 |
+
"Model": {
|
| 35 |
+
"model_type": "body",
|
| 36 |
+
"model_name": "s2g_body_pixel",
|
| 37 |
+
"composition": true,
|
| 38 |
+
"code_num": 2048,
|
| 39 |
+
"bh_model": true,
|
| 40 |
+
"AudioOpt": "Adam",
|
| 41 |
+
"encoder_choice": "mfcc",
|
| 42 |
+
"gan": false,
|
| 43 |
+
"vq_path": "./experiments/2022-10-31-smplx_S2G-body-vq-3d/ckpt-99.pth"
|
| 44 |
+
},
|
| 45 |
+
"DataLoader": {
|
| 46 |
+
"batch_size": 128,
|
| 47 |
+
"num_workers": 0
|
| 48 |
+
},
|
| 49 |
+
"Train": {
|
| 50 |
+
"epochs": 100,
|
| 51 |
+
"max_gradient_norm": 5,
|
| 52 |
+
"learning_rate": {
|
| 53 |
+
"generator_learning_rate": 1e-4,
|
| 54 |
+
"discriminator_learning_rate": 1e-4
|
| 55 |
+
}
|
| 56 |
+
},
|
| 57 |
+
"Log": {
|
| 58 |
+
"save_every": 50,
|
| 59 |
+
"print_every": 200,
|
| 60 |
+
"name": "body-pixel2"
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
|
config/body_vq.json
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
|
| 3 |
+
"dataset_load_mode": "json",
|
| 4 |
+
"store_file_path": "store.pkl",
|
| 5 |
+
"smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
|
| 6 |
+
"extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
|
| 7 |
+
"j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
|
| 8 |
+
"param": {
|
| 9 |
+
"w_j": 1,
|
| 10 |
+
"w_b": 1,
|
| 11 |
+
"w_h": 1
|
| 12 |
+
},
|
| 13 |
+
"Data": {
|
| 14 |
+
"data_root": "../ExpressiveWholeBodyDatasetv1.0/",
|
| 15 |
+
"pklname": "_3d_mfcc.pkl",
|
| 16 |
+
"whole_video": false,
|
| 17 |
+
"pose": {
|
| 18 |
+
"normalization": false,
|
| 19 |
+
"convert_to_6d": false,
|
| 20 |
+
"norm_method": "all",
|
| 21 |
+
"augmentation": false,
|
| 22 |
+
"generate_length": 88,
|
| 23 |
+
"pre_pose_length": 0,
|
| 24 |
+
"pose_dim": 99,
|
| 25 |
+
"expression": true
|
| 26 |
+
},
|
| 27 |
+
"aud": {
|
| 28 |
+
"feat_method": "mfcc",
|
| 29 |
+
"aud_feat_dim": 64,
|
| 30 |
+
"aud_feat_win_size": null,
|
| 31 |
+
"context_info": false
|
| 32 |
+
}
|
| 33 |
+
},
|
| 34 |
+
"Model": {
|
| 35 |
+
"model_type": "body",
|
| 36 |
+
"model_name": "s2g_body_vq",
|
| 37 |
+
"composition": true,
|
| 38 |
+
"code_num": 2048,
|
| 39 |
+
"bh_model": true,
|
| 40 |
+
"AudioOpt": "Adam",
|
| 41 |
+
"encoder_choice": "mfcc",
|
| 42 |
+
"gan": false
|
| 43 |
+
},
|
| 44 |
+
"DataLoader": {
|
| 45 |
+
"batch_size": 128,
|
| 46 |
+
"num_workers": 0
|
| 47 |
+
},
|
| 48 |
+
"Train": {
|
| 49 |
+
"epochs": 100,
|
| 50 |
+
"max_gradient_norm": 5,
|
| 51 |
+
"learning_rate": {
|
| 52 |
+
"generator_learning_rate": 1e-4,
|
| 53 |
+
"discriminator_learning_rate": 1e-4
|
| 54 |
+
}
|
| 55 |
+
},
|
| 56 |
+
"Log": {
|
| 57 |
+
"save_every": 50,
|
| 58 |
+
"print_every": 200,
|
| 59 |
+
"name": "body-vq"
|
| 60 |
+
}
|
| 61 |
+
}
|
| 62 |
+
|
config/face.json
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
|
| 3 |
+
"dataset_load_mode": "json",
|
| 4 |
+
"store_file_path": "store.pkl",
|
| 5 |
+
"smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
|
| 6 |
+
"extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
|
| 7 |
+
"j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
|
| 8 |
+
"param": {
|
| 9 |
+
"w_j": 1,
|
| 10 |
+
"w_b": 1,
|
| 11 |
+
"w_h": 1
|
| 12 |
+
},
|
| 13 |
+
"Data": {
|
| 14 |
+
"data_root": "../ExpressiveWholeBodyDatasetv1.0/",
|
| 15 |
+
"pklname": "_3d_wv2.pkl",
|
| 16 |
+
"whole_video": true,
|
| 17 |
+
"pose": {
|
| 18 |
+
"normalization": false,
|
| 19 |
+
"convert_to_6d": false,
|
| 20 |
+
"norm_method": "all",
|
| 21 |
+
"augmentation": false,
|
| 22 |
+
"generate_length": 88,
|
| 23 |
+
"pre_pose_length": 0,
|
| 24 |
+
"pose_dim": 99,
|
| 25 |
+
"expression": true
|
| 26 |
+
},
|
| 27 |
+
"aud": {
|
| 28 |
+
"feat_method": "mfcc",
|
| 29 |
+
"aud_feat_dim": 64,
|
| 30 |
+
"aud_feat_win_size": null,
|
| 31 |
+
"context_info": false
|
| 32 |
+
}
|
| 33 |
+
},
|
| 34 |
+
"Model": {
|
| 35 |
+
"model_type": "face",
|
| 36 |
+
"model_name": "s2g_face",
|
| 37 |
+
"AudioOpt": "SGD",
|
| 38 |
+
"encoder_choice": "faceformer",
|
| 39 |
+
"gan": false
|
| 40 |
+
},
|
| 41 |
+
"DataLoader": {
|
| 42 |
+
"batch_size": 1,
|
| 43 |
+
"num_workers": 0
|
| 44 |
+
},
|
| 45 |
+
"Train": {
|
| 46 |
+
"epochs": 100,
|
| 47 |
+
"max_gradient_norm": 5,
|
| 48 |
+
"learning_rate": {
|
| 49 |
+
"generator_learning_rate": 1e-4,
|
| 50 |
+
"discriminator_learning_rate": 1e-4
|
| 51 |
+
}
|
| 52 |
+
},
|
| 53 |
+
"Log": {
|
| 54 |
+
"save_every": 50,
|
| 55 |
+
"print_every": 1000,
|
| 56 |
+
"name": "face"
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
|
data_utils/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from .dataloader_csv import MultiVidData as csv_data
|
| 2 |
+
from .dataloader_torch import MultiVidData as torch_data
|
| 3 |
+
from .utils import get_melspec, get_mfcc, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
|
data_utils/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (375 Bytes). View file
|
|
|
data_utils/__pycache__/consts.cpython-37.pyc
ADDED
|
Binary file (92.7 kB). View file
|
|
|
data_utils/__pycache__/dataloader_torch.cpython-37.pyc
ADDED
|
Binary file (5.31 kB). View file
|
|
|
data_utils/__pycache__/lower_body.cpython-37.pyc
ADDED
|
Binary file (3.91 kB). View file
|
|
|
data_utils/__pycache__/mesh_dataset.cpython-37.pyc
ADDED
|
Binary file (7.9 kB). View file
|
|
|
data_utils/__pycache__/rotation_conversion.cpython-37.pyc
ADDED
|
Binary file (16.4 kB). View file
|
|
|
data_utils/__pycache__/utils.cpython-37.pyc
ADDED
|
Binary file (7.42 kB). View file
|
|
|
data_utils/apply_split.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
import pickle
|
| 4 |
+
import shutil
|
| 5 |
+
|
| 6 |
+
speakers = ['seth', 'oliver', 'conan', 'chemistry']
|
| 7 |
+
source_data_root = "../expressive_body-V0.7"
|
| 8 |
+
data_root = "D:/Downloads/SHOW_dataset_v1.0/ExpressiveWholeBodyDatasetReleaseV1.0"
|
| 9 |
+
|
| 10 |
+
f_read = open('split_more_than_2s.pkl', 'rb')
|
| 11 |
+
f_save = open('none.pkl', 'wb')
|
| 12 |
+
data_split = pickle.load(f_read)
|
| 13 |
+
none_split = []
|
| 14 |
+
|
| 15 |
+
train = val = test = 0
|
| 16 |
+
|
| 17 |
+
for speaker_name in speakers:
|
| 18 |
+
speaker_root = os.path.join(data_root, speaker_name)
|
| 19 |
+
|
| 20 |
+
videos = [v for v in data_split[speaker_name]]
|
| 21 |
+
|
| 22 |
+
for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
|
| 23 |
+
for split in data_split[speaker_name][vid]:
|
| 24 |
+
for seq in data_split[speaker_name][vid][split]:
|
| 25 |
+
|
| 26 |
+
seq = seq.replace('\\', '/')
|
| 27 |
+
old_file_path = os.path.join(data_root, speaker_name, vid, seq.split('/')[-1])
|
| 28 |
+
old_file_path = old_file_path.replace('\\', '/')
|
| 29 |
+
new_file_path = seq.replace(source_data_root.split('/')[-1], data_root.split('/')[-1])
|
| 30 |
+
try:
|
| 31 |
+
shutil.move(old_file_path, new_file_path)
|
| 32 |
+
if split == 'train':
|
| 33 |
+
train = train + 1
|
| 34 |
+
elif split == 'test':
|
| 35 |
+
test = test + 1
|
| 36 |
+
elif split == 'val':
|
| 37 |
+
val = val + 1
|
| 38 |
+
except FileNotFoundError:
|
| 39 |
+
none_split.append(old_file_path)
|
| 40 |
+
print(f"The file {old_file_path} does not exists.")
|
| 41 |
+
except shutil.Error:
|
| 42 |
+
none_split.append(old_file_path)
|
| 43 |
+
print(f"The file {old_file_path} does not exists.")
|
| 44 |
+
|
| 45 |
+
print(none_split.__len__())
|
| 46 |
+
pickle.dump(none_split, f_save)
|
| 47 |
+
f_save.close()
|
| 48 |
+
|
| 49 |
+
print(train, val, test)
|
| 50 |
+
|
| 51 |
+
|
data_utils/axis2matrix.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import math
|
| 3 |
+
import scipy.linalg as linalg
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def rotate_mat(axis, radian):
|
| 7 |
+
|
| 8 |
+
a = np.cross(np.eye(3), axis / linalg.norm(axis) * radian)
|
| 9 |
+
|
| 10 |
+
rot_matrix = linalg.expm(a)
|
| 11 |
+
|
| 12 |
+
return rot_matrix
|
| 13 |
+
|
| 14 |
+
def aaa2mat(axis, sin, cos):
|
| 15 |
+
i = np.eye(3)
|
| 16 |
+
nnt = np.dot(axis.T, axis)
|
| 17 |
+
s = np.asarray([[0, -axis[0,2], axis[0,1]],
|
| 18 |
+
[axis[0,2], 0, -axis[0,0]],
|
| 19 |
+
[-axis[0,1], axis[0,0], 0]])
|
| 20 |
+
r = cos * i + (1-cos)*nnt +sin * s
|
| 21 |
+
return r
|
| 22 |
+
|
| 23 |
+
rand_axis = np.asarray([[1,0,0]])
|
| 24 |
+
#旋转角度
|
| 25 |
+
r = math.pi/2
|
| 26 |
+
#返回旋转矩阵
|
| 27 |
+
rot_matrix = rotate_mat(rand_axis, r)
|
| 28 |
+
r2 = aaa2mat(rand_axis, np.sin(r), np.cos(r))
|
| 29 |
+
print(rot_matrix)
|
data_utils/consts.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data_utils/dataloader_torch.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
sys.path.append(os.getcwd())
|
| 4 |
+
import os
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from data_utils.utils import *
|
| 7 |
+
import torch.utils.data as data
|
| 8 |
+
from data_utils.mesh_dataset import SmplxDataset
|
| 9 |
+
from transformers import Wav2Vec2Processor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MultiVidData():
|
| 13 |
+
def __init__(self,
|
| 14 |
+
data_root,
|
| 15 |
+
speakers,
|
| 16 |
+
split='train',
|
| 17 |
+
limbscaling=False,
|
| 18 |
+
normalization=False,
|
| 19 |
+
norm_method='new',
|
| 20 |
+
split_trans_zero=False,
|
| 21 |
+
num_frames=25,
|
| 22 |
+
num_pre_frames=25,
|
| 23 |
+
num_generate_length=None,
|
| 24 |
+
aud_feat_win_size=None,
|
| 25 |
+
aud_feat_dim=64,
|
| 26 |
+
feat_method='mel_spec',
|
| 27 |
+
context_info=False,
|
| 28 |
+
smplx=False,
|
| 29 |
+
audio_sr=16000,
|
| 30 |
+
convert_to_6d=False,
|
| 31 |
+
expression=False,
|
| 32 |
+
config=None
|
| 33 |
+
):
|
| 34 |
+
self.data_root = data_root
|
| 35 |
+
self.speakers = speakers
|
| 36 |
+
self.split = split
|
| 37 |
+
if split == 'pre':
|
| 38 |
+
self.split = 'train'
|
| 39 |
+
self.norm_method=norm_method
|
| 40 |
+
self.normalization = normalization
|
| 41 |
+
self.limbscaling = limbscaling
|
| 42 |
+
self.convert_to_6d = convert_to_6d
|
| 43 |
+
self.num_frames=num_frames
|
| 44 |
+
self.num_pre_frames=num_pre_frames
|
| 45 |
+
if num_generate_length is None:
|
| 46 |
+
self.num_generate_length = num_frames
|
| 47 |
+
else:
|
| 48 |
+
self.num_generate_length = num_generate_length
|
| 49 |
+
self.split_trans_zero=split_trans_zero
|
| 50 |
+
|
| 51 |
+
dataset = SmplxDataset
|
| 52 |
+
|
| 53 |
+
if self.split_trans_zero:
|
| 54 |
+
self.trans_dataset_list = []
|
| 55 |
+
self.zero_dataset_list = []
|
| 56 |
+
else:
|
| 57 |
+
self.all_dataset_list = []
|
| 58 |
+
self.dataset={}
|
| 59 |
+
self.complete_data=[]
|
| 60 |
+
self.config=config
|
| 61 |
+
load_mode=self.config.dataset_load_mode
|
| 62 |
+
|
| 63 |
+
######################load with pickle file
|
| 64 |
+
if load_mode=='pickle':
|
| 65 |
+
import pickle
|
| 66 |
+
import subprocess
|
| 67 |
+
|
| 68 |
+
# store_file_path='/tmp/store.pkl'
|
| 69 |
+
# cp /is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts/store.pkl /tmp/store.pkl
|
| 70 |
+
# subprocess.run(f'cp /is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts/store.pkl {store_file_path}',shell=True)
|
| 71 |
+
|
| 72 |
+
# f = open(self.config.store_file_path, 'rb+')
|
| 73 |
+
f = open(self.split+config.Data.pklname, 'rb+')
|
| 74 |
+
self.dataset=pickle.load(f)
|
| 75 |
+
f.close()
|
| 76 |
+
for key in self.dataset:
|
| 77 |
+
self.complete_data.append(self.dataset[key].complete_data)
|
| 78 |
+
######################load with pickle file
|
| 79 |
+
|
| 80 |
+
######################load with a csv file
|
| 81 |
+
elif load_mode=='csv':
|
| 82 |
+
|
| 83 |
+
# 这里从我的一个code文件夹导入的,后续再完善进来
|
| 84 |
+
try:
|
| 85 |
+
sys.path.append(self.config.config_root_path)
|
| 86 |
+
from config import config_path
|
| 87 |
+
from csv_parser import csv_parse
|
| 88 |
+
|
| 89 |
+
except ImportError as e:
|
| 90 |
+
print(f'err: {e}')
|
| 91 |
+
raise ImportError('config root path error...')
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
for speaker_name in self.speakers:
|
| 95 |
+
# df_intervals=pd.read_csv(self.config.voca_csv_file_path)
|
| 96 |
+
df_intervals=None
|
| 97 |
+
df_intervals=df_intervals[df_intervals['speaker']==speaker_name]
|
| 98 |
+
df_intervals = df_intervals[df_intervals['dataset'] == self.split]
|
| 99 |
+
|
| 100 |
+
print(f'speaker {speaker_name} train interval length: {len(df_intervals)}')
|
| 101 |
+
for iter_index, (_, interval) in tqdm(
|
| 102 |
+
(enumerate(df_intervals.iterrows())),desc=f'load {speaker_name}'
|
| 103 |
+
):
|
| 104 |
+
|
| 105 |
+
(
|
| 106 |
+
interval_index,
|
| 107 |
+
interval_speaker,
|
| 108 |
+
interval_video_fn,
|
| 109 |
+
interval_id,
|
| 110 |
+
|
| 111 |
+
start_time,
|
| 112 |
+
end_time,
|
| 113 |
+
duration_time,
|
| 114 |
+
start_time_10,
|
| 115 |
+
over_flow_flag,
|
| 116 |
+
short_dur_flag,
|
| 117 |
+
|
| 118 |
+
big_video_dir,
|
| 119 |
+
small_video_dir_name,
|
| 120 |
+
speaker_video_path,
|
| 121 |
+
|
| 122 |
+
voca_basename,
|
| 123 |
+
json_basename,
|
| 124 |
+
wav_basename,
|
| 125 |
+
voca_top_clip_path,
|
| 126 |
+
voca_json_clip_path,
|
| 127 |
+
voca_wav_clip_path,
|
| 128 |
+
|
| 129 |
+
audio_output_fn,
|
| 130 |
+
image_output_path,
|
| 131 |
+
pifpaf_output_path,
|
| 132 |
+
mp_output_path,
|
| 133 |
+
op_output_path,
|
| 134 |
+
deca_output_path,
|
| 135 |
+
pixie_output_path,
|
| 136 |
+
cam_output_path,
|
| 137 |
+
ours_output_path,
|
| 138 |
+
merge_output_path,
|
| 139 |
+
multi_output_path,
|
| 140 |
+
gt_output_path,
|
| 141 |
+
ours_images_path,
|
| 142 |
+
pkl_fil_path,
|
| 143 |
+
)=csv_parse(interval)
|
| 144 |
+
|
| 145 |
+
if not os.path.exists(pkl_fil_path) or not os.path.exists(audio_output_fn):
|
| 146 |
+
continue
|
| 147 |
+
|
| 148 |
+
key=f'{interval_video_fn}/{small_video_dir_name}'
|
| 149 |
+
self.dataset[key] = dataset(
|
| 150 |
+
data_root=pkl_fil_path,
|
| 151 |
+
speaker=speaker_name,
|
| 152 |
+
audio_fn=audio_output_fn,
|
| 153 |
+
audio_sr=audio_sr,
|
| 154 |
+
fps=num_frames,
|
| 155 |
+
feat_method=feat_method,
|
| 156 |
+
audio_feat_dim=aud_feat_dim,
|
| 157 |
+
train=(self.split == 'train'),
|
| 158 |
+
load_all=True,
|
| 159 |
+
split_trans_zero=self.split_trans_zero,
|
| 160 |
+
limbscaling=self.limbscaling,
|
| 161 |
+
num_frames=self.num_frames,
|
| 162 |
+
num_pre_frames=self.num_pre_frames,
|
| 163 |
+
num_generate_length=self.num_generate_length,
|
| 164 |
+
audio_feat_win_size=aud_feat_win_size,
|
| 165 |
+
context_info=context_info,
|
| 166 |
+
convert_to_6d=convert_to_6d,
|
| 167 |
+
expression=expression,
|
| 168 |
+
config=self.config
|
| 169 |
+
)
|
| 170 |
+
self.complete_data.append(self.dataset[key].complete_data)
|
| 171 |
+
######################load with a csv file
|
| 172 |
+
|
| 173 |
+
######################origin load method
|
| 174 |
+
elif load_mode=='json':
|
| 175 |
+
|
| 176 |
+
# if self.split == 'train':
|
| 177 |
+
# import pickle
|
| 178 |
+
# f = open('store.pkl', 'rb+')
|
| 179 |
+
# self.dataset=pickle.load(f)
|
| 180 |
+
# f.close()
|
| 181 |
+
# for key in self.dataset:
|
| 182 |
+
# self.complete_data.append(self.dataset[key].complete_data)
|
| 183 |
+
# else:https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav
|
| 184 |
+
# if config.Model.model_type == 'face':
|
| 185 |
+
am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
|
| 186 |
+
am_sr = 16000
|
| 187 |
+
# else:
|
| 188 |
+
# am, am_sr = None, None
|
| 189 |
+
for speaker_name in self.speakers:
|
| 190 |
+
speaker_root = os.path.join(self.data_root, speaker_name)
|
| 191 |
+
|
| 192 |
+
videos=[v for v in os.listdir(speaker_root) ]
|
| 193 |
+
print(videos)
|
| 194 |
+
|
| 195 |
+
haode = huaide = 0
|
| 196 |
+
|
| 197 |
+
for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
|
| 198 |
+
source_vid=vid
|
| 199 |
+
# vid_pth=os.path.join(speaker_root, source_vid, 'images/half', self.split)
|
| 200 |
+
vid_pth = os.path.join(speaker_root, source_vid, self.split)
|
| 201 |
+
if smplx == 'pose':
|
| 202 |
+
seqs = [s for s in os.listdir(vid_pth) if (s.startswith('clip'))]
|
| 203 |
+
else:
|
| 204 |
+
try:
|
| 205 |
+
seqs = [s for s in os.listdir(vid_pth)]
|
| 206 |
+
except:
|
| 207 |
+
continue
|
| 208 |
+
|
| 209 |
+
for s in seqs:
|
| 210 |
+
seq_root=os.path.join(vid_pth, s)
|
| 211 |
+
key = seq_root # correspond to clip******
|
| 212 |
+
audio_fname = os.path.join(speaker_root, source_vid, self.split, s, '%s.wav' % (s))
|
| 213 |
+
motion_fname = os.path.join(speaker_root, source_vid, self.split, s, '%s.pkl' % (s))
|
| 214 |
+
if not os.path.isfile(audio_fname) or not os.path.isfile(motion_fname):
|
| 215 |
+
huaide = huaide + 1
|
| 216 |
+
continue
|
| 217 |
+
|
| 218 |
+
self.dataset[key]=dataset(
|
| 219 |
+
data_root=seq_root,
|
| 220 |
+
speaker=speaker_name,
|
| 221 |
+
motion_fn=motion_fname,
|
| 222 |
+
audio_fn=audio_fname,
|
| 223 |
+
audio_sr=audio_sr,
|
| 224 |
+
fps=num_frames,
|
| 225 |
+
feat_method=feat_method,
|
| 226 |
+
audio_feat_dim=aud_feat_dim,
|
| 227 |
+
train=(self.split=='train'),
|
| 228 |
+
load_all=True,
|
| 229 |
+
split_trans_zero=self.split_trans_zero,
|
| 230 |
+
limbscaling=self.limbscaling,
|
| 231 |
+
num_frames=self.num_frames,
|
| 232 |
+
num_pre_frames=self.num_pre_frames,
|
| 233 |
+
num_generate_length=self.num_generate_length,
|
| 234 |
+
audio_feat_win_size=aud_feat_win_size,
|
| 235 |
+
context_info=context_info,
|
| 236 |
+
convert_to_6d=convert_to_6d,
|
| 237 |
+
expression=expression,
|
| 238 |
+
config=self.config,
|
| 239 |
+
am=am,
|
| 240 |
+
am_sr=am_sr,
|
| 241 |
+
whole_video=config.Data.whole_video
|
| 242 |
+
)
|
| 243 |
+
self.complete_data.append(self.dataset[key].complete_data)
|
| 244 |
+
haode = haode + 1
|
| 245 |
+
print("huaide:{}, haode:{}".format(huaide, haode))
|
| 246 |
+
import pickle
|
| 247 |
+
|
| 248 |
+
f = open(self.split+config.Data.pklname, 'wb')
|
| 249 |
+
pickle.dump(self.dataset, f)
|
| 250 |
+
f.close()
|
| 251 |
+
######################origin load method
|
| 252 |
+
|
| 253 |
+
self.complete_data=np.concatenate(self.complete_data, axis=0)
|
| 254 |
+
|
| 255 |
+
# assert self.complete_data.shape[-1] == (12+21+21)*2
|
| 256 |
+
self.normalize_stats = {}
|
| 257 |
+
|
| 258 |
+
self.data_mean = None
|
| 259 |
+
self.data_std = None
|
| 260 |
+
|
| 261 |
+
def get_dataset(self):
|
| 262 |
+
self.normalize_stats['mean'] = self.data_mean
|
| 263 |
+
self.normalize_stats['std'] = self.data_std
|
| 264 |
+
|
| 265 |
+
for key in list(self.dataset.keys()):
|
| 266 |
+
if self.dataset[key].complete_data.shape[0] < self.num_generate_length:
|
| 267 |
+
continue
|
| 268 |
+
self.dataset[key].num_generate_length = self.num_generate_length
|
| 269 |
+
self.dataset[key].get_dataset(self.normalization, self.normalize_stats, self.split)
|
| 270 |
+
self.all_dataset_list.append(self.dataset[key].all_dataset)
|
| 271 |
+
|
| 272 |
+
if self.split_trans_zero:
|
| 273 |
+
self.trans_dataset = data.ConcatDataset(self.trans_dataset_list)
|
| 274 |
+
self.zero_dataset = data.ConcatDataset(self.zero_dataset_list)
|
| 275 |
+
else:
|
| 276 |
+
self.all_dataset = data.ConcatDataset(self.all_dataset_list)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
|
data_utils/dataset_preprocess.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import shutil
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import librosa
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
speakers = ['seth', 'conan', 'oliver', 'chemistry']
|
| 11 |
+
data_root = "../ExpressiveWholeBodyDatasetv1.0/"
|
| 12 |
+
split = 'train'
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def split_list(full_list,shuffle=False,ratio=0.2):
|
| 17 |
+
n_total = len(full_list)
|
| 18 |
+
offset_0 = int(n_total * ratio)
|
| 19 |
+
offset_1 = int(n_total * ratio * 2)
|
| 20 |
+
if n_total==0 or offset_1<1:
|
| 21 |
+
return [],full_list
|
| 22 |
+
if shuffle:
|
| 23 |
+
random.shuffle(full_list)
|
| 24 |
+
sublist_0 = full_list[:offset_0]
|
| 25 |
+
sublist_1 = full_list[offset_0:offset_1]
|
| 26 |
+
sublist_2 = full_list[offset_1:]
|
| 27 |
+
return sublist_0, sublist_1, sublist_2
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def moveto(list, file):
|
| 31 |
+
for f in list:
|
| 32 |
+
before, after = '/'.join(f.split('/')[:-1]), f.split('/')[-1]
|
| 33 |
+
new_path = os.path.join(before, file)
|
| 34 |
+
new_path = os.path.join(new_path, after)
|
| 35 |
+
# os.makedirs(new_path)
|
| 36 |
+
# os.path.isdir(new_path)
|
| 37 |
+
# shutil.move(f, new_path)
|
| 38 |
+
|
| 39 |
+
#转移到新目录
|
| 40 |
+
shutil.copytree(f, new_path)
|
| 41 |
+
#删除原train里的文件
|
| 42 |
+
shutil.rmtree(f)
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def read_pkl(data):
|
| 47 |
+
betas = np.array(data['betas'])
|
| 48 |
+
|
| 49 |
+
jaw_pose = np.array(data['jaw_pose'])
|
| 50 |
+
leye_pose = np.array(data['leye_pose'])
|
| 51 |
+
reye_pose = np.array(data['reye_pose'])
|
| 52 |
+
global_orient = np.array(data['global_orient']).squeeze()
|
| 53 |
+
body_pose = np.array(data['body_pose_axis'])
|
| 54 |
+
left_hand_pose = np.array(data['left_hand_pose'])
|
| 55 |
+
right_hand_pose = np.array(data['right_hand_pose'])
|
| 56 |
+
|
| 57 |
+
full_body = np.concatenate(
|
| 58 |
+
(jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose), axis=1)
|
| 59 |
+
|
| 60 |
+
expression = np.array(data['expression'])
|
| 61 |
+
full_body = np.concatenate((full_body, expression), axis=1)
|
| 62 |
+
|
| 63 |
+
if (full_body.shape[0] < 90) or (torch.isnan(torch.from_numpy(full_body)).sum() > 0):
|
| 64 |
+
return 1
|
| 65 |
+
else:
|
| 66 |
+
return 0
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
for speaker_name in speakers:
|
| 70 |
+
speaker_root = os.path.join(data_root, speaker_name)
|
| 71 |
+
|
| 72 |
+
videos = [v for v in os.listdir(speaker_root)]
|
| 73 |
+
print(videos)
|
| 74 |
+
|
| 75 |
+
haode = huaide = 0
|
| 76 |
+
total_seqs = []
|
| 77 |
+
|
| 78 |
+
for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
|
| 79 |
+
# for vid in videos:
|
| 80 |
+
source_vid = vid
|
| 81 |
+
vid_pth = os.path.join(speaker_root, source_vid)
|
| 82 |
+
# vid_pth = os.path.join(speaker_root, source_vid, 'images/half', split)
|
| 83 |
+
t = os.path.join(speaker_root, source_vid, 'test')
|
| 84 |
+
v = os.path.join(speaker_root, source_vid, 'val')
|
| 85 |
+
|
| 86 |
+
# if os.path.exists(t):
|
| 87 |
+
# shutil.rmtree(t)
|
| 88 |
+
# if os.path.exists(v):
|
| 89 |
+
# shutil.rmtree(v)
|
| 90 |
+
try:
|
| 91 |
+
seqs = [s for s in os.listdir(vid_pth)]
|
| 92 |
+
except:
|
| 93 |
+
continue
|
| 94 |
+
# if len(seqs) == 0:
|
| 95 |
+
# shutil.rmtree(os.path.join(speaker_root, source_vid))
|
| 96 |
+
# None
|
| 97 |
+
for s in seqs:
|
| 98 |
+
quality = 0
|
| 99 |
+
total_seqs.append(os.path.join(vid_pth,s))
|
| 100 |
+
seq_root = os.path.join(vid_pth, s)
|
| 101 |
+
key = seq_root # correspond to clip******
|
| 102 |
+
audio_fname = os.path.join(speaker_root, source_vid, s, '%s.wav' % (s))
|
| 103 |
+
|
| 104 |
+
# delete the data without audio or the audio file could not be read
|
| 105 |
+
if os.path.isfile(audio_fname):
|
| 106 |
+
try:
|
| 107 |
+
audio = librosa.load(audio_fname)
|
| 108 |
+
except:
|
| 109 |
+
# print(key)
|
| 110 |
+
shutil.rmtree(key)
|
| 111 |
+
huaide = huaide + 1
|
| 112 |
+
continue
|
| 113 |
+
else:
|
| 114 |
+
huaide = huaide + 1
|
| 115 |
+
# print(key)
|
| 116 |
+
shutil.rmtree(key)
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
# check motion file
|
| 120 |
+
motion_fname = os.path.join(speaker_root, source_vid, s, '%s.pkl' % (s))
|
| 121 |
+
try:
|
| 122 |
+
f = open(motion_fname, 'rb+')
|
| 123 |
+
except:
|
| 124 |
+
shutil.rmtree(key)
|
| 125 |
+
huaide = huaide + 1
|
| 126 |
+
continue
|
| 127 |
+
|
| 128 |
+
data = pickle.load(f)
|
| 129 |
+
w = read_pkl(data)
|
| 130 |
+
f.close()
|
| 131 |
+
quality = quality + w
|
| 132 |
+
|
| 133 |
+
if w == 1:
|
| 134 |
+
shutil.rmtree(key)
|
| 135 |
+
# print(key)
|
| 136 |
+
huaide = huaide + 1
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
haode = haode + 1
|
| 140 |
+
|
| 141 |
+
print("huaide:{}, haode:{}, total_seqs:{}".format(huaide, haode, total_seqs.__len__()))
|
| 142 |
+
|
| 143 |
+
for speaker_name in speakers:
|
| 144 |
+
speaker_root = os.path.join(data_root, speaker_name)
|
| 145 |
+
|
| 146 |
+
videos = [v for v in os.listdir(speaker_root)]
|
| 147 |
+
print(videos)
|
| 148 |
+
|
| 149 |
+
haode = huaide = 0
|
| 150 |
+
total_seqs = []
|
| 151 |
+
|
| 152 |
+
for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
|
| 153 |
+
# for vid in videos:
|
| 154 |
+
source_vid = vid
|
| 155 |
+
vid_pth = os.path.join(speaker_root, source_vid)
|
| 156 |
+
try:
|
| 157 |
+
seqs = [s for s in os.listdir(vid_pth)]
|
| 158 |
+
except:
|
| 159 |
+
continue
|
| 160 |
+
for s in seqs:
|
| 161 |
+
quality = 0
|
| 162 |
+
total_seqs.append(os.path.join(vid_pth, s))
|
| 163 |
+
print("total_seqs:{}".format(total_seqs.__len__()))
|
| 164 |
+
# split the dataset
|
| 165 |
+
test_list, val_list, train_list = split_list(total_seqs, True, 0.1)
|
| 166 |
+
print(len(test_list), len(val_list), len(train_list))
|
| 167 |
+
moveto(train_list, 'train')
|
| 168 |
+
moveto(test_list, 'test')
|
| 169 |
+
moveto(val_list, 'val')
|
| 170 |
+
|
data_utils/get_j.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def to3d(poses, config):
|
| 5 |
+
if config.Data.pose.convert_to_6d:
|
| 6 |
+
if config.Data.pose.expression:
|
| 7 |
+
poses_exp = poses[:, -100:]
|
| 8 |
+
poses = poses[:, :-100]
|
| 9 |
+
|
| 10 |
+
poses = poses.reshape(poses.shape[0], -1, 5)
|
| 11 |
+
sin, cos = poses[:, :, 3], poses[:, :, 4]
|
| 12 |
+
pose_angle = torch.atan2(sin, cos)
|
| 13 |
+
poses = (poses[:, :, :3] * pose_angle.unsqueeze(dim=-1)).reshape(poses.shape[0], -1)
|
| 14 |
+
|
| 15 |
+
if config.Data.pose.expression:
|
| 16 |
+
poses = torch.cat([poses, poses_exp], dim=-1)
|
| 17 |
+
return poses
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_joint(smplx_model, betas, pred):
|
| 21 |
+
joint = smplx_model(betas=betas.repeat(pred.shape[0], 1),
|
| 22 |
+
expression=pred[:, 165:265],
|
| 23 |
+
jaw_pose=pred[:, 0:3],
|
| 24 |
+
leye_pose=pred[:, 3:6],
|
| 25 |
+
reye_pose=pred[:, 6:9],
|
| 26 |
+
global_orient=pred[:, 9:12],
|
| 27 |
+
body_pose=pred[:, 12:75],
|
| 28 |
+
left_hand_pose=pred[:, 75:120],
|
| 29 |
+
right_hand_pose=pred[:, 120:165],
|
| 30 |
+
return_verts=True)['joints']
|
| 31 |
+
return joint
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_joints(smplx_model, betas, pred):
|
| 35 |
+
if len(pred.shape) == 3:
|
| 36 |
+
B = pred.shape[0]
|
| 37 |
+
x = 4 if B>= 4 else B
|
| 38 |
+
T = pred.shape[1]
|
| 39 |
+
pred = pred.reshape(-1, 265)
|
| 40 |
+
smplx_model.batch_size = L = T * x
|
| 41 |
+
|
| 42 |
+
times = pred.shape[0] // smplx_model.batch_size
|
| 43 |
+
joints = []
|
| 44 |
+
for i in range(times):
|
| 45 |
+
joints.append(get_joint(smplx_model, betas, pred[i*L:(i+1)*L]))
|
| 46 |
+
joints = torch.cat(joints, dim=0)
|
| 47 |
+
joints = joints.reshape(B, T, -1, 3)
|
| 48 |
+
else:
|
| 49 |
+
smplx_model.batch_size = pred.shape[0]
|
| 50 |
+
joints = get_joint(smplx_model, betas, pred)
|
| 51 |
+
return joints
|
data_utils/hand_component.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data_utils/lower_body.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
lower_pose = torch.tensor(
|
| 5 |
+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0747, -0.0158, -0.0152, -1.1826512813568115, 0.23866955935955048,
|
| 6 |
+
0.15146760642528534, -1.2604516744613647, -0.3160211145877838,
|
| 7 |
+
-0.1603458970785141, 1.1654603481292725, 0.0, 0.0, 1.2521806955337524, 0.041598282754421234, -0.06312154978513718,
|
| 8 |
+
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
|
| 9 |
+
lower_pose_stand = torch.tensor([
|
| 10 |
+
8.9759e-04, 7.1074e-04, -5.9163e-06, 8.9759e-04, 7.1074e-04, -5.9163e-06,
|
| 11 |
+
3.0747, -0.0158, -0.0152,
|
| 12 |
+
-3.6665e-01, -8.8455e-03, 1.6113e-01, -3.6665e-01, -8.8455e-03, 1.6113e-01,
|
| 13 |
+
-3.9716e-01, -4.0229e-02, -1.2637e-01,
|
| 14 |
+
7.9163e-01, 6.8519e-02, -1.5091e-01, 7.9163e-01, 6.8519e-02, -1.5091e-01,
|
| 15 |
+
7.8632e-01, -4.3810e-02, 1.4375e-02,
|
| 16 |
+
-1.0675e-01, 1.2635e-01, 1.6711e-02, -1.0675e-01, 1.2635e-01, 1.6711e-02, ])
|
| 17 |
+
# lower_pose_stand = torch.tensor(
|
| 18 |
+
# [6.4919e-02, 3.3018e-02, 1.7485e-02, 8.9759e-04, 7.1074e-04, -5.9163e-06,
|
| 19 |
+
# 3.0747, -0.0158, -0.0152,
|
| 20 |
+
# -3.3633e+00, -9.3915e-02, 3.0996e-01, -3.6665e-01, -8.8455e-03, 1.6113e-01,
|
| 21 |
+
# 1.1654603481292725, 0.0, 0.0,
|
| 22 |
+
# 4.4167e-01, 6.7183e-03, -3.6379e-03, 7.9163e-01, 6.8519e-02, -1.5091e-01,
|
| 23 |
+
# 0.0, 0.0, 0.0,
|
| 24 |
+
# 2.2910e-02, -2.4797e-02, -5.5657e-03, -1.0675e-01, 1.2635e-01, 1.6711e-02,])
|
| 25 |
+
lower_body = [0, 1, 3, 4, 6, 7, 9, 10]
|
| 26 |
+
count_part = [6, 9, 12, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
|
| 27 |
+
29, 30, 31, 32, 33, 34, 35, 36, 37,
|
| 28 |
+
38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54]
|
| 29 |
+
fix_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
|
| 30 |
+
29,
|
| 31 |
+
35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
|
| 32 |
+
50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
|
| 33 |
+
65, 66, 67, 68, 69, 70, 71, 72, 73, 74]
|
| 34 |
+
all_index = np.ones(275)
|
| 35 |
+
all_index[fix_index] = 0
|
| 36 |
+
c_index = []
|
| 37 |
+
i = 0
|
| 38 |
+
for num in all_index:
|
| 39 |
+
if num == 1:
|
| 40 |
+
c_index.append(i)
|
| 41 |
+
i = i + 1
|
| 42 |
+
c_index = np.asarray(c_index)
|
| 43 |
+
|
| 44 |
+
fix_index_3d = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
|
| 45 |
+
21, 22, 23, 24, 25, 26,
|
| 46 |
+
30, 31, 32, 33, 34, 35,
|
| 47 |
+
45, 46, 47, 48, 49, 50]
|
| 48 |
+
all_index_3d = np.ones(165)
|
| 49 |
+
all_index_3d[fix_index_3d] = 0
|
| 50 |
+
c_index_3d = []
|
| 51 |
+
i = 0
|
| 52 |
+
for num in all_index_3d:
|
| 53 |
+
if num == 1:
|
| 54 |
+
c_index_3d.append(i)
|
| 55 |
+
i = i + 1
|
| 56 |
+
c_index_3d = np.asarray(c_index_3d)
|
| 57 |
+
|
| 58 |
+
c_index_6d = []
|
| 59 |
+
i = 0
|
| 60 |
+
for num in all_index_3d:
|
| 61 |
+
if num == 1:
|
| 62 |
+
c_index_6d.append(2*i)
|
| 63 |
+
c_index_6d.append(2 * i + 1)
|
| 64 |
+
i = i + 1
|
| 65 |
+
c_index_6d = np.asarray(c_index_6d)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def part2full(input, stand=False):
|
| 69 |
+
if stand:
|
| 70 |
+
# lp = lower_pose_stand.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
|
| 71 |
+
lp = torch.zeros_like(lower_pose)
|
| 72 |
+
lp[6:9] = torch.tensor([3.0747, -0.0158, -0.0152])
|
| 73 |
+
lp = lp.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
|
| 74 |
+
else:
|
| 75 |
+
lp = lower_pose.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
|
| 76 |
+
|
| 77 |
+
input = torch.cat([input[:, :3],
|
| 78 |
+
lp[:, :15],
|
| 79 |
+
input[:, 3:6],
|
| 80 |
+
lp[:, 15:21],
|
| 81 |
+
input[:, 6:9],
|
| 82 |
+
lp[:, 21:27],
|
| 83 |
+
input[:, 9:12],
|
| 84 |
+
lp[:, 27:],
|
| 85 |
+
input[:, 12:]]
|
| 86 |
+
, dim=1)
|
| 87 |
+
return input
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def pred2poses(input, gt):
|
| 91 |
+
input = torch.cat([input[:, :3],
|
| 92 |
+
gt[0:1, 3:18].repeat(input.shape[0], 1),
|
| 93 |
+
input[:, 3:6],
|
| 94 |
+
gt[0:1, 21:27].repeat(input.shape[0], 1),
|
| 95 |
+
input[:, 6:9],
|
| 96 |
+
gt[0:1, 30:36].repeat(input.shape[0], 1),
|
| 97 |
+
input[:, 9:12],
|
| 98 |
+
gt[0:1, 39:45].repeat(input.shape[0], 1),
|
| 99 |
+
input[:, 12:]]
|
| 100 |
+
, dim=1)
|
| 101 |
+
return input
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def poses2poses(input, gt):
|
| 105 |
+
input = torch.cat([input[:, :3],
|
| 106 |
+
gt[0:1, 3:18].repeat(input.shape[0], 1),
|
| 107 |
+
input[:, 18:21],
|
| 108 |
+
gt[0:1, 21:27].repeat(input.shape[0], 1),
|
| 109 |
+
input[:, 27:30],
|
| 110 |
+
gt[0:1, 30:36].repeat(input.shape[0], 1),
|
| 111 |
+
input[:, 36:39],
|
| 112 |
+
gt[0:1, 39:45].repeat(input.shape[0], 1),
|
| 113 |
+
input[:, 45:]]
|
| 114 |
+
, dim=1)
|
| 115 |
+
return input
|
| 116 |
+
|
| 117 |
+
def poses2pred(input, stand=False):
|
| 118 |
+
if stand:
|
| 119 |
+
lp = lower_pose_stand.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
|
| 120 |
+
# lp = torch.zeros_like(lower_pose).unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
|
| 121 |
+
else:
|
| 122 |
+
lp = lower_pose.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
|
| 123 |
+
input = torch.cat([input[:, :3],
|
| 124 |
+
lp[:, :15],
|
| 125 |
+
input[:, 18:21],
|
| 126 |
+
lp[:, 15:21],
|
| 127 |
+
input[:, 27:30],
|
| 128 |
+
lp[:, 21:27],
|
| 129 |
+
input[:, 36:39],
|
| 130 |
+
lp[:, 27:],
|
| 131 |
+
input[:, 45:]]
|
| 132 |
+
, dim=1)
|
| 133 |
+
return input
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
rearrange = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]\
|
| 137 |
+
# ,22, 23, 24, 25, 40, 26, 41,
|
| 138 |
+
# 27, 42, 28, 43, 29, 44, 30, 45, 31, 46, 32, 47, 33, 48, 34, 49, 35, 50, 36, 51, 37, 52, 38, 53, 39, 54, 55,
|
| 139 |
+
# 57, 56, 59, 58, 60, 63, 61, 64, 62, 65, 66, 71, 67, 72, 68, 73, 69, 74, 70, 75]
|
| 140 |
+
|
| 141 |
+
symmetry = [0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1]#, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
| 142 |
+
# 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
| 143 |
+
# 1, 1, 1, 1, 1, 1]
|
data_utils/mesh_dataset.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
sys.path.append(os.getcwd())
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
from glob import glob
|
| 9 |
+
from data_utils.utils import *
|
| 10 |
+
import torch.utils.data as data
|
| 11 |
+
from data_utils.consts import speaker_id
|
| 12 |
+
from data_utils.lower_body import count_part
|
| 13 |
+
import random
|
| 14 |
+
from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d
|
| 15 |
+
|
| 16 |
+
with open('data_utils/hand_component.json') as file_obj:
|
| 17 |
+
comp = json.load(file_obj)
|
| 18 |
+
left_hand_c = np.asarray(comp['left'])
|
| 19 |
+
right_hand_c = np.asarray(comp['right'])
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def to3d(data):
|
| 23 |
+
left_hand_pose = np.einsum('bi,ij->bj', data[:, 75:87], left_hand_c[:12, :])
|
| 24 |
+
right_hand_pose = np.einsum('bi,ij->bj', data[:, 87:99], right_hand_c[:12, :])
|
| 25 |
+
data = np.concatenate((data[:, :75], left_hand_pose, right_hand_pose), axis=-1)
|
| 26 |
+
return data
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SmplxDataset():
|
| 30 |
+
'''
|
| 31 |
+
creat a dataset for every segment and concat.
|
| 32 |
+
'''
|
| 33 |
+
|
| 34 |
+
def __init__(self,
|
| 35 |
+
data_root,
|
| 36 |
+
speaker,
|
| 37 |
+
motion_fn,
|
| 38 |
+
audio_fn,
|
| 39 |
+
audio_sr,
|
| 40 |
+
fps,
|
| 41 |
+
feat_method='mel_spec',
|
| 42 |
+
audio_feat_dim=64,
|
| 43 |
+
audio_feat_win_size=None,
|
| 44 |
+
|
| 45 |
+
train=True,
|
| 46 |
+
load_all=False,
|
| 47 |
+
split_trans_zero=False,
|
| 48 |
+
limbscaling=False,
|
| 49 |
+
num_frames=25,
|
| 50 |
+
num_pre_frames=25,
|
| 51 |
+
num_generate_length=25,
|
| 52 |
+
context_info=False,
|
| 53 |
+
convert_to_6d=False,
|
| 54 |
+
expression=False,
|
| 55 |
+
config=None,
|
| 56 |
+
am=None,
|
| 57 |
+
am_sr=None,
|
| 58 |
+
whole_video=False
|
| 59 |
+
):
|
| 60 |
+
|
| 61 |
+
self.data_root = data_root
|
| 62 |
+
self.speaker = speaker
|
| 63 |
+
|
| 64 |
+
self.feat_method = feat_method
|
| 65 |
+
self.audio_fn = audio_fn
|
| 66 |
+
self.audio_sr = audio_sr
|
| 67 |
+
self.fps = fps
|
| 68 |
+
self.audio_feat_dim = audio_feat_dim
|
| 69 |
+
self.audio_feat_win_size = audio_feat_win_size
|
| 70 |
+
self.context_info = context_info # for aud feat
|
| 71 |
+
self.convert_to_6d = convert_to_6d
|
| 72 |
+
self.expression = expression
|
| 73 |
+
|
| 74 |
+
self.train = train
|
| 75 |
+
self.load_all = load_all
|
| 76 |
+
self.split_trans_zero = split_trans_zero
|
| 77 |
+
self.limbscaling = limbscaling
|
| 78 |
+
self.num_frames = num_frames
|
| 79 |
+
self.num_pre_frames = num_pre_frames
|
| 80 |
+
self.num_generate_length = num_generate_length
|
| 81 |
+
# print('num_generate_length ', self.num_generate_length)
|
| 82 |
+
|
| 83 |
+
self.config = config
|
| 84 |
+
self.am_sr = am_sr
|
| 85 |
+
self.whole_video = whole_video
|
| 86 |
+
load_mode = self.config.dataset_load_mode
|
| 87 |
+
|
| 88 |
+
if load_mode == 'pickle':
|
| 89 |
+
raise NotImplementedError
|
| 90 |
+
|
| 91 |
+
elif load_mode == 'csv':
|
| 92 |
+
import pickle
|
| 93 |
+
with open(data_root, 'rb') as f:
|
| 94 |
+
u = pickle._Unpickler(f)
|
| 95 |
+
data = u.load()
|
| 96 |
+
self.data = data[0]
|
| 97 |
+
if self.load_all:
|
| 98 |
+
self._load_npz_all()
|
| 99 |
+
|
| 100 |
+
elif load_mode == 'json':
|
| 101 |
+
self.annotations = glob(data_root + '/*pkl')
|
| 102 |
+
if len(self.annotations) == 0:
|
| 103 |
+
raise FileNotFoundError(data_root + ' are empty')
|
| 104 |
+
self.annotations = sorted(self.annotations)
|
| 105 |
+
self.img_name_list = self.annotations
|
| 106 |
+
|
| 107 |
+
if self.load_all:
|
| 108 |
+
self._load_them_all(am, am_sr, motion_fn)
|
| 109 |
+
|
| 110 |
+
def _load_npz_all(self):
|
| 111 |
+
self.loaded_data = {}
|
| 112 |
+
self.complete_data = []
|
| 113 |
+
data = self.data
|
| 114 |
+
shape = data['body_pose_axis'].shape[0]
|
| 115 |
+
self.betas = data['betas']
|
| 116 |
+
self.img_name_list = []
|
| 117 |
+
for index in range(shape):
|
| 118 |
+
img_name = f'{index:6d}'
|
| 119 |
+
self.img_name_list.append(img_name)
|
| 120 |
+
|
| 121 |
+
jaw_pose = data['jaw_pose'][index]
|
| 122 |
+
leye_pose = data['leye_pose'][index]
|
| 123 |
+
reye_pose = data['reye_pose'][index]
|
| 124 |
+
global_orient = data['global_orient'][index]
|
| 125 |
+
body_pose = data['body_pose_axis'][index]
|
| 126 |
+
left_hand_pose = data['left_hand_pose'][index]
|
| 127 |
+
right_hand_pose = data['right_hand_pose'][index]
|
| 128 |
+
|
| 129 |
+
full_body = np.concatenate(
|
| 130 |
+
(jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose))
|
| 131 |
+
assert full_body.shape[0] == 99
|
| 132 |
+
if self.convert_to_6d:
|
| 133 |
+
full_body = to3d(full_body)
|
| 134 |
+
full_body = torch.from_numpy(full_body)
|
| 135 |
+
full_body = matrix_to_rotation_6d(axis_angle_to_matrix(full_body))
|
| 136 |
+
full_body = np.asarray(full_body)
|
| 137 |
+
if self.expression:
|
| 138 |
+
expression = data['expression'][index]
|
| 139 |
+
full_body = np.concatenate((full_body, expression))
|
| 140 |
+
# full_body = np.concatenate((full_body, non_zero))
|
| 141 |
+
else:
|
| 142 |
+
full_body = to3d(full_body)
|
| 143 |
+
if self.expression:
|
| 144 |
+
expression = data['expression'][index]
|
| 145 |
+
full_body = np.concatenate((full_body, expression))
|
| 146 |
+
|
| 147 |
+
self.loaded_data[img_name] = full_body.reshape(-1)
|
| 148 |
+
self.complete_data.append(full_body.reshape(-1))
|
| 149 |
+
|
| 150 |
+
self.complete_data = np.array(self.complete_data)
|
| 151 |
+
|
| 152 |
+
if self.audio_feat_win_size is not None:
|
| 153 |
+
self.audio_feat = get_mfcc_old(self.audio_fn).transpose(1, 0)
|
| 154 |
+
# print(self.audio_feat.shape)
|
| 155 |
+
else:
|
| 156 |
+
if self.feat_method == 'mel_spec':
|
| 157 |
+
self.audio_feat = get_melspec(self.audio_fn, fps=self.fps, sr=self.audio_sr, n_mels=self.audio_feat_dim)
|
| 158 |
+
elif self.feat_method == 'mfcc':
|
| 159 |
+
self.audio_feat = get_mfcc(self.audio_fn,
|
| 160 |
+
smlpx=True,
|
| 161 |
+
sr=self.audio_sr,
|
| 162 |
+
n_mfcc=self.audio_feat_dim,
|
| 163 |
+
win_size=self.audio_feat_win_size
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
def _load_them_all(self, am, am_sr, motion_fn):
|
| 167 |
+
self.loaded_data = {}
|
| 168 |
+
self.complete_data = []
|
| 169 |
+
f = open(motion_fn, 'rb+')
|
| 170 |
+
data = pickle.load(f)
|
| 171 |
+
|
| 172 |
+
self.betas = np.array(data['betas'])
|
| 173 |
+
|
| 174 |
+
jaw_pose = np.array(data['jaw_pose'])
|
| 175 |
+
leye_pose = np.array(data['leye_pose'])
|
| 176 |
+
reye_pose = np.array(data['reye_pose'])
|
| 177 |
+
global_orient = np.array(data['global_orient']).squeeze()
|
| 178 |
+
body_pose = np.array(data['body_pose_axis'])
|
| 179 |
+
left_hand_pose = np.array(data['left_hand_pose'])
|
| 180 |
+
right_hand_pose = np.array(data['right_hand_pose'])
|
| 181 |
+
|
| 182 |
+
full_body = np.concatenate(
|
| 183 |
+
(jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose), axis=1)
|
| 184 |
+
assert full_body.shape[1] == 99
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
if self.convert_to_6d:
|
| 188 |
+
full_body = to3d(full_body)
|
| 189 |
+
full_body = torch.from_numpy(full_body)
|
| 190 |
+
full_body = matrix_to_rotation_6d(axis_angle_to_matrix(full_body.reshape(-1, 55, 3))).reshape(-1, 330)
|
| 191 |
+
full_body = np.asarray(full_body)
|
| 192 |
+
if self.expression:
|
| 193 |
+
expression = np.array(data['expression'])
|
| 194 |
+
full_body = np.concatenate((full_body, expression), axis=1)
|
| 195 |
+
|
| 196 |
+
else:
|
| 197 |
+
full_body = to3d(full_body)
|
| 198 |
+
expression = np.array(data['expression'])
|
| 199 |
+
full_body = np.concatenate((full_body, expression), axis=1)
|
| 200 |
+
|
| 201 |
+
self.complete_data = full_body
|
| 202 |
+
self.complete_data = np.array(self.complete_data)
|
| 203 |
+
|
| 204 |
+
if self.audio_feat_win_size is not None:
|
| 205 |
+
self.audio_feat = get_mfcc_old(self.audio_fn).transpose(1, 0)
|
| 206 |
+
else:
|
| 207 |
+
# if self.feat_method == 'mel_spec':
|
| 208 |
+
# self.audio_feat = get_melspec(self.audio_fn, fps=self.fps, sr=self.audio_sr, n_mels=self.audio_feat_dim)
|
| 209 |
+
# elif self.feat_method == 'mfcc':
|
| 210 |
+
self.audio_feat = get_mfcc_ta(self.audio_fn,
|
| 211 |
+
smlpx=True,
|
| 212 |
+
fps=30,
|
| 213 |
+
sr=self.audio_sr,
|
| 214 |
+
n_mfcc=self.audio_feat_dim,
|
| 215 |
+
win_size=self.audio_feat_win_size,
|
| 216 |
+
type=self.feat_method,
|
| 217 |
+
am=am,
|
| 218 |
+
am_sr=am_sr,
|
| 219 |
+
encoder_choice=self.config.Model.encoder_choice,
|
| 220 |
+
)
|
| 221 |
+
# with open(audio_file, 'w', encoding='utf-8') as file:
|
| 222 |
+
# file.write(json.dumps(self.audio_feat.__array__().tolist(), indent=0, ensure_ascii=False))
|
| 223 |
+
|
| 224 |
+
def get_dataset(self, normalization=False, normalize_stats=None, split='train'):
|
| 225 |
+
|
| 226 |
+
class __Worker__(data.Dataset):
|
| 227 |
+
def __init__(child, index_list, normalization, normalize_stats, split='train') -> None:
|
| 228 |
+
super().__init__()
|
| 229 |
+
child.index_list = index_list
|
| 230 |
+
child.normalization = normalization
|
| 231 |
+
child.normalize_stats = normalize_stats
|
| 232 |
+
child.split = split
|
| 233 |
+
|
| 234 |
+
def __getitem__(child, index):
|
| 235 |
+
num_generate_length = self.num_generate_length
|
| 236 |
+
num_pre_frames = self.num_pre_frames
|
| 237 |
+
seq_len = num_generate_length + num_pre_frames
|
| 238 |
+
# print(num_generate_length)
|
| 239 |
+
|
| 240 |
+
index = child.index_list[index]
|
| 241 |
+
index_new = index + random.randrange(0, 5, 3)
|
| 242 |
+
if index_new + seq_len > self.complete_data.shape[0]:
|
| 243 |
+
index_new = index
|
| 244 |
+
index = index_new
|
| 245 |
+
|
| 246 |
+
if child.split in ['val', 'pre', 'test'] or self.whole_video:
|
| 247 |
+
index = 0
|
| 248 |
+
seq_len = self.complete_data.shape[0]
|
| 249 |
+
seq_data = []
|
| 250 |
+
assert index + seq_len <= self.complete_data.shape[0]
|
| 251 |
+
# print(seq_len)
|
| 252 |
+
seq_data = self.complete_data[index:(index + seq_len), :]
|
| 253 |
+
seq_data = np.array(seq_data)
|
| 254 |
+
|
| 255 |
+
'''
|
| 256 |
+
audio feature,
|
| 257 |
+
'''
|
| 258 |
+
if not self.context_info:
|
| 259 |
+
if not self.whole_video:
|
| 260 |
+
audio_feat = self.audio_feat[index:index + seq_len, ...]
|
| 261 |
+
if audio_feat.shape[0] < seq_len:
|
| 262 |
+
audio_feat = np.pad(audio_feat, [[0, seq_len - audio_feat.shape[0]], [0, 0]],
|
| 263 |
+
mode='reflect')
|
| 264 |
+
|
| 265 |
+
assert audio_feat.shape[0] == seq_len and audio_feat.shape[1] == self.audio_feat_dim
|
| 266 |
+
else:
|
| 267 |
+
audio_feat = self.audio_feat
|
| 268 |
+
|
| 269 |
+
else: # including feature and history
|
| 270 |
+
if self.audio_feat_win_size is None:
|
| 271 |
+
audio_feat = self.audio_feat[index:index + seq_len + num_pre_frames, ...]
|
| 272 |
+
if audio_feat.shape[0] < seq_len + num_pre_frames:
|
| 273 |
+
audio_feat = np.pad(audio_feat,
|
| 274 |
+
[[0, seq_len + self.num_frames - audio_feat.shape[0]], [0, 0]],
|
| 275 |
+
mode='constant')
|
| 276 |
+
|
| 277 |
+
assert audio_feat.shape[0] == self.num_frames + seq_len and audio_feat.shape[
|
| 278 |
+
1] == self.audio_feat_dim
|
| 279 |
+
|
| 280 |
+
if child.normalization:
|
| 281 |
+
data_mean = child.normalize_stats['mean'].reshape(1, -1)
|
| 282 |
+
data_std = child.normalize_stats['std'].reshape(1, -1)
|
| 283 |
+
seq_data[:, :330] = (seq_data[:, :330] - data_mean) / data_std
|
| 284 |
+
if child.split in['train', 'test']:
|
| 285 |
+
if self.convert_to_6d:
|
| 286 |
+
if self.expression:
|
| 287 |
+
data_sample = {
|
| 288 |
+
'poses': seq_data[:, :330].astype(np.float).transpose(1, 0),
|
| 289 |
+
'expression': seq_data[:, 330:].astype(np.float).transpose(1, 0),
|
| 290 |
+
# 'nzero': seq_data[:, 375:].astype(np.float).transpose(1, 0),
|
| 291 |
+
'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
|
| 292 |
+
'speaker': speaker_id[self.speaker],
|
| 293 |
+
'betas': self.betas,
|
| 294 |
+
'aud_file': self.audio_fn,
|
| 295 |
+
}
|
| 296 |
+
else:
|
| 297 |
+
data_sample = {
|
| 298 |
+
'poses': seq_data[:, :330].astype(np.float).transpose(1, 0),
|
| 299 |
+
'nzero': seq_data[:, 330:].astype(np.float).transpose(1, 0),
|
| 300 |
+
'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
|
| 301 |
+
'speaker': speaker_id[self.speaker],
|
| 302 |
+
'betas': self.betas
|
| 303 |
+
}
|
| 304 |
+
else:
|
| 305 |
+
if self.expression:
|
| 306 |
+
data_sample = {
|
| 307 |
+
'poses': seq_data[:, :165].astype(np.float).transpose(1, 0),
|
| 308 |
+
'expression': seq_data[:, 165:].astype(np.float).transpose(1, 0),
|
| 309 |
+
'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
|
| 310 |
+
# 'wv2_feat': wv2_feat.astype(np.float).transpose(1, 0),
|
| 311 |
+
'speaker': speaker_id[self.speaker],
|
| 312 |
+
'aud_file': self.audio_fn,
|
| 313 |
+
'betas': self.betas
|
| 314 |
+
}
|
| 315 |
+
else:
|
| 316 |
+
data_sample = {
|
| 317 |
+
'poses': seq_data.astype(np.float).transpose(1, 0),
|
| 318 |
+
'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
|
| 319 |
+
'speaker': speaker_id[self.speaker],
|
| 320 |
+
'betas': self.betas
|
| 321 |
+
}
|
| 322 |
+
return data_sample
|
| 323 |
+
else:
|
| 324 |
+
data_sample = {
|
| 325 |
+
'poses': seq_data[:, :330].astype(np.float).transpose(1, 0),
|
| 326 |
+
'expression': seq_data[:, 330:].astype(np.float).transpose(1, 0),
|
| 327 |
+
# 'nzero': seq_data[:, 325:].astype(np.float).transpose(1, 0),
|
| 328 |
+
'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
|
| 329 |
+
'aud_file': self.audio_fn,
|
| 330 |
+
'speaker': speaker_id[self.speaker],
|
| 331 |
+
'betas': self.betas
|
| 332 |
+
}
|
| 333 |
+
return data_sample
|
| 334 |
+
def __len__(child):
|
| 335 |
+
return len(child.index_list)
|
| 336 |
+
|
| 337 |
+
if split == 'train':
|
| 338 |
+
index_list = list(
|
| 339 |
+
range(0, min(self.complete_data.shape[0], self.audio_feat.shape[0]) - self.num_generate_length - self.num_pre_frames,
|
| 340 |
+
6))
|
| 341 |
+
elif split in ['val', 'test']:
|
| 342 |
+
index_list = list([0])
|
| 343 |
+
if self.whole_video:
|
| 344 |
+
index_list = list([0])
|
| 345 |
+
self.all_dataset = __Worker__(index_list, normalization, normalize_stats, split)
|
| 346 |
+
|
| 347 |
+
def __len__(self):
|
| 348 |
+
return len(self.img_name_list)
|
data_utils/rotation_conversion.py
ADDED
|
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
# Check PYTORCH3D_LICENCE before use
|
| 3 |
+
|
| 4 |
+
import functools
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
"""
|
| 12 |
+
The transformation matrices returned from the functions in this file assume
|
| 13 |
+
the points on which the transformation will be applied are column vectors.
|
| 14 |
+
i.e. the R matrix is structured as
|
| 15 |
+
|
| 16 |
+
R = [
|
| 17 |
+
[Rxx, Rxy, Rxz],
|
| 18 |
+
[Ryx, Ryy, Ryz],
|
| 19 |
+
[Rzx, Rzy, Rzz],
|
| 20 |
+
] # (3, 3)
|
| 21 |
+
|
| 22 |
+
This matrix can be applied to column vectors by post multiplication
|
| 23 |
+
by the points e.g.
|
| 24 |
+
|
| 25 |
+
points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
|
| 26 |
+
transformed_points = R * points
|
| 27 |
+
|
| 28 |
+
To apply the same matrix to points which are row vectors, the R matrix
|
| 29 |
+
can be transposed and pre multiplied by the points:
|
| 30 |
+
|
| 31 |
+
e.g.
|
| 32 |
+
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
|
| 33 |
+
transformed_points = points * R.transpose(1, 0)
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def quaternion_to_matrix(quaternions):
|
| 38 |
+
"""
|
| 39 |
+
Convert rotations given as quaternions to rotation matrices.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
quaternions: quaternions with real part first,
|
| 43 |
+
as tensor of shape (..., 4).
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 47 |
+
"""
|
| 48 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
| 49 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
| 50 |
+
|
| 51 |
+
o = torch.stack(
|
| 52 |
+
(
|
| 53 |
+
1 - two_s * (j * j + k * k),
|
| 54 |
+
two_s * (i * j - k * r),
|
| 55 |
+
two_s * (i * k + j * r),
|
| 56 |
+
two_s * (i * j + k * r),
|
| 57 |
+
1 - two_s * (i * i + k * k),
|
| 58 |
+
two_s * (j * k - i * r),
|
| 59 |
+
two_s * (i * k - j * r),
|
| 60 |
+
two_s * (j * k + i * r),
|
| 61 |
+
1 - two_s * (i * i + j * j),
|
| 62 |
+
),
|
| 63 |
+
-1,
|
| 64 |
+
)
|
| 65 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _copysign(a, b):
|
| 69 |
+
"""
|
| 70 |
+
Return a tensor where each element has the absolute value taken from the,
|
| 71 |
+
corresponding element of a, with sign taken from the corresponding
|
| 72 |
+
element of b. This is like the standard copysign floating-point operation,
|
| 73 |
+
but is not careful about negative 0 and NaN.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
a: source tensor.
|
| 77 |
+
b: tensor whose signs will be used, of the same shape as a.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
Tensor of the same shape as a with the signs of b.
|
| 81 |
+
"""
|
| 82 |
+
signs_differ = (a < 0) != (b < 0)
|
| 83 |
+
return torch.where(signs_differ, -a, a)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _sqrt_positive_part(x):
|
| 87 |
+
"""
|
| 88 |
+
Returns torch.sqrt(torch.max(0, x))
|
| 89 |
+
but with a zero subgradient where x is 0.
|
| 90 |
+
"""
|
| 91 |
+
ret = torch.zeros_like(x)
|
| 92 |
+
positive_mask = x > 0
|
| 93 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
| 94 |
+
return ret
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def matrix_to_quaternion(matrix):
|
| 98 |
+
"""
|
| 99 |
+
Convert rotations given as rotation matrices to quaternions.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
| 106 |
+
"""
|
| 107 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
| 108 |
+
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
|
| 109 |
+
m00 = matrix[..., 0, 0]
|
| 110 |
+
m11 = matrix[..., 1, 1]
|
| 111 |
+
m22 = matrix[..., 2, 2]
|
| 112 |
+
o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
|
| 113 |
+
x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
|
| 114 |
+
y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
|
| 115 |
+
z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
|
| 116 |
+
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
|
| 117 |
+
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
|
| 118 |
+
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
|
| 119 |
+
return torch.stack((o0, o1, o2, o3), -1)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _axis_angle_rotation(axis: str, angle):
|
| 123 |
+
"""
|
| 124 |
+
Return the rotation matrices for one of the rotations about an axis
|
| 125 |
+
of which Euler angles describe, for each value of the angle given.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
axis: Axis label "X" or "Y or "Z".
|
| 129 |
+
angle: any shape tensor of Euler angles in radians
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
cos = torch.cos(angle)
|
| 136 |
+
sin = torch.sin(angle)
|
| 137 |
+
one = torch.ones_like(angle)
|
| 138 |
+
zero = torch.zeros_like(angle)
|
| 139 |
+
|
| 140 |
+
if axis == "X":
|
| 141 |
+
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
|
| 142 |
+
if axis == "Y":
|
| 143 |
+
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
|
| 144 |
+
if axis == "Z":
|
| 145 |
+
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
|
| 146 |
+
|
| 147 |
+
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def euler_angles_to_matrix(euler_angles, convention: str):
|
| 151 |
+
"""
|
| 152 |
+
Convert rotations given as Euler angles in radians to rotation matrices.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
euler_angles: Euler angles in radians as tensor of shape (..., 3).
|
| 156 |
+
convention: Convention string of three uppercase letters from
|
| 157 |
+
{"X", "Y", and "Z"}.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 161 |
+
"""
|
| 162 |
+
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
|
| 163 |
+
raise ValueError("Invalid input euler angles.")
|
| 164 |
+
if len(convention) != 3:
|
| 165 |
+
raise ValueError("Convention must have 3 letters.")
|
| 166 |
+
if convention[1] in (convention[0], convention[2]):
|
| 167 |
+
raise ValueError(f"Invalid convention {convention}.")
|
| 168 |
+
for letter in convention:
|
| 169 |
+
if letter not in ("X", "Y", "Z"):
|
| 170 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
| 171 |
+
matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
|
| 172 |
+
return functools.reduce(torch.matmul, matrices)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _angle_from_tan(
|
| 176 |
+
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
|
| 177 |
+
):
|
| 178 |
+
"""
|
| 179 |
+
Extract the first or third Euler angle from the two members of
|
| 180 |
+
the matrix which are positive constant times its sine and cosine.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
|
| 184 |
+
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
|
| 185 |
+
convention.
|
| 186 |
+
data: Rotation matrices as tensor of shape (..., 3, 3).
|
| 187 |
+
horizontal: Whether we are looking for the angle for the third axis,
|
| 188 |
+
which means the relevant entries are in the same row of the
|
| 189 |
+
rotation matrix. If not, they are in the same column.
|
| 190 |
+
tait_bryan: Whether the first and third axes in the convention differ.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
Euler Angles in radians for each matrix in data as a tensor
|
| 194 |
+
of shape (...).
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
|
| 198 |
+
if horizontal:
|
| 199 |
+
i2, i1 = i1, i2
|
| 200 |
+
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
|
| 201 |
+
if horizontal == even:
|
| 202 |
+
return torch.atan2(data[..., i1], data[..., i2])
|
| 203 |
+
if tait_bryan:
|
| 204 |
+
return torch.atan2(-data[..., i2], data[..., i1])
|
| 205 |
+
return torch.atan2(data[..., i2], -data[..., i1])
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def _index_from_letter(letter: str):
|
| 209 |
+
if letter == "X":
|
| 210 |
+
return 0
|
| 211 |
+
if letter == "Y":
|
| 212 |
+
return 1
|
| 213 |
+
if letter == "Z":
|
| 214 |
+
return 2
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def matrix_to_euler_angles(matrix, convention: str):
|
| 218 |
+
"""
|
| 219 |
+
Convert rotations given as rotation matrices to Euler angles in radians.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 223 |
+
convention: Convention string of three uppercase letters.
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
Euler angles in radians as tensor of shape (..., 3).
|
| 227 |
+
"""
|
| 228 |
+
if len(convention) != 3:
|
| 229 |
+
raise ValueError("Convention must have 3 letters.")
|
| 230 |
+
if convention[1] in (convention[0], convention[2]):
|
| 231 |
+
raise ValueError(f"Invalid convention {convention}.")
|
| 232 |
+
for letter in convention:
|
| 233 |
+
if letter not in ("X", "Y", "Z"):
|
| 234 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
| 235 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
| 236 |
+
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
|
| 237 |
+
i0 = _index_from_letter(convention[0])
|
| 238 |
+
i2 = _index_from_letter(convention[2])
|
| 239 |
+
tait_bryan = i0 != i2
|
| 240 |
+
if tait_bryan:
|
| 241 |
+
central_angle = torch.asin(
|
| 242 |
+
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
|
| 243 |
+
)
|
| 244 |
+
else:
|
| 245 |
+
central_angle = torch.acos(matrix[..., i0, i0])
|
| 246 |
+
|
| 247 |
+
o = (
|
| 248 |
+
_angle_from_tan(
|
| 249 |
+
convention[0], convention[1], matrix[..., i2], False, tait_bryan
|
| 250 |
+
),
|
| 251 |
+
central_angle,
|
| 252 |
+
_angle_from_tan(
|
| 253 |
+
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
|
| 254 |
+
),
|
| 255 |
+
)
|
| 256 |
+
return torch.stack(o, -1)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def random_quaternions(
|
| 260 |
+
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
| 261 |
+
):
|
| 262 |
+
"""
|
| 263 |
+
Generate random quaternions representing rotations,
|
| 264 |
+
i.e. versors with nonnegative real part.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
n: Number of quaternions in a batch to return.
|
| 268 |
+
dtype: Type to return.
|
| 269 |
+
device: Desired device of returned tensor. Default:
|
| 270 |
+
uses the current device for the default tensor type.
|
| 271 |
+
requires_grad: Whether the resulting tensor should have the gradient
|
| 272 |
+
flag set.
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
Quaternions as tensor of shape (N, 4).
|
| 276 |
+
"""
|
| 277 |
+
o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
|
| 278 |
+
s = (o * o).sum(1)
|
| 279 |
+
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
|
| 280 |
+
return o
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def random_rotations(
|
| 284 |
+
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
| 285 |
+
):
|
| 286 |
+
"""
|
| 287 |
+
Generate random rotations as 3x3 rotation matrices.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
n: Number of rotation matrices in a batch to return.
|
| 291 |
+
dtype: Type to return.
|
| 292 |
+
device: Device of returned tensor. Default: if None,
|
| 293 |
+
uses the current device for the default tensor type.
|
| 294 |
+
requires_grad: Whether the resulting tensor should have the gradient
|
| 295 |
+
flag set.
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
Rotation matrices as tensor of shape (n, 3, 3).
|
| 299 |
+
"""
|
| 300 |
+
quaternions = random_quaternions(
|
| 301 |
+
n, dtype=dtype, device=device, requires_grad=requires_grad
|
| 302 |
+
)
|
| 303 |
+
return quaternion_to_matrix(quaternions)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def random_rotation(
|
| 307 |
+
dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
| 308 |
+
):
|
| 309 |
+
"""
|
| 310 |
+
Generate a single random 3x3 rotation matrix.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
dtype: Type to return
|
| 314 |
+
device: Device of returned tensor. Default: if None,
|
| 315 |
+
uses the current device for the default tensor type
|
| 316 |
+
requires_grad: Whether the resulting tensor should have the gradient
|
| 317 |
+
flag set
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
Rotation matrix as tensor of shape (3, 3).
|
| 321 |
+
"""
|
| 322 |
+
return random_rotations(1, dtype, device, requires_grad)[0]
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def standardize_quaternion(quaternions):
|
| 326 |
+
"""
|
| 327 |
+
Convert a unit quaternion to a standard form: one in which the real
|
| 328 |
+
part is non negative.
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
quaternions: Quaternions with real part first,
|
| 332 |
+
as tensor of shape (..., 4).
|
| 333 |
+
|
| 334 |
+
Returns:
|
| 335 |
+
Standardized quaternions as tensor of shape (..., 4).
|
| 336 |
+
"""
|
| 337 |
+
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def quaternion_raw_multiply(a, b):
|
| 341 |
+
"""
|
| 342 |
+
Multiply two quaternions.
|
| 343 |
+
Usual torch rules for broadcasting apply.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
| 347 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
| 348 |
+
|
| 349 |
+
Returns:
|
| 350 |
+
The product of a and b, a tensor of quaternions shape (..., 4).
|
| 351 |
+
"""
|
| 352 |
+
aw, ax, ay, az = torch.unbind(a, -1)
|
| 353 |
+
bw, bx, by, bz = torch.unbind(b, -1)
|
| 354 |
+
ow = aw * bw - ax * bx - ay * by - az * bz
|
| 355 |
+
ox = aw * bx + ax * bw + ay * bz - az * by
|
| 356 |
+
oy = aw * by - ax * bz + ay * bw + az * bx
|
| 357 |
+
oz = aw * bz + ax * by - ay * bx + az * bw
|
| 358 |
+
return torch.stack((ow, ox, oy, oz), -1)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def quaternion_multiply(a, b):
|
| 362 |
+
"""
|
| 363 |
+
Multiply two quaternions representing rotations, returning the quaternion
|
| 364 |
+
representing their composition, i.e. the versor with nonnegative real part.
|
| 365 |
+
Usual torch rules for broadcasting apply.
|
| 366 |
+
|
| 367 |
+
Args:
|
| 368 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
| 369 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
| 370 |
+
|
| 371 |
+
Returns:
|
| 372 |
+
The product of a and b, a tensor of quaternions of shape (..., 4).
|
| 373 |
+
"""
|
| 374 |
+
ab = quaternion_raw_multiply(a, b)
|
| 375 |
+
return standardize_quaternion(ab)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def quaternion_invert(quaternion):
|
| 379 |
+
"""
|
| 380 |
+
Given a quaternion representing rotation, get the quaternion representing
|
| 381 |
+
its inverse.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
quaternion: Quaternions as tensor of shape (..., 4), with real part
|
| 385 |
+
first, which must be versors (unit quaternions).
|
| 386 |
+
|
| 387 |
+
Returns:
|
| 388 |
+
The inverse, a tensor of quaternions of shape (..., 4).
|
| 389 |
+
"""
|
| 390 |
+
|
| 391 |
+
return quaternion * quaternion.new_tensor([1, -1, -1, -1])
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def quaternion_apply(quaternion, point):
|
| 395 |
+
"""
|
| 396 |
+
Apply the rotation given by a quaternion to a 3D point.
|
| 397 |
+
Usual torch rules for broadcasting apply.
|
| 398 |
+
|
| 399 |
+
Args:
|
| 400 |
+
quaternion: Tensor of quaternions, real part first, of shape (..., 4).
|
| 401 |
+
point: Tensor of 3D points of shape (..., 3).
|
| 402 |
+
|
| 403 |
+
Returns:
|
| 404 |
+
Tensor of rotated points of shape (..., 3).
|
| 405 |
+
"""
|
| 406 |
+
if point.size(-1) != 3:
|
| 407 |
+
raise ValueError(f"Points are not in 3D, f{point.shape}.")
|
| 408 |
+
real_parts = point.new_zeros(point.shape[:-1] + (1,))
|
| 409 |
+
point_as_quaternion = torch.cat((real_parts, point), -1)
|
| 410 |
+
out = quaternion_raw_multiply(
|
| 411 |
+
quaternion_raw_multiply(quaternion, point_as_quaternion),
|
| 412 |
+
quaternion_invert(quaternion),
|
| 413 |
+
)
|
| 414 |
+
return out[..., 1:]
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def axis_angle_to_matrix(axis_angle):
|
| 418 |
+
"""
|
| 419 |
+
Convert rotations given as axis/angle to rotation matrices.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
| 423 |
+
as a tensor of shape (..., 3), where the magnitude is
|
| 424 |
+
the angle turned anticlockwise in radians around the
|
| 425 |
+
vector's direction.
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 429 |
+
"""
|
| 430 |
+
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def matrix_to_axis_angle(matrix):
|
| 434 |
+
"""
|
| 435 |
+
Convert rotations given as rotation matrices to axis/angle.
|
| 436 |
+
|
| 437 |
+
Args:
|
| 438 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 439 |
+
|
| 440 |
+
Returns:
|
| 441 |
+
Rotations given as a vector in axis angle form, as a tensor
|
| 442 |
+
of shape (..., 3), where the magnitude is the angle
|
| 443 |
+
turned anticlockwise in radians around the vector's
|
| 444 |
+
direction.
|
| 445 |
+
"""
|
| 446 |
+
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def axis_angle_to_quaternion(axis_angle):
|
| 450 |
+
"""
|
| 451 |
+
Convert rotations given as axis/angle to quaternions.
|
| 452 |
+
|
| 453 |
+
Args:
|
| 454 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
| 455 |
+
as a tensor of shape (..., 3), where the magnitude is
|
| 456 |
+
the angle turned anticlockwise in radians around the
|
| 457 |
+
vector's direction.
|
| 458 |
+
|
| 459 |
+
Returns:
|
| 460 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
| 461 |
+
"""
|
| 462 |
+
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
|
| 463 |
+
half_angles = 0.5 * angles
|
| 464 |
+
eps = 1e-6
|
| 465 |
+
small_angles = angles.abs() < eps
|
| 466 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
| 467 |
+
sin_half_angles_over_angles[~small_angles] = (
|
| 468 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
| 469 |
+
)
|
| 470 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
| 471 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
| 472 |
+
sin_half_angles_over_angles[small_angles] = (
|
| 473 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
| 474 |
+
)
|
| 475 |
+
quaternions = torch.cat(
|
| 476 |
+
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
|
| 477 |
+
)
|
| 478 |
+
return quaternions
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def quaternion_to_axis_angle(quaternions):
|
| 482 |
+
"""
|
| 483 |
+
Convert rotations given as quaternions to axis/angle.
|
| 484 |
+
|
| 485 |
+
Args:
|
| 486 |
+
quaternions: quaternions with real part first,
|
| 487 |
+
as tensor of shape (..., 4).
|
| 488 |
+
|
| 489 |
+
Returns:
|
| 490 |
+
Rotations given as a vector in axis angle form, as a tensor
|
| 491 |
+
of shape (..., 3), where the magnitude is the angle
|
| 492 |
+
turned anticlockwise in radians around the vector's
|
| 493 |
+
direction.
|
| 494 |
+
"""
|
| 495 |
+
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
|
| 496 |
+
half_angles = torch.atan2(norms, quaternions[..., :1])
|
| 497 |
+
angles = 2 * half_angles
|
| 498 |
+
eps = 1e-6
|
| 499 |
+
small_angles = angles.abs() < eps
|
| 500 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
| 501 |
+
sin_half_angles_over_angles[~small_angles] = (
|
| 502 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
| 503 |
+
)
|
| 504 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
| 505 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
| 506 |
+
sin_half_angles_over_angles[small_angles] = (
|
| 507 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
| 508 |
+
)
|
| 509 |
+
return quaternions[..., 1:] / sin_half_angles_over_angles
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
|
| 513 |
+
"""
|
| 514 |
+
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
|
| 515 |
+
using Gram--Schmidt orthogonalisation per Section B of [1].
|
| 516 |
+
Args:
|
| 517 |
+
d6: 6D rotation representation, of size (*, 6)
|
| 518 |
+
|
| 519 |
+
Returns:
|
| 520 |
+
batch of rotation matrices of size (*, 3, 3)
|
| 521 |
+
|
| 522 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
| 523 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
| 524 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
| 525 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
| 526 |
+
"""
|
| 527 |
+
|
| 528 |
+
a1, a2 = d6[..., :3], d6[..., 3:]
|
| 529 |
+
b1 = F.normalize(a1, dim=-1)
|
| 530 |
+
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
|
| 531 |
+
b2 = F.normalize(b2, dim=-1)
|
| 532 |
+
b3 = torch.cross(b1, b2, dim=-1)
|
| 533 |
+
return torch.stack((b1, b2, b3), dim=-2)
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
|
| 537 |
+
"""
|
| 538 |
+
Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
|
| 539 |
+
by dropping the last row. Note that 6D representation is not unique.
|
| 540 |
+
Args:
|
| 541 |
+
matrix: batch of rotation matrices of size (*, 3, 3)
|
| 542 |
+
|
| 543 |
+
Returns:
|
| 544 |
+
6D rotation representation, of size (*, 6)
|
| 545 |
+
|
| 546 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
| 547 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
| 548 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
| 549 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
| 550 |
+
"""
|
| 551 |
+
return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
|
data_utils/split_more_than_2s.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2df6e745cdf7473f13ce3ae2ed759c3cceb60c9197e7f3fd65110e7bc20b6f2d
|
| 3 |
+
size 2398875
|
data_utils/split_train_val_test.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import shutil
|
| 4 |
+
|
| 5 |
+
if __name__ =='__main__':
|
| 6 |
+
id_list = "chemistry conan oliver seth"
|
| 7 |
+
id_list = id_list.split(' ')
|
| 8 |
+
|
| 9 |
+
old_root = '/home/usename/talkshow_data/ExpressiveWholeBodyDatasetReleaseV1.0'
|
| 10 |
+
new_root = '/home/usename/talkshow_data/ExpressiveWholeBodyDatasetReleaseV1.0/talkshow_data_splited'
|
| 11 |
+
|
| 12 |
+
with open('train_val_test.json') as f:
|
| 13 |
+
split_info = json.load(f)
|
| 14 |
+
phase_list = ['train', 'val', 'test']
|
| 15 |
+
for phase in phase_list:
|
| 16 |
+
phase_path_list = split_info[phase]
|
| 17 |
+
for p in phase_path_list:
|
| 18 |
+
old_path = os.path.join(old_root, p)
|
| 19 |
+
if not os.path.exists(old_path):
|
| 20 |
+
print(f'{old_path} not found, continue' )
|
| 21 |
+
continue
|
| 22 |
+
new_path = os.path.join(new_root, phase, p)
|
| 23 |
+
dir_name = os.path.dirname(new_path)
|
| 24 |
+
if not os.path.isdir(dir_name):
|
| 25 |
+
os.makedirs(dir_name, exist_ok=True)
|
| 26 |
+
shutil.move(old_path, new_path)
|
| 27 |
+
|
data_utils/train_val_test.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data_utils/utils.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
# import librosa #has to do this cause librosa is not supported on my server
|
| 3 |
+
import python_speech_features
|
| 4 |
+
from scipy.io import wavfile
|
| 5 |
+
from scipy import signal
|
| 6 |
+
import librosa
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio as ta
|
| 9 |
+
import torchaudio.functional as ta_F
|
| 10 |
+
import torchaudio.transforms as ta_T
|
| 11 |
+
# import pyloudnorm as pyln
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_wav_old(audio_fn, sr = 16000):
|
| 15 |
+
sample_rate, sig = wavfile.read(audio_fn)
|
| 16 |
+
if sample_rate != sr:
|
| 17 |
+
result = int((sig.shape[0]) / sample_rate * sr)
|
| 18 |
+
x_resampled = signal.resample(sig, result)
|
| 19 |
+
x_resampled = x_resampled.astype(np.float64)
|
| 20 |
+
return x_resampled, sr
|
| 21 |
+
|
| 22 |
+
sig = sig / (2**15)
|
| 23 |
+
return sig, sample_rate
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_mfcc(audio_fn, eps=1e-6, fps=25, smlpx=False, sr=16000, n_mfcc=64, win_size=None):
|
| 27 |
+
|
| 28 |
+
y, sr = librosa.load(audio_fn, sr=sr, mono=True)
|
| 29 |
+
|
| 30 |
+
if win_size is None:
|
| 31 |
+
hop_len=int(sr / fps)
|
| 32 |
+
else:
|
| 33 |
+
hop_len=int(sr / win_size)
|
| 34 |
+
|
| 35 |
+
n_fft=2048
|
| 36 |
+
|
| 37 |
+
C = librosa.feature.mfcc(
|
| 38 |
+
y = y,
|
| 39 |
+
sr = sr,
|
| 40 |
+
n_mfcc = n_mfcc,
|
| 41 |
+
hop_length = hop_len,
|
| 42 |
+
n_fft = n_fft
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
if C.shape[0] == n_mfcc:
|
| 46 |
+
C = C.transpose(1, 0)
|
| 47 |
+
|
| 48 |
+
return C
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_melspec(audio_fn, eps=1e-6, fps = 25, sr=16000, n_mels=64):
|
| 52 |
+
raise NotImplementedError
|
| 53 |
+
'''
|
| 54 |
+
# y, sr = load_wav(audio_fn=audio_fn, sr=sr)
|
| 55 |
+
|
| 56 |
+
# hop_len = int(sr / fps)
|
| 57 |
+
# n_fft = 2048
|
| 58 |
+
|
| 59 |
+
# C = librosa.feature.melspectrogram(
|
| 60 |
+
# y = y,
|
| 61 |
+
# sr = sr,
|
| 62 |
+
# n_fft=n_fft,
|
| 63 |
+
# hop_length=hop_len,
|
| 64 |
+
# n_mels = n_mels,
|
| 65 |
+
# fmin=0,
|
| 66 |
+
# fmax=8000)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# mask = (C == 0).astype(np.float)
|
| 70 |
+
# C = mask * eps + (1-mask) * C
|
| 71 |
+
|
| 72 |
+
# C = np.log(C)
|
| 73 |
+
# #wierd error may occur here
|
| 74 |
+
# assert not (np.isnan(C).any()), audio_fn
|
| 75 |
+
# if C.shape[0] == n_mels:
|
| 76 |
+
# C = C.transpose(1, 0)
|
| 77 |
+
|
| 78 |
+
# return C
|
| 79 |
+
'''
|
| 80 |
+
|
| 81 |
+
def extract_mfcc(audio,sample_rate=16000):
|
| 82 |
+
mfcc = zip(*python_speech_features.mfcc(audio,sample_rate, numcep=64, nfilt=64, nfft=2048, winstep=0.04))
|
| 83 |
+
mfcc = np.stack([np.array(i) for i in mfcc])
|
| 84 |
+
return mfcc
|
| 85 |
+
|
| 86 |
+
def get_mfcc_psf(audio_fn, eps=1e-6, fps=25, smlpx=False, sr=16000, n_mfcc=64, win_size=None):
|
| 87 |
+
y, sr = load_wav_old(audio_fn, sr=sr)
|
| 88 |
+
|
| 89 |
+
if y.shape.__len__() > 1:
|
| 90 |
+
y = (y[:,0]+y[:,1])/2
|
| 91 |
+
|
| 92 |
+
if win_size is None:
|
| 93 |
+
hop_len=int(sr / fps)
|
| 94 |
+
else:
|
| 95 |
+
hop_len=int(sr/ win_size)
|
| 96 |
+
|
| 97 |
+
n_fft=2048
|
| 98 |
+
|
| 99 |
+
#hard coded for 25 fps
|
| 100 |
+
if not smlpx:
|
| 101 |
+
C = python_speech_features.mfcc(y, sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=0.04)
|
| 102 |
+
else:
|
| 103 |
+
C = python_speech_features.mfcc(y, sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=1.01/15)
|
| 104 |
+
# if C.shape[0] == n_mfcc:
|
| 105 |
+
# C = C.transpose(1, 0)
|
| 106 |
+
|
| 107 |
+
return C
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_mfcc_psf_min(audio_fn, eps=1e-6, fps=25, smlpx=False, sr=16000, n_mfcc=64, win_size=None):
|
| 111 |
+
y, sr = load_wav_old(audio_fn, sr=sr)
|
| 112 |
+
|
| 113 |
+
if y.shape.__len__() > 1:
|
| 114 |
+
y = (y[:, 0] + y[:, 1]) / 2
|
| 115 |
+
n_fft = 2048
|
| 116 |
+
|
| 117 |
+
slice_len = 22000 * 5
|
| 118 |
+
slice = y.size // slice_len
|
| 119 |
+
|
| 120 |
+
C = []
|
| 121 |
+
|
| 122 |
+
for i in range(slice):
|
| 123 |
+
if i != (slice - 1):
|
| 124 |
+
feat = python_speech_features.mfcc(y[i*slice_len:(i+1)*slice_len], sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=1.01 / 15)
|
| 125 |
+
else:
|
| 126 |
+
feat = python_speech_features.mfcc(y[i * slice_len:], sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=1.01 / 15)
|
| 127 |
+
|
| 128 |
+
C.append(feat)
|
| 129 |
+
|
| 130 |
+
return C
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000):
|
| 134 |
+
"""
|
| 135 |
+
:param audio: 1 x T tensor containing a 16kHz audio signal
|
| 136 |
+
:param frame_rate: frame rate for video (we need one audio chunk per video frame)
|
| 137 |
+
:param chunk_size: number of audio samples per chunk
|
| 138 |
+
:return: num_chunks x chunk_size tensor containing sliced audio
|
| 139 |
+
"""
|
| 140 |
+
samples_per_frame = chunk_size // frame_rate
|
| 141 |
+
padding = (chunk_size - samples_per_frame) // 2
|
| 142 |
+
audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0)
|
| 143 |
+
anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame))
|
| 144 |
+
audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0)
|
| 145 |
+
return audio
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def get_mfcc_ta(audio_fn, eps=1e-6, fps=15, smlpx=False, sr=16000, n_mfcc=64, win_size=None, type='mfcc', am=None, am_sr=None, encoder_choice='mfcc'):
|
| 149 |
+
if am is None:
|
| 150 |
+
audio, sr_0 = ta.load(audio_fn)
|
| 151 |
+
if sr != sr_0:
|
| 152 |
+
audio = ta.transforms.Resample(sr_0, sr)(audio)
|
| 153 |
+
if audio.shape[0] > 1:
|
| 154 |
+
audio = torch.mean(audio, dim=0, keepdim=True)
|
| 155 |
+
|
| 156 |
+
n_fft = 2048
|
| 157 |
+
if fps == 15:
|
| 158 |
+
hop_length = 1467
|
| 159 |
+
elif fps == 30:
|
| 160 |
+
hop_length = 734
|
| 161 |
+
win_length = hop_length * 2
|
| 162 |
+
n_mels = 256
|
| 163 |
+
n_mfcc = 64
|
| 164 |
+
|
| 165 |
+
if type == 'mfcc':
|
| 166 |
+
mfcc_transform = ta_T.MFCC(
|
| 167 |
+
sample_rate=sr,
|
| 168 |
+
n_mfcc=n_mfcc,
|
| 169 |
+
melkwargs={
|
| 170 |
+
"n_fft": n_fft,
|
| 171 |
+
"n_mels": n_mels,
|
| 172 |
+
# "win_length": win_length,
|
| 173 |
+
"hop_length": hop_length,
|
| 174 |
+
"mel_scale": "htk",
|
| 175 |
+
},
|
| 176 |
+
)
|
| 177 |
+
audio_ft = mfcc_transform(audio).squeeze(dim=0).transpose(0,1).numpy()
|
| 178 |
+
elif type == 'mel':
|
| 179 |
+
# audio = 0.01 * audio / torch.mean(torch.abs(audio))
|
| 180 |
+
mel_transform = ta_T.MelSpectrogram(
|
| 181 |
+
sample_rate=sr, n_fft=n_fft, win_length=None, hop_length=hop_length, n_mels=n_mels
|
| 182 |
+
)
|
| 183 |
+
audio_ft = mel_transform(audio).squeeze(0).transpose(0,1).numpy()
|
| 184 |
+
# audio_ft = torch.log(audio_ft.clamp(min=1e-10, max=None)).transpose(0,1).numpy()
|
| 185 |
+
elif type == 'mel_mul':
|
| 186 |
+
audio = 0.01 * audio / torch.mean(torch.abs(audio))
|
| 187 |
+
audio = audio_chunking(audio, frame_rate=fps, chunk_size=sr)
|
| 188 |
+
mel_transform = ta_T.MelSpectrogram(
|
| 189 |
+
sample_rate=sr, n_fft=n_fft, win_length=int(sr/20), hop_length=int(sr/100), n_mels=n_mels
|
| 190 |
+
)
|
| 191 |
+
audio_ft = mel_transform(audio).squeeze(1)
|
| 192 |
+
audio_ft = torch.log(audio_ft.clamp(min=1e-10, max=None)).numpy()
|
| 193 |
+
else:
|
| 194 |
+
speech_array, sampling_rate = librosa.load(audio_fn, sr=16000)
|
| 195 |
+
|
| 196 |
+
if encoder_choice == 'faceformer':
|
| 197 |
+
# audio_ft = np.squeeze(am(speech_array, sampling_rate=16000).input_values).reshape(-1, 1)
|
| 198 |
+
audio_ft = speech_array.reshape(-1, 1)
|
| 199 |
+
elif encoder_choice == 'meshtalk':
|
| 200 |
+
audio_ft = 0.01 * speech_array / np.mean(np.abs(speech_array))
|
| 201 |
+
elif encoder_choice == 'onset':
|
| 202 |
+
audio_ft = librosa.onset.onset_detect(y=speech_array, sr=16000, units='time').reshape(-1, 1)
|
| 203 |
+
else:
|
| 204 |
+
audio, sr_0 = ta.load(audio_fn)
|
| 205 |
+
if sr != sr_0:
|
| 206 |
+
audio = ta.transforms.Resample(sr_0, sr)(audio)
|
| 207 |
+
if audio.shape[0] > 1:
|
| 208 |
+
audio = torch.mean(audio, dim=0, keepdim=True)
|
| 209 |
+
|
| 210 |
+
n_fft = 2048
|
| 211 |
+
if fps == 15:
|
| 212 |
+
hop_length = 1467
|
| 213 |
+
elif fps == 30:
|
| 214 |
+
hop_length = 734
|
| 215 |
+
win_length = hop_length * 2
|
| 216 |
+
n_mels = 256
|
| 217 |
+
n_mfcc = 64
|
| 218 |
+
|
| 219 |
+
mfcc_transform = ta_T.MFCC(
|
| 220 |
+
sample_rate=sr,
|
| 221 |
+
n_mfcc=n_mfcc,
|
| 222 |
+
melkwargs={
|
| 223 |
+
"n_fft": n_fft,
|
| 224 |
+
"n_mels": n_mels,
|
| 225 |
+
# "win_length": win_length,
|
| 226 |
+
"hop_length": hop_length,
|
| 227 |
+
"mel_scale": "htk",
|
| 228 |
+
},
|
| 229 |
+
)
|
| 230 |
+
audio_ft = mfcc_transform(audio).squeeze(dim=0).transpose(0, 1).numpy()
|
| 231 |
+
return audio_ft
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def get_mfcc_sepa(audio_fn, fps=15, sr=16000):
|
| 235 |
+
audio, sr_0 = ta.load(audio_fn)
|
| 236 |
+
if sr != sr_0:
|
| 237 |
+
audio = ta.transforms.Resample(sr_0, sr)(audio)
|
| 238 |
+
if audio.shape[0] > 1:
|
| 239 |
+
audio = torch.mean(audio, dim=0, keepdim=True)
|
| 240 |
+
|
| 241 |
+
n_fft = 2048
|
| 242 |
+
if fps == 15:
|
| 243 |
+
hop_length = 1467
|
| 244 |
+
elif fps == 30:
|
| 245 |
+
hop_length = 734
|
| 246 |
+
n_mels = 256
|
| 247 |
+
n_mfcc = 64
|
| 248 |
+
|
| 249 |
+
mfcc_transform = ta_T.MFCC(
|
| 250 |
+
sample_rate=sr,
|
| 251 |
+
n_mfcc=n_mfcc,
|
| 252 |
+
melkwargs={
|
| 253 |
+
"n_fft": n_fft,
|
| 254 |
+
"n_mels": n_mels,
|
| 255 |
+
# "win_length": win_length,
|
| 256 |
+
"hop_length": hop_length,
|
| 257 |
+
"mel_scale": "htk",
|
| 258 |
+
},
|
| 259 |
+
)
|
| 260 |
+
audio_ft_0 = mfcc_transform(audio[0, :sr*2]).squeeze(dim=0).transpose(0,1).numpy()
|
| 261 |
+
audio_ft_1 = mfcc_transform(audio[0, sr*2:]).squeeze(dim=0).transpose(0,1).numpy()
|
| 262 |
+
audio_ft = np.concatenate((audio_ft_0, audio_ft_1), axis=0)
|
| 263 |
+
return audio_ft, audio_ft_0.shape[0]
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def get_mfcc_old(wav_file):
|
| 267 |
+
sig, sample_rate = load_wav_old(wav_file)
|
| 268 |
+
mfcc = extract_mfcc(sig)
|
| 269 |
+
return mfcc
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def smooth_geom(geom, mask: torch.Tensor = None, filter_size: int = 9, sigma: float = 2.0):
|
| 273 |
+
"""
|
| 274 |
+
:param geom: T x V x 3 tensor containing a temporal sequence of length T with V vertices in each frame
|
| 275 |
+
:param mask: V-dimensional Tensor containing a mask with vertices to be smoothed
|
| 276 |
+
:param filter_size: size of the Gaussian filter
|
| 277 |
+
:param sigma: standard deviation of the Gaussian filter
|
| 278 |
+
:return: T x V x 3 tensor containing smoothed geometry (i.e., smoothed in the area indicated by the mask)
|
| 279 |
+
"""
|
| 280 |
+
assert filter_size % 2 == 1, f"filter size must be odd but is {filter_size}"
|
| 281 |
+
# Gaussian smoothing (low-pass filtering)
|
| 282 |
+
fltr = np.arange(-(filter_size // 2), filter_size // 2 + 1)
|
| 283 |
+
fltr = np.exp(-0.5 * fltr ** 2 / sigma ** 2)
|
| 284 |
+
fltr = torch.Tensor(fltr) / np.sum(fltr)
|
| 285 |
+
# apply fltr
|
| 286 |
+
fltr = fltr.view(1, 1, -1).to(device=geom.device)
|
| 287 |
+
T, V = geom.shape[1], geom.shape[2]
|
| 288 |
+
g = torch.nn.functional.pad(
|
| 289 |
+
geom.permute(2, 0, 1).view(V, 1, T),
|
| 290 |
+
pad=[filter_size // 2, filter_size // 2], mode='replicate'
|
| 291 |
+
)
|
| 292 |
+
g = torch.nn.functional.conv1d(g, fltr).view(V, 1, T)
|
| 293 |
+
smoothed = g.permute(1, 2, 0).contiguous()
|
| 294 |
+
# blend smoothed signal with original signal
|
| 295 |
+
if mask is None:
|
| 296 |
+
return smoothed
|
| 297 |
+
else:
|
| 298 |
+
return smoothed * mask[None, :, None] + geom * (-mask[None, :, None] + 1)
|
| 299 |
+
|
| 300 |
+
if __name__ == '__main__':
|
| 301 |
+
audio_fn = '../sample_audio/clip000028_tCAkv4ggPgI.wav'
|
| 302 |
+
|
| 303 |
+
C = get_mfcc_psf(audio_fn)
|
| 304 |
+
print(C.shape)
|
| 305 |
+
|
| 306 |
+
C_2 = get_mfcc_librosa(audio_fn)
|
| 307 |
+
print(C.shape)
|
| 308 |
+
|
| 309 |
+
print(C)
|
| 310 |
+
print(C_2)
|
| 311 |
+
print((C == C_2).all())
|
| 312 |
+
# print(y.shape, sr)
|
| 313 |
+
# mel_spec = get_melspec(audio_fn)
|
| 314 |
+
# print(mel_spec.shape)
|
| 315 |
+
# mfcc = get_mfcc(audio_fn, sr = 16000)
|
| 316 |
+
# print(mfcc.shape)
|
| 317 |
+
# print(mel_spec.max(), mel_spec.min())
|
| 318 |
+
# print(mfcc.max(), mfcc.min())
|
evaluation/FGD.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from scipy import linalg
|
| 7 |
+
import math
|
| 8 |
+
from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d
|
| 9 |
+
|
| 10 |
+
import warnings
|
| 11 |
+
warnings.filterwarnings("ignore", category=RuntimeWarning) # ignore warnings
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
change_angle = torch.tensor([6.0181e-05, 5.1597e-05, 2.1344e-04, 2.1899e-04])
|
| 15 |
+
class EmbeddingSpaceEvaluator:
|
| 16 |
+
def __init__(self, ae, vae, device):
|
| 17 |
+
|
| 18 |
+
# init embed net
|
| 19 |
+
self.ae = ae
|
| 20 |
+
# self.vae = vae
|
| 21 |
+
|
| 22 |
+
# storage
|
| 23 |
+
self.real_feat_list = []
|
| 24 |
+
self.generated_feat_list = []
|
| 25 |
+
self.real_joints_list = []
|
| 26 |
+
self.generated_joints_list = []
|
| 27 |
+
self.real_6d_list = []
|
| 28 |
+
self.generated_6d_list = []
|
| 29 |
+
self.audio_beat_list = []
|
| 30 |
+
|
| 31 |
+
def reset(self):
|
| 32 |
+
self.real_feat_list = []
|
| 33 |
+
self.generated_feat_list = []
|
| 34 |
+
|
| 35 |
+
def get_no_of_samples(self):
|
| 36 |
+
return len(self.real_feat_list)
|
| 37 |
+
|
| 38 |
+
def push_samples(self, generated_poses, real_poses):
|
| 39 |
+
# self.net.eval()
|
| 40 |
+
# convert poses to latent features
|
| 41 |
+
real_feat, real_poses = self.ae.extract(real_poses)
|
| 42 |
+
generated_feat, generated_poses = self.ae.extract(generated_poses)
|
| 43 |
+
|
| 44 |
+
num_joints = real_poses.shape[2] // 3
|
| 45 |
+
|
| 46 |
+
real_feat = real_feat.squeeze()
|
| 47 |
+
generated_feat = generated_feat.reshape(generated_feat.shape[0]*generated_feat.shape[1], -1)
|
| 48 |
+
|
| 49 |
+
self.real_feat_list.append(real_feat.data.cpu().numpy())
|
| 50 |
+
self.generated_feat_list.append(generated_feat.data.cpu().numpy())
|
| 51 |
+
|
| 52 |
+
# real_poses = matrix_to_rotation_6d(axis_angle_to_matrix(real_poses.reshape(-1, 3))).reshape(-1, num_joints, 6)
|
| 53 |
+
# generated_poses = matrix_to_rotation_6d(axis_angle_to_matrix(generated_poses.reshape(-1, 3))).reshape(-1, num_joints, 6)
|
| 54 |
+
#
|
| 55 |
+
# self.real_feat_list.append(real_poses.data.cpu().numpy())
|
| 56 |
+
# self.generated_feat_list.append(generated_poses.data.cpu().numpy())
|
| 57 |
+
|
| 58 |
+
def push_joints(self, generated_poses, real_poses):
|
| 59 |
+
self.real_joints_list.append(real_poses.data.cpu())
|
| 60 |
+
self.generated_joints_list.append(generated_poses.squeeze().data.cpu())
|
| 61 |
+
|
| 62 |
+
def push_aud(self, aud):
|
| 63 |
+
self.audio_beat_list.append(aud.squeeze().data.cpu())
|
| 64 |
+
|
| 65 |
+
def get_MAAC(self):
|
| 66 |
+
ang_vel_list = []
|
| 67 |
+
for real_joints in self.real_joints_list:
|
| 68 |
+
real_joints[:, 15:21] = real_joints[:, 16:22]
|
| 69 |
+
vec = real_joints[:, 15:21] - real_joints[:, 13:19]
|
| 70 |
+
inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]])
|
| 71 |
+
inner_product = torch.clamp(inner_product, -1, 1, out=None)
|
| 72 |
+
angle = torch.acos(inner_product) / math.pi
|
| 73 |
+
ang_vel = (angle[1:] - angle[:-1]).abs().mean(dim=0)
|
| 74 |
+
ang_vel_list.append(ang_vel.unsqueeze(dim=0))
|
| 75 |
+
all_vel = torch.cat(ang_vel_list, dim=0)
|
| 76 |
+
MAAC = all_vel.mean(dim=0)
|
| 77 |
+
return MAAC
|
| 78 |
+
|
| 79 |
+
def get_BCscore(self):
|
| 80 |
+
thres = 0.01
|
| 81 |
+
sigma = 0.1
|
| 82 |
+
sum_1 = 0
|
| 83 |
+
total_beat = 0
|
| 84 |
+
for joints, audio_beat_time in zip(self.generated_joints_list, self.audio_beat_list):
|
| 85 |
+
motion_beat_time = []
|
| 86 |
+
if joints.dim() == 4:
|
| 87 |
+
joints = joints[0]
|
| 88 |
+
joints[:, 15:21] = joints[:, 16:22]
|
| 89 |
+
vec = joints[:, 15:21] - joints[:, 13:19]
|
| 90 |
+
inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]])
|
| 91 |
+
inner_product = torch.clamp(inner_product, -1, 1, out=None)
|
| 92 |
+
angle = torch.acos(inner_product) / math.pi
|
| 93 |
+
ang_vel = (angle[1:] - angle[:-1]).abs() / change_angle / len(change_angle)
|
| 94 |
+
|
| 95 |
+
angle_diff = torch.cat((torch.zeros(1, 4), ang_vel), dim=0)
|
| 96 |
+
|
| 97 |
+
sum_2 = 0
|
| 98 |
+
for i in range(angle_diff.shape[1]):
|
| 99 |
+
motion_beat_time = []
|
| 100 |
+
for t in range(1, joints.shape[0]-1):
|
| 101 |
+
if (angle_diff[t][i] < angle_diff[t - 1][i] and angle_diff[t][i] < angle_diff[t + 1][i]):
|
| 102 |
+
if (angle_diff[t - 1][i] - angle_diff[t][i] >= thres or angle_diff[t + 1][i] - angle_diff[
|
| 103 |
+
t][i] >= thres):
|
| 104 |
+
motion_beat_time.append(float(t) / 30.0)
|
| 105 |
+
if (len(motion_beat_time) == 0):
|
| 106 |
+
continue
|
| 107 |
+
motion_beat_time = torch.tensor(motion_beat_time)
|
| 108 |
+
sum = 0
|
| 109 |
+
for audio in audio_beat_time:
|
| 110 |
+
sum += np.power(math.e, -(np.power((audio.item() - motion_beat_time), 2)).min() / (2 * sigma * sigma))
|
| 111 |
+
sum_2 = sum_2 + sum
|
| 112 |
+
total_beat = total_beat + len(audio_beat_time)
|
| 113 |
+
sum_1 = sum_1 + sum_2
|
| 114 |
+
return sum_1/total_beat
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_scores(self):
|
| 118 |
+
generated_feats = np.vstack(self.generated_feat_list)
|
| 119 |
+
real_feats = np.vstack(self.real_feat_list)
|
| 120 |
+
|
| 121 |
+
def frechet_distance(samples_A, samples_B):
|
| 122 |
+
A_mu = np.mean(samples_A, axis=0)
|
| 123 |
+
A_sigma = np.cov(samples_A, rowvar=False)
|
| 124 |
+
B_mu = np.mean(samples_B, axis=0)
|
| 125 |
+
B_sigma = np.cov(samples_B, rowvar=False)
|
| 126 |
+
try:
|
| 127 |
+
frechet_dist = self.calculate_frechet_distance(A_mu, A_sigma, B_mu, B_sigma)
|
| 128 |
+
except ValueError:
|
| 129 |
+
frechet_dist = 1e+10
|
| 130 |
+
return frechet_dist
|
| 131 |
+
|
| 132 |
+
####################################################################
|
| 133 |
+
# frechet distance
|
| 134 |
+
frechet_dist = frechet_distance(generated_feats, real_feats)
|
| 135 |
+
|
| 136 |
+
####################################################################
|
| 137 |
+
# distance between real and generated samples on the latent feature space
|
| 138 |
+
dists = []
|
| 139 |
+
for i in range(real_feats.shape[0]):
|
| 140 |
+
d = np.sum(np.absolute(real_feats[i] - generated_feats[i])) # MAE
|
| 141 |
+
dists.append(d)
|
| 142 |
+
feat_dist = np.mean(dists)
|
| 143 |
+
|
| 144 |
+
return frechet_dist, feat_dist
|
| 145 |
+
|
| 146 |
+
@staticmethod
|
| 147 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
| 148 |
+
""" from https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py """
|
| 149 |
+
"""Numpy implementation of the Frechet Distance.
|
| 150 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
| 151 |
+
and X_2 ~ N(mu_2, C_2) is
|
| 152 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
| 153 |
+
Stable version by Dougal J. Sutherland.
|
| 154 |
+
Params:
|
| 155 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
| 156 |
+
inception net (like returned by the function 'get_predictions')
|
| 157 |
+
for generated samples.
|
| 158 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
| 159 |
+
representative data set.
|
| 160 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
| 161 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
| 162 |
+
representative data set.
|
| 163 |
+
Returns:
|
| 164 |
+
-- : The Frechet Distance.
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
mu1 = np.atleast_1d(mu1)
|
| 168 |
+
mu2 = np.atleast_1d(mu2)
|
| 169 |
+
|
| 170 |
+
sigma1 = np.atleast_2d(sigma1)
|
| 171 |
+
sigma2 = np.atleast_2d(sigma2)
|
| 172 |
+
|
| 173 |
+
assert mu1.shape == mu2.shape, \
|
| 174 |
+
'Training and test mean vectors have different lengths'
|
| 175 |
+
assert sigma1.shape == sigma2.shape, \
|
| 176 |
+
'Training and test covariances have different dimensions'
|
| 177 |
+
|
| 178 |
+
diff = mu1 - mu2
|
| 179 |
+
|
| 180 |
+
# Product might be almost singular
|
| 181 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
| 182 |
+
if not np.isfinite(covmean).all():
|
| 183 |
+
msg = ('fid calculation produces singular product; '
|
| 184 |
+
'adding %s to diagonal of cov estimates') % eps
|
| 185 |
+
print(msg)
|
| 186 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 187 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 188 |
+
|
| 189 |
+
# Numerical error might give slight imaginary component
|
| 190 |
+
if np.iscomplexobj(covmean):
|
| 191 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
| 192 |
+
m = np.max(np.abs(covmean.imag))
|
| 193 |
+
raise ValueError('Imaginary component {}'.format(m))
|
| 194 |
+
covmean = covmean.real
|
| 195 |
+
|
| 196 |
+
tr_covmean = np.trace(covmean)
|
| 197 |
+
|
| 198 |
+
return (diff.dot(diff) + np.trace(sigma1) +
|
| 199 |
+
np.trace(sigma2) - 2 * tr_covmean)
|
evaluation/__init__.py
ADDED
|
File without changes
|
evaluation/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (181 Bytes). View file
|
|
|
evaluation/__pycache__/metrics.cpython-37.pyc
ADDED
|
Binary file (3.81 kB). View file
|
|
|
evaluation/diversity_LVD.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
LVD: different initial pose
|
| 3 |
+
diversity: same initial pose
|
| 4 |
+
'''
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
sys.path.append(os.getcwd())
|
| 8 |
+
|
| 9 |
+
from glob import glob
|
| 10 |
+
|
| 11 |
+
from argparse import ArgumentParser
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
from evaluation.util import *
|
| 15 |
+
from evaluation.metrics import *
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
parser = ArgumentParser()
|
| 19 |
+
parser.add_argument('--speaker', required=True, type=str)
|
| 20 |
+
parser.add_argument('--post_fix', nargs='+', default=['base'], type=str)
|
| 21 |
+
args = parser.parse_args()
|
| 22 |
+
|
| 23 |
+
speaker = args.speaker
|
| 24 |
+
test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
|
| 25 |
+
|
| 26 |
+
LVD_list = []
|
| 27 |
+
diversity_list = []
|
| 28 |
+
|
| 29 |
+
for aud in tqdm(test_audios):
|
| 30 |
+
base_name = os.path.splitext(aud)[0]
|
| 31 |
+
gt_path = get_full_path(aud, speaker, 'val')
|
| 32 |
+
_, gt_poses, _ = get_gts(gt_path)
|
| 33 |
+
gt_poses = gt_poses[np.newaxis,...]
|
| 34 |
+
# print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
|
| 35 |
+
for post_fix in args.post_fix:
|
| 36 |
+
pred_path = base_name + '_'+post_fix+'.json'
|
| 37 |
+
pred_poses = np.array(json.load(open(pred_path)))
|
| 38 |
+
# print(pred_poses.shape)#(B, seq_len, 108)
|
| 39 |
+
pred_poses = cvt25(pred_poses, gt_poses)
|
| 40 |
+
# print(pred_poses.shape)#(B, seq, pose_dim)
|
| 41 |
+
|
| 42 |
+
gt_valid_points = hand_points(gt_poses)
|
| 43 |
+
pred_valid_points = hand_points(pred_poses)
|
| 44 |
+
|
| 45 |
+
lvd = LVD(gt_valid_points, pred_valid_points)
|
| 46 |
+
# div = diversity(pred_valid_points)
|
| 47 |
+
|
| 48 |
+
LVD_list.append(lvd)
|
| 49 |
+
# diversity_list.append(div)
|
| 50 |
+
|
| 51 |
+
# gt_velocity = peak_velocity(gt_valid_points, order=2)
|
| 52 |
+
# pred_velocity = peak_velocity(pred_valid_points, order=2)
|
| 53 |
+
|
| 54 |
+
# gt_consistency = velocity_consistency(gt_velocity, pred_velocity)
|
| 55 |
+
# pred_consistency = velocity_consistency(pred_velocity, gt_velocity)
|
| 56 |
+
|
| 57 |
+
# gt_consistency_list.append(gt_consistency)
|
| 58 |
+
# pred_consistency_list.append(pred_consistency)
|
| 59 |
+
|
| 60 |
+
lvd = np.mean(LVD_list)
|
| 61 |
+
# diversity_list = np.mean(diversity_list)
|
| 62 |
+
|
| 63 |
+
print('LVD:', lvd)
|
| 64 |
+
# print("diversity:", diversity_list)
|
evaluation/get_quality_samples.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
'''
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
sys.path.append(os.getcwd())
|
| 6 |
+
|
| 7 |
+
from glob import glob
|
| 8 |
+
|
| 9 |
+
from argparse import ArgumentParser
|
| 10 |
+
import json
|
| 11 |
+
|
| 12 |
+
from evaluation.util import *
|
| 13 |
+
from evaluation.metrics import *
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
parser = ArgumentParser()
|
| 17 |
+
parser.add_argument('--speaker', required=True, type=str)
|
| 18 |
+
parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
|
| 19 |
+
args = parser.parse_args()
|
| 20 |
+
|
| 21 |
+
speaker = args.speaker
|
| 22 |
+
test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
|
| 23 |
+
|
| 24 |
+
quality_samples={'gt':[]}
|
| 25 |
+
for post_fix in args.post_fix:
|
| 26 |
+
quality_samples[post_fix] = []
|
| 27 |
+
|
| 28 |
+
for aud in tqdm(test_audios):
|
| 29 |
+
base_name = os.path.splitext(aud)[0]
|
| 30 |
+
gt_path = get_full_path(aud, speaker, 'val')
|
| 31 |
+
_, gt_poses, _ = get_gts(gt_path)
|
| 32 |
+
gt_poses = gt_poses[np.newaxis,...]
|
| 33 |
+
gt_valid_points = valid_points(gt_poses)
|
| 34 |
+
# print(gt_valid_points.shape)
|
| 35 |
+
quality_samples['gt'].append(gt_valid_points)
|
| 36 |
+
|
| 37 |
+
for post_fix in args.post_fix:
|
| 38 |
+
pred_path = base_name + '_'+post_fix+'.json'
|
| 39 |
+
pred_poses = np.array(json.load(open(pred_path)))
|
| 40 |
+
# print(pred_poses.shape)#(B, seq_len, 108)
|
| 41 |
+
pred_poses = cvt25(pred_poses, gt_poses)
|
| 42 |
+
# print(pred_poses.shape)#(B, seq, pose_dim)
|
| 43 |
+
|
| 44 |
+
pred_valid_points = valid_points(pred_poses)[0:1]
|
| 45 |
+
quality_samples[post_fix].append(pred_valid_points)
|
| 46 |
+
|
| 47 |
+
quality_samples['gt'] = np.concatenate(quality_samples['gt'], axis=1)
|
| 48 |
+
for post_fix in args.post_fix:
|
| 49 |
+
quality_samples[post_fix] = np.concatenate(quality_samples[post_fix], axis=1)
|
| 50 |
+
|
| 51 |
+
print('gt:', quality_samples['gt'].shape)
|
| 52 |
+
quality_samples['gt'] = quality_samples['gt'].tolist()
|
| 53 |
+
for post_fix in args.post_fix:
|
| 54 |
+
print(post_fix, ':', quality_samples[post_fix].shape)
|
| 55 |
+
quality_samples[post_fix] = quality_samples[post_fix].tolist()
|
| 56 |
+
|
| 57 |
+
save_dir = '../../experiments/'
|
| 58 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 59 |
+
save_name = os.path.join(save_dir, 'quality_samples_%s.json'%(speaker))
|
| 60 |
+
with open(save_name, 'w') as f:
|
| 61 |
+
json.dump(quality_samples, f)
|
| 62 |
+
|
evaluation/metrics.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Warning: metrics are for reference only, may have limited significance
|
| 3 |
+
'''
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
sys.path.append(os.getcwd())
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from data_utils.lower_body import rearrange, symmetry
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
def data_driven_baselines(gt_kps):
|
| 14 |
+
'''
|
| 15 |
+
gt_kps: T, D
|
| 16 |
+
'''
|
| 17 |
+
gt_velocity = np.abs(gt_kps[1:] - gt_kps[:-1])
|
| 18 |
+
|
| 19 |
+
mean= np.mean(gt_velocity, axis=0)[np.newaxis] #(1, D)
|
| 20 |
+
mean = np.mean(np.abs(gt_velocity-mean))
|
| 21 |
+
last_step = gt_kps[1] - gt_kps[0]
|
| 22 |
+
last_step = last_step[np.newaxis] #(1, D)
|
| 23 |
+
last_step = np.mean(np.abs(gt_velocity-last_step))
|
| 24 |
+
return last_step, mean
|
| 25 |
+
|
| 26 |
+
def Batch_LVD(gt_kps, pr_kps, symmetrical, weight):
|
| 27 |
+
if gt_kps.shape[0] > pr_kps.shape[1]:
|
| 28 |
+
length = pr_kps.shape[1]
|
| 29 |
+
else:
|
| 30 |
+
length = gt_kps.shape[0]
|
| 31 |
+
gt_kps = gt_kps[:length]
|
| 32 |
+
pr_kps = pr_kps[:, :length]
|
| 33 |
+
global symmetry
|
| 34 |
+
symmetry = torch.tensor(symmetry).bool()
|
| 35 |
+
|
| 36 |
+
if symmetrical:
|
| 37 |
+
# rearrange for compute symmetric. ns means non-symmetrical joints, ys means symmetrical joints.
|
| 38 |
+
gt_kps = gt_kps[:, rearrange]
|
| 39 |
+
ns_gt_kps = gt_kps[:, ~symmetry]
|
| 40 |
+
ys_gt_kps = gt_kps[:, symmetry]
|
| 41 |
+
ys_gt_kps = ys_gt_kps.reshape(ys_gt_kps.shape[0], -1, 2, 3)
|
| 42 |
+
ns_gt_velocity = (ns_gt_kps[1:] - ns_gt_kps[:-1]).norm(p=2, dim=-1)
|
| 43 |
+
ys_gt_velocity = (ys_gt_kps[1:] - ys_gt_kps[:-1]).norm(p=2, dim=-1)
|
| 44 |
+
left_gt_vel = ys_gt_velocity[:, :, 0].sum(dim=-1)
|
| 45 |
+
right_gt_vel = ys_gt_velocity[:, :, 1].sum(dim=-1)
|
| 46 |
+
move_side = torch.where(left_gt_vel>right_gt_vel, torch.ones(left_gt_vel.shape).cuda(), torch.zeros(left_gt_vel.shape).cuda())
|
| 47 |
+
ys_gt_velocity = torch.mul(ys_gt_velocity[:, :, 0].transpose(0,1), move_side) + torch.mul(ys_gt_velocity[:, :, 1].transpose(0,1), ~move_side.bool())
|
| 48 |
+
ys_gt_velocity = ys_gt_velocity.transpose(0,1)
|
| 49 |
+
gt_velocity = torch.cat([ns_gt_velocity, ys_gt_velocity], dim=1)
|
| 50 |
+
|
| 51 |
+
pr_kps = pr_kps[:, :, rearrange]
|
| 52 |
+
ns_pr_kps = pr_kps[:, :, ~symmetry]
|
| 53 |
+
ys_pr_kps = pr_kps[:, :, symmetry]
|
| 54 |
+
ys_pr_kps = ys_pr_kps.reshape(ys_pr_kps.shape[0], ys_pr_kps.shape[1], -1, 2, 3)
|
| 55 |
+
ns_pr_velocity = (ns_pr_kps[:, 1:] - ns_pr_kps[:, :-1]).norm(p=2, dim=-1)
|
| 56 |
+
ys_pr_velocity = (ys_pr_kps[:, 1:] - ys_pr_kps[:, :-1]).norm(p=2, dim=-1)
|
| 57 |
+
left_pr_vel = ys_pr_velocity[:, :, :, 0].sum(dim=-1)
|
| 58 |
+
right_pr_vel = ys_pr_velocity[:, :, :, 1].sum(dim=-1)
|
| 59 |
+
move_side = torch.where(left_pr_vel > right_pr_vel, torch.ones(left_pr_vel.shape).cuda(),
|
| 60 |
+
torch.zeros(left_pr_vel.shape).cuda())
|
| 61 |
+
ys_pr_velocity = torch.mul(ys_pr_velocity[..., 0].permute(2, 0, 1), move_side) + torch.mul(
|
| 62 |
+
ys_pr_velocity[..., 1].permute(2, 0, 1), ~move_side.long())
|
| 63 |
+
ys_pr_velocity = ys_pr_velocity.permute(1, 2, 0)
|
| 64 |
+
pr_velocity = torch.cat([ns_pr_velocity, ys_pr_velocity], dim=2)
|
| 65 |
+
else:
|
| 66 |
+
gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1)
|
| 67 |
+
pr_velocity = (pr_kps[:, 1:] - pr_kps[:, :-1]).norm(p=2, dim=-1)
|
| 68 |
+
|
| 69 |
+
if weight:
|
| 70 |
+
w = F.softmax(gt_velocity.sum(dim=1).normal_(), dim=0)
|
| 71 |
+
else:
|
| 72 |
+
w = 1 / gt_velocity.shape[0]
|
| 73 |
+
|
| 74 |
+
v_diff = ((pr_velocity - gt_velocity).abs().sum(dim=-1) * w).sum(dim=-1).mean()
|
| 75 |
+
|
| 76 |
+
return v_diff
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def LVD(gt_kps, pr_kps, symmetrical=False, weight=False):
|
| 80 |
+
gt_kps = gt_kps.squeeze()
|
| 81 |
+
pr_kps = pr_kps.squeeze()
|
| 82 |
+
if len(pr_kps.shape) == 4:
|
| 83 |
+
return Batch_LVD(gt_kps, pr_kps, symmetrical, weight)
|
| 84 |
+
# length = np.minimum(gt_kps.shape[0], pr_kps.shape[0])
|
| 85 |
+
length = gt_kps.shape[0]-10
|
| 86 |
+
# gt_kps = gt_kps[25:length]
|
| 87 |
+
# pr_kps = pr_kps[25:length] #(T, D)
|
| 88 |
+
# if pr_kps.shape[0] < gt_kps.shape[0]:
|
| 89 |
+
# pr_kps = np.pad(pr_kps, [[0, int(gt_kps.shape[0]-pr_kps.shape[0])], [0, 0]], mode='constant')
|
| 90 |
+
|
| 91 |
+
gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1)
|
| 92 |
+
pr_velocity = (pr_kps[1:] - pr_kps[:-1]).norm(p=2, dim=-1)
|
| 93 |
+
|
| 94 |
+
return (pr_velocity-gt_velocity).abs().sum(dim=-1).mean()
|
| 95 |
+
|
| 96 |
+
def diversity(kps):
|
| 97 |
+
'''
|
| 98 |
+
kps: bs, seq, dim
|
| 99 |
+
'''
|
| 100 |
+
dis_list = []
|
| 101 |
+
#the distance between each pair
|
| 102 |
+
for i in range(kps.shape[0]):
|
| 103 |
+
for j in range(i+1, kps.shape[0]):
|
| 104 |
+
seq_i = kps[i]
|
| 105 |
+
seq_j = kps[j]
|
| 106 |
+
|
| 107 |
+
dis = np.mean(np.abs(seq_i - seq_j))
|
| 108 |
+
dis_list.append(dis)
|
| 109 |
+
return np.mean(dis_list)
|
evaluation/mode_transition.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.append(os.getcwd())
|
| 4 |
+
|
| 5 |
+
from glob import glob
|
| 6 |
+
|
| 7 |
+
from argparse import ArgumentParser
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
from evaluation.util import *
|
| 11 |
+
from evaluation.metrics import *
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
parser = ArgumentParser()
|
| 15 |
+
parser.add_argument('--speaker', required=True, type=str)
|
| 16 |
+
parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
|
| 17 |
+
args = parser.parse_args()
|
| 18 |
+
|
| 19 |
+
speaker = args.speaker
|
| 20 |
+
test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
|
| 21 |
+
|
| 22 |
+
precision_list=[]
|
| 23 |
+
recall_list=[]
|
| 24 |
+
accuracy_list=[]
|
| 25 |
+
|
| 26 |
+
for aud in tqdm(test_audios):
|
| 27 |
+
base_name = os.path.splitext(aud)[0]
|
| 28 |
+
gt_path = get_full_path(aud, speaker, 'val')
|
| 29 |
+
_, gt_poses, _ = get_gts(gt_path)
|
| 30 |
+
if gt_poses.shape[0] < 50:
|
| 31 |
+
continue
|
| 32 |
+
gt_poses = gt_poses[np.newaxis,...]
|
| 33 |
+
# print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
|
| 34 |
+
for post_fix in args.post_fix:
|
| 35 |
+
pred_path = base_name + '_'+post_fix+'.json'
|
| 36 |
+
pred_poses = np.array(json.load(open(pred_path)))
|
| 37 |
+
# print(pred_poses.shape)#(B, seq_len, 108)
|
| 38 |
+
pred_poses = cvt25(pred_poses, gt_poses)
|
| 39 |
+
# print(pred_poses.shape)#(B, seq, pose_dim)
|
| 40 |
+
|
| 41 |
+
gt_valid_points = valid_points(gt_poses)
|
| 42 |
+
pred_valid_points = valid_points(pred_poses)
|
| 43 |
+
|
| 44 |
+
# print(gt_valid_points.shape, pred_valid_points.shape)
|
| 45 |
+
|
| 46 |
+
gt_mode_transition_seq = mode_transition_seq(gt_valid_points, speaker)#(B, N)
|
| 47 |
+
pred_mode_transition_seq = mode_transition_seq(pred_valid_points, speaker)#(B, N)
|
| 48 |
+
|
| 49 |
+
# baseline = np.random.randint(0, 2, size=pred_mode_transition_seq.shape)
|
| 50 |
+
# pred_mode_transition_seq = baseline
|
| 51 |
+
precision, recall, accuracy = mode_transition_consistency(pred_mode_transition_seq, gt_mode_transition_seq)
|
| 52 |
+
precision_list.append(precision)
|
| 53 |
+
recall_list.append(recall)
|
| 54 |
+
accuracy_list.append(accuracy)
|
| 55 |
+
print(len(precision_list), len(recall_list), len(accuracy_list))
|
| 56 |
+
precision_list = np.mean(precision_list)
|
| 57 |
+
recall_list = np.mean(recall_list)
|
| 58 |
+
accuracy_list = np.mean(accuracy_list)
|
| 59 |
+
|
| 60 |
+
print('precision, recall, accu:', precision_list, recall_list, accuracy_list)
|
evaluation/peak_velocity.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.append(os.getcwd())
|
| 4 |
+
|
| 5 |
+
from glob import glob
|
| 6 |
+
|
| 7 |
+
from argparse import ArgumentParser
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
from evaluation.util import *
|
| 11 |
+
from evaluation.metrics import *
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
parser = ArgumentParser()
|
| 15 |
+
parser.add_argument('--speaker', required=True, type=str)
|
| 16 |
+
parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
|
| 17 |
+
args = parser.parse_args()
|
| 18 |
+
|
| 19 |
+
speaker = args.speaker
|
| 20 |
+
test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
|
| 21 |
+
|
| 22 |
+
gt_consistency_list=[]
|
| 23 |
+
pred_consistency_list=[]
|
| 24 |
+
|
| 25 |
+
for aud in tqdm(test_audios):
|
| 26 |
+
base_name = os.path.splitext(aud)[0]
|
| 27 |
+
gt_path = get_full_path(aud, speaker, 'val')
|
| 28 |
+
_, gt_poses, _ = get_gts(gt_path)
|
| 29 |
+
gt_poses = gt_poses[np.newaxis,...]
|
| 30 |
+
# print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
|
| 31 |
+
for post_fix in args.post_fix:
|
| 32 |
+
pred_path = base_name + '_'+post_fix+'.json'
|
| 33 |
+
pred_poses = np.array(json.load(open(pred_path)))
|
| 34 |
+
# print(pred_poses.shape)#(B, seq_len, 108)
|
| 35 |
+
pred_poses = cvt25(pred_poses, gt_poses)
|
| 36 |
+
# print(pred_poses.shape)#(B, seq, pose_dim)
|
| 37 |
+
|
| 38 |
+
gt_valid_points = hand_points(gt_poses)
|
| 39 |
+
pred_valid_points = hand_points(pred_poses)
|
| 40 |
+
|
| 41 |
+
gt_velocity = peak_velocity(gt_valid_points, order=2)
|
| 42 |
+
pred_velocity = peak_velocity(pred_valid_points, order=2)
|
| 43 |
+
|
| 44 |
+
gt_consistency = velocity_consistency(gt_velocity, pred_velocity)
|
| 45 |
+
pred_consistency = velocity_consistency(pred_velocity, gt_velocity)
|
| 46 |
+
|
| 47 |
+
gt_consistency_list.append(gt_consistency)
|
| 48 |
+
pred_consistency_list.append(pred_consistency)
|
| 49 |
+
|
| 50 |
+
gt_consistency_list = np.concatenate(gt_consistency_list)
|
| 51 |
+
pred_consistency_list = np.concatenate(pred_consistency_list)
|
| 52 |
+
|
| 53 |
+
print(gt_consistency_list.max(), gt_consistency_list.min())
|
| 54 |
+
print(pred_consistency_list.max(), pred_consistency_list.min())
|
| 55 |
+
print(np.mean(gt_consistency_list), np.mean(pred_consistency_list))
|
| 56 |
+
print(np.std(gt_consistency_list), np.std(pred_consistency_list))
|
| 57 |
+
|
| 58 |
+
draw_cdf(gt_consistency_list, save_name='%s_gt.jpg'%(speaker), color='slateblue')
|
| 59 |
+
draw_cdf(pred_consistency_list, save_name='%s_pred.jpg'%(speaker), color='lightskyblue')
|
| 60 |
+
|
| 61 |
+
to_excel(gt_consistency_list, '%s_gt.xlsx'%(speaker))
|
| 62 |
+
to_excel(pred_consistency_list, '%s_pred.xlsx'%(speaker))
|
| 63 |
+
|
| 64 |
+
np.save('%s_gt.npy'%(speaker), gt_consistency_list)
|
| 65 |
+
np.save('%s_pred.npy'%(speaker), pred_consistency_list)
|
evaluation/util.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from glob import glob
|
| 3 |
+
import numpy as np
|
| 4 |
+
import json
|
| 5 |
+
from matplotlib import pyplot as plt
|
| 6 |
+
import pandas as pd
|
| 7 |
+
def get_gts(clip):
|
| 8 |
+
'''
|
| 9 |
+
clip: abs path to the clip dir
|
| 10 |
+
'''
|
| 11 |
+
keypoints_files = sorted(glob(os.path.join(clip, 'keypoints_new/person_1')+'/*.json'))
|
| 12 |
+
|
| 13 |
+
upper_body_points = list(np.arange(0, 25))
|
| 14 |
+
poses = []
|
| 15 |
+
confs = []
|
| 16 |
+
neck_to_nose_len = []
|
| 17 |
+
mean_position = []
|
| 18 |
+
for kp_file in keypoints_files:
|
| 19 |
+
kp_load = json.load(open(kp_file, 'r'))['people'][0]
|
| 20 |
+
posepts = kp_load['pose_keypoints_2d']
|
| 21 |
+
lhandpts = kp_load['hand_left_keypoints_2d']
|
| 22 |
+
rhandpts = kp_load['hand_right_keypoints_2d']
|
| 23 |
+
facepts = kp_load['face_keypoints_2d']
|
| 24 |
+
|
| 25 |
+
neck = np.array(posepts).reshape(-1,3)[1]
|
| 26 |
+
nose = np.array(posepts).reshape(-1,3)[0]
|
| 27 |
+
x_offset = abs(neck[0]-nose[0])
|
| 28 |
+
y_offset = abs(neck[1]-nose[1])
|
| 29 |
+
neck_to_nose_len.append(y_offset)
|
| 30 |
+
mean_position.append([neck[0],neck[1]])
|
| 31 |
+
|
| 32 |
+
keypoints=np.array(posepts+lhandpts+rhandpts+facepts).reshape(-1,3)[:,:2]
|
| 33 |
+
|
| 34 |
+
upper_body = keypoints[upper_body_points, :]
|
| 35 |
+
hand_points = keypoints[25:, :]
|
| 36 |
+
keypoints = np.vstack([upper_body, hand_points])
|
| 37 |
+
|
| 38 |
+
poses.append(keypoints)
|
| 39 |
+
|
| 40 |
+
if len(neck_to_nose_len) > 0:
|
| 41 |
+
scale_factor = np.mean(neck_to_nose_len)
|
| 42 |
+
else:
|
| 43 |
+
raise ValueError(clip)
|
| 44 |
+
mean_position = np.mean(np.array(mean_position), axis=0)
|
| 45 |
+
|
| 46 |
+
unlocalized_poses = np.array(poses).copy()
|
| 47 |
+
localized_poses = []
|
| 48 |
+
for i in range(len(poses)):
|
| 49 |
+
keypoints = poses[i]
|
| 50 |
+
neck = keypoints[1].copy()
|
| 51 |
+
|
| 52 |
+
keypoints[:, 0] = (keypoints[:, 0] - neck[0]) / scale_factor
|
| 53 |
+
keypoints[:, 1] = (keypoints[:, 1] - neck[1]) / scale_factor
|
| 54 |
+
localized_poses.append(keypoints.reshape(-1))
|
| 55 |
+
|
| 56 |
+
localized_poses=np.array(localized_poses)
|
| 57 |
+
return unlocalized_poses, localized_poses, (scale_factor, mean_position)
|
| 58 |
+
|
| 59 |
+
def get_full_path(wav_name, speaker, split):
|
| 60 |
+
'''
|
| 61 |
+
get clip path from aud file
|
| 62 |
+
'''
|
| 63 |
+
wav_name = os.path.basename(wav_name)
|
| 64 |
+
wav_name = os.path.splitext(wav_name)[0]
|
| 65 |
+
clip_name, vid_name = wav_name[:10], wav_name[11:]
|
| 66 |
+
|
| 67 |
+
full_path = os.path.join('pose_dataset/videos/', speaker, 'clips', vid_name, 'images/half', split, clip_name)
|
| 68 |
+
|
| 69 |
+
assert os.path.isdir(full_path), full_path
|
| 70 |
+
|
| 71 |
+
return full_path
|
| 72 |
+
|
| 73 |
+
def smooth(res):
|
| 74 |
+
'''
|
| 75 |
+
res: (B, seq_len, pose_dim)
|
| 76 |
+
'''
|
| 77 |
+
window = [res[:, 7, :], res[:, 8, :], res[:, 9, :], res[:, 10, :], res[:, 11, :], res[:, 12, :]]
|
| 78 |
+
w_size=7
|
| 79 |
+
for i in range(10, res.shape[1]-3):
|
| 80 |
+
window.append(res[:, i+3, :])
|
| 81 |
+
if len(window) > w_size:
|
| 82 |
+
window = window[1:]
|
| 83 |
+
|
| 84 |
+
if (i%25) in [22, 23, 24, 0, 1, 2, 3]:
|
| 85 |
+
res[:, i, :] = np.mean(window, axis=1)
|
| 86 |
+
|
| 87 |
+
return res
|
| 88 |
+
|
| 89 |
+
def cvt25(pred_poses, gt_poses=None):
|
| 90 |
+
'''
|
| 91 |
+
gt_poses: (1, seq_len, 270), 135 *2
|
| 92 |
+
pred_poses: (B, seq_len, 108), 54 * 2
|
| 93 |
+
'''
|
| 94 |
+
if gt_poses is None:
|
| 95 |
+
gt_poses = np.zeros_like(pred_poses)
|
| 96 |
+
else:
|
| 97 |
+
gt_poses = gt_poses.repeat(pred_poses.shape[0], axis=0)
|
| 98 |
+
|
| 99 |
+
length = min(pred_poses.shape[1], gt_poses.shape[1])
|
| 100 |
+
pred_poses = pred_poses[:, :length, :]
|
| 101 |
+
gt_poses = gt_poses[:, :length, :]
|
| 102 |
+
gt_poses = gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1, 2)
|
| 103 |
+
pred_poses = pred_poses.reshape(pred_poses.shape[0], pred_poses.shape[1], -1, 2)
|
| 104 |
+
|
| 105 |
+
gt_poses[:, :, [1, 2, 3, 4, 5, 6, 7], :] = pred_poses[:, :, 1:8, :]
|
| 106 |
+
gt_poses[:, :, 25:25+21+21, :] = pred_poses[:, :, 12:, :]
|
| 107 |
+
|
| 108 |
+
return gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1)
|
| 109 |
+
|
| 110 |
+
def hand_points(seq):
|
| 111 |
+
'''
|
| 112 |
+
seq: (B, seq_len, 135*2)
|
| 113 |
+
hands only
|
| 114 |
+
'''
|
| 115 |
+
hand_idx = [1, 2, 3, 4,5 ,6,7] + list(range(25, 25+21+21))
|
| 116 |
+
seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2)
|
| 117 |
+
return seq[:, :, hand_idx, :].reshape(seq.shape[0], seq.shape[1], -1)
|
| 118 |
+
|
| 119 |
+
def valid_points(seq):
|
| 120 |
+
'''
|
| 121 |
+
hands with some head points
|
| 122 |
+
'''
|
| 123 |
+
valid_idx = [0, 1, 2, 3, 4,5 ,6,7, 8, 9, 10, 11] + list(range(25, 25+21+21))
|
| 124 |
+
seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2)
|
| 125 |
+
|
| 126 |
+
seq = seq[:, :, valid_idx, :].reshape(seq.shape[0], seq.shape[1], -1)
|
| 127 |
+
assert seq.shape[-1] == 108, seq.shape
|
| 128 |
+
return seq
|
| 129 |
+
|
| 130 |
+
def draw_cdf(seq, save_name='cdf.jpg', color='slatebule'):
|
| 131 |
+
plt.figure()
|
| 132 |
+
plt.hist(seq, bins=100, range=(0, 100), color=color)
|
| 133 |
+
plt.savefig(save_name)
|
| 134 |
+
|
| 135 |
+
def to_excel(seq, save_name='res.xlsx'):
|
| 136 |
+
'''
|
| 137 |
+
seq: (T)
|
| 138 |
+
'''
|
| 139 |
+
df = pd.DataFrame(seq)
|
| 140 |
+
writer = pd.ExcelWriter(save_name)
|
| 141 |
+
df.to_excel(writer, 'sheet1')
|
| 142 |
+
writer.save()
|
| 143 |
+
writer.close()
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
if __name__ == '__main__':
|
| 147 |
+
random_data = np.random.randint(0, 10, 100)
|
| 148 |
+
draw_cdf(random_data)
|
losses/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .losses import *
|
losses/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (174 Bytes). View file
|
|
|
losses/__pycache__/losses.cpython-37.pyc
ADDED
|
Binary file (3.53 kB). View file
|
|
|
losses/losses.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
sys.path.append(os.getcwd())
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
class KeypointLoss(nn.Module):
|
| 12 |
+
def __init__(self):
|
| 13 |
+
super(KeypointLoss, self).__init__()
|
| 14 |
+
|
| 15 |
+
def forward(self, pred_seq, gt_seq, gt_conf=None):
|
| 16 |
+
#pred_seq: (B, C, T)
|
| 17 |
+
if gt_conf is not None:
|
| 18 |
+
gt_conf = gt_conf >= 0.01
|
| 19 |
+
return F.mse_loss(pred_seq[gt_conf], gt_seq[gt_conf], reduction='mean')
|
| 20 |
+
else:
|
| 21 |
+
return F.mse_loss(pred_seq, gt_seq)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class KLLoss(nn.Module):
|
| 25 |
+
def __init__(self, kl_tolerance):
|
| 26 |
+
super(KLLoss, self).__init__()
|
| 27 |
+
self.kl_tolerance = kl_tolerance
|
| 28 |
+
|
| 29 |
+
def forward(self, mu, var, mul=1):
|
| 30 |
+
kl_tolerance = self.kl_tolerance * mul * var.shape[1] / 64
|
| 31 |
+
kld_loss = -0.5 * torch.sum(1 + var - mu**2 - var.exp(), dim=1)
|
| 32 |
+
# kld_loss = -0.5 * torch.sum(1 + (var-1) - (mu) ** 2 - (var-1).exp(), dim=1)
|
| 33 |
+
if self.kl_tolerance is not None:
|
| 34 |
+
# above_line = kld_loss[kld_loss > self.kl_tolerance]
|
| 35 |
+
# if len(above_line) > 0:
|
| 36 |
+
# kld_loss = torch.mean(kld_loss)
|
| 37 |
+
# else:
|
| 38 |
+
# kld_loss = 0
|
| 39 |
+
kld_loss = torch.where(kld_loss > kl_tolerance, kld_loss, torch.tensor(kl_tolerance, device='cuda'))
|
| 40 |
+
# else:
|
| 41 |
+
kld_loss = torch.mean(kld_loss)
|
| 42 |
+
return kld_loss
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class L2KLLoss(nn.Module):
|
| 46 |
+
def __init__(self, kl_tolerance):
|
| 47 |
+
super(L2KLLoss, self).__init__()
|
| 48 |
+
self.kl_tolerance = kl_tolerance
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
# TODO: check
|
| 52 |
+
kld_loss = torch.sum(x ** 2, dim=1)
|
| 53 |
+
if self.kl_tolerance is not None:
|
| 54 |
+
above_line = kld_loss[kld_loss > self.kl_tolerance]
|
| 55 |
+
if len(above_line) > 0:
|
| 56 |
+
kld_loss = torch.mean(kld_loss)
|
| 57 |
+
else:
|
| 58 |
+
kld_loss = 0
|
| 59 |
+
else:
|
| 60 |
+
kld_loss = torch.mean(kld_loss)
|
| 61 |
+
return kld_loss
|
| 62 |
+
|
| 63 |
+
class L2RegLoss(nn.Module):
|
| 64 |
+
def __init__(self):
|
| 65 |
+
super(L2RegLoss, self).__init__()
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
#TODO: check
|
| 69 |
+
return torch.sum(x**2)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class L2Loss(nn.Module):
|
| 73 |
+
def __init__(self):
|
| 74 |
+
super(L2Loss, self).__init__()
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
# TODO: check
|
| 78 |
+
return torch.sum(x ** 2)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class AudioLoss(nn.Module):
|
| 82 |
+
def __init__(self):
|
| 83 |
+
super(AudioLoss, self).__init__()
|
| 84 |
+
|
| 85 |
+
def forward(self, dynamics, gt_poses):
|
| 86 |
+
#pay attention, normalized
|
| 87 |
+
mean = torch.mean(gt_poses, dim=-1).unsqueeze(-1)
|
| 88 |
+
gt = gt_poses - mean
|
| 89 |
+
return F.mse_loss(dynamics, gt)
|
| 90 |
+
|
| 91 |
+
L1Loss = nn.L1Loss
|
nets/LS3DCG.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
not exactly the same as the official repo but the results are good
|
| 3 |
+
'''
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
from data_utils.lower_body import c_index_3d, c_index_6d
|
| 8 |
+
|
| 9 |
+
sys.path.append(os.getcwd())
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.optim as optim
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import math
|
| 17 |
+
|
| 18 |
+
from nets.base import TrainWrapperBaseClass
|
| 19 |
+
from nets.layers import SeqEncoder1D
|
| 20 |
+
from losses import KeypointLoss, L1Loss, KLLoss
|
| 21 |
+
from data_utils.utils import get_melspec, get_mfcc_psf, get_mfcc_ta
|
| 22 |
+
from nets.utils import denormalize
|
| 23 |
+
|
| 24 |
+
class Conv1d_tf(nn.Conv1d):
|
| 25 |
+
"""
|
| 26 |
+
Conv1d with the padding behavior from TF
|
| 27 |
+
modified from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, *args, **kwargs):
|
| 31 |
+
super(Conv1d_tf, self).__init__(*args, **kwargs)
|
| 32 |
+
self.padding = kwargs.get("padding", "same")
|
| 33 |
+
|
| 34 |
+
def _compute_padding(self, input, dim):
|
| 35 |
+
input_size = input.size(dim + 2)
|
| 36 |
+
filter_size = self.weight.size(dim + 2)
|
| 37 |
+
effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1
|
| 38 |
+
out_size = (input_size + self.stride[dim] - 1) // self.stride[dim]
|
| 39 |
+
total_padding = max(
|
| 40 |
+
0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size
|
| 41 |
+
)
|
| 42 |
+
additional_padding = int(total_padding % 2 != 0)
|
| 43 |
+
|
| 44 |
+
return additional_padding, total_padding
|
| 45 |
+
|
| 46 |
+
def forward(self, input):
|
| 47 |
+
if self.padding == "VALID":
|
| 48 |
+
return F.conv1d(
|
| 49 |
+
input,
|
| 50 |
+
self.weight,
|
| 51 |
+
self.bias,
|
| 52 |
+
self.stride,
|
| 53 |
+
padding=0,
|
| 54 |
+
dilation=self.dilation,
|
| 55 |
+
groups=self.groups,
|
| 56 |
+
)
|
| 57 |
+
rows_odd, padding_rows = self._compute_padding(input, dim=0)
|
| 58 |
+
if rows_odd:
|
| 59 |
+
input = F.pad(input, [0, rows_odd])
|
| 60 |
+
|
| 61 |
+
return F.conv1d(
|
| 62 |
+
input,
|
| 63 |
+
self.weight,
|
| 64 |
+
self.bias,
|
| 65 |
+
self.stride,
|
| 66 |
+
padding=(padding_rows // 2),
|
| 67 |
+
dilation=self.dilation,
|
| 68 |
+
groups=self.groups,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def ConvNormRelu(in_channels, out_channels, type='1d', downsample=False, k=None, s=None, norm='bn', padding='valid'):
|
| 73 |
+
if k is None and s is None:
|
| 74 |
+
if not downsample:
|
| 75 |
+
k = 3
|
| 76 |
+
s = 1
|
| 77 |
+
else:
|
| 78 |
+
k = 4
|
| 79 |
+
s = 2
|
| 80 |
+
|
| 81 |
+
if type == '1d':
|
| 82 |
+
conv_block = Conv1d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding)
|
| 83 |
+
if norm == 'bn':
|
| 84 |
+
norm_block = nn.BatchNorm1d(out_channels)
|
| 85 |
+
elif norm == 'ln':
|
| 86 |
+
norm_block = nn.LayerNorm(out_channels)
|
| 87 |
+
elif type == '2d':
|
| 88 |
+
conv_block = Conv2d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding)
|
| 89 |
+
norm_block = nn.BatchNorm2d(out_channels)
|
| 90 |
+
else:
|
| 91 |
+
assert False
|
| 92 |
+
|
| 93 |
+
return nn.Sequential(
|
| 94 |
+
conv_block,
|
| 95 |
+
norm_block,
|
| 96 |
+
nn.LeakyReLU(0.2, True)
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
class Decoder(nn.Module):
|
| 100 |
+
def __init__(self, in_ch, out_ch):
|
| 101 |
+
super(Decoder, self).__init__()
|
| 102 |
+
self.up1 = nn.Sequential(
|
| 103 |
+
ConvNormRelu(in_ch // 2 + in_ch, in_ch // 2),
|
| 104 |
+
ConvNormRelu(in_ch // 2, in_ch // 2),
|
| 105 |
+
nn.Upsample(scale_factor=2, mode='nearest')
|
| 106 |
+
)
|
| 107 |
+
self.up2 = nn.Sequential(
|
| 108 |
+
ConvNormRelu(in_ch // 4 + in_ch // 2, in_ch // 4),
|
| 109 |
+
ConvNormRelu(in_ch // 4, in_ch // 4),
|
| 110 |
+
nn.Upsample(scale_factor=2, mode='nearest')
|
| 111 |
+
)
|
| 112 |
+
self.up3 = nn.Sequential(
|
| 113 |
+
ConvNormRelu(in_ch // 8 + in_ch // 4, in_ch // 8),
|
| 114 |
+
ConvNormRelu(in_ch // 8, in_ch // 8),
|
| 115 |
+
nn.Conv1d(in_ch // 8, out_ch, 1, 1)
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def forward(self, x, x1, x2, x3):
|
| 119 |
+
x = F.interpolate(x, x3.shape[2])
|
| 120 |
+
x = torch.cat([x, x3], dim=1)
|
| 121 |
+
x = self.up1(x)
|
| 122 |
+
x = F.interpolate(x, x2.shape[2])
|
| 123 |
+
x = torch.cat([x, x2], dim=1)
|
| 124 |
+
x = self.up2(x)
|
| 125 |
+
x = F.interpolate(x, x1.shape[2])
|
| 126 |
+
x = torch.cat([x, x1], dim=1)
|
| 127 |
+
x = self.up3(x)
|
| 128 |
+
return x
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class EncoderDecoder(nn.Module):
|
| 132 |
+
def __init__(self, n_frames, each_dim):
|
| 133 |
+
super().__init__()
|
| 134 |
+
self.n_frames = n_frames
|
| 135 |
+
|
| 136 |
+
self.down1 = nn.Sequential(
|
| 137 |
+
ConvNormRelu(64, 64, '1d', False),
|
| 138 |
+
ConvNormRelu(64, 128, '1d', False),
|
| 139 |
+
)
|
| 140 |
+
self.down2 = nn.Sequential(
|
| 141 |
+
ConvNormRelu(128, 128, '1d', False),
|
| 142 |
+
ConvNormRelu(128, 256, '1d', False),
|
| 143 |
+
)
|
| 144 |
+
self.down3 = nn.Sequential(
|
| 145 |
+
ConvNormRelu(256, 256, '1d', False),
|
| 146 |
+
ConvNormRelu(256, 512, '1d', False),
|
| 147 |
+
)
|
| 148 |
+
self.down4 = nn.Sequential(
|
| 149 |
+
ConvNormRelu(512, 512, '1d', False),
|
| 150 |
+
ConvNormRelu(512, 1024, '1d', False),
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
self.down = nn.MaxPool1d(kernel_size=2)
|
| 154 |
+
self.up = nn.Upsample(scale_factor=2, mode='nearest')
|
| 155 |
+
|
| 156 |
+
self.face_decoder = Decoder(1024, each_dim[0] + each_dim[3])
|
| 157 |
+
self.body_decoder = Decoder(1024, each_dim[1])
|
| 158 |
+
self.hand_decoder = Decoder(1024, each_dim[2])
|
| 159 |
+
|
| 160 |
+
def forward(self, spectrogram, time_steps=None):
|
| 161 |
+
if time_steps is None:
|
| 162 |
+
time_steps = self.n_frames
|
| 163 |
+
|
| 164 |
+
x1 = self.down1(spectrogram)
|
| 165 |
+
x = self.down(x1)
|
| 166 |
+
x2 = self.down2(x)
|
| 167 |
+
x = self.down(x2)
|
| 168 |
+
x3 = self.down3(x)
|
| 169 |
+
x = self.down(x3)
|
| 170 |
+
x = self.down4(x)
|
| 171 |
+
x = self.up(x)
|
| 172 |
+
|
| 173 |
+
face = self.face_decoder(x, x1, x2, x3)
|
| 174 |
+
body = self.body_decoder(x, x1, x2, x3)
|
| 175 |
+
hand = self.hand_decoder(x, x1, x2, x3)
|
| 176 |
+
|
| 177 |
+
return face, body, hand
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class Generator(nn.Module):
|
| 181 |
+
def __init__(self,
|
| 182 |
+
each_dim,
|
| 183 |
+
training=False,
|
| 184 |
+
device=None
|
| 185 |
+
):
|
| 186 |
+
super().__init__()
|
| 187 |
+
|
| 188 |
+
self.training = training
|
| 189 |
+
self.device = device
|
| 190 |
+
|
| 191 |
+
self.encoderdecoder = EncoderDecoder(15, each_dim)
|
| 192 |
+
|
| 193 |
+
def forward(self, in_spec, time_steps=None):
|
| 194 |
+
if time_steps is not None:
|
| 195 |
+
self.gen_length = time_steps
|
| 196 |
+
|
| 197 |
+
face, body, hand = self.encoderdecoder(in_spec)
|
| 198 |
+
out = torch.cat([face, body, hand], dim=1)
|
| 199 |
+
out = out.transpose(1, 2)
|
| 200 |
+
|
| 201 |
+
return out
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class Discriminator(nn.Module):
|
| 205 |
+
def __init__(self, input_dim):
|
| 206 |
+
super().__init__()
|
| 207 |
+
self.net = nn.Sequential(
|
| 208 |
+
ConvNormRelu(input_dim, 128, '1d'),
|
| 209 |
+
ConvNormRelu(128, 256, '1d'),
|
| 210 |
+
nn.MaxPool1d(kernel_size=2),
|
| 211 |
+
ConvNormRelu(256, 256, '1d'),
|
| 212 |
+
ConvNormRelu(256, 512, '1d'),
|
| 213 |
+
nn.MaxPool1d(kernel_size=2),
|
| 214 |
+
ConvNormRelu(512, 512, '1d'),
|
| 215 |
+
ConvNormRelu(512, 1024, '1d'),
|
| 216 |
+
nn.MaxPool1d(kernel_size=2),
|
| 217 |
+
nn.Conv1d(1024, 1, 1, 1),
|
| 218 |
+
nn.Sigmoid()
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
def forward(self, x):
|
| 222 |
+
x = x.transpose(1, 2)
|
| 223 |
+
|
| 224 |
+
out = self.net(x)
|
| 225 |
+
return out
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class TrainWrapper(TrainWrapperBaseClass):
|
| 229 |
+
def __init__(self, args, config) -> None:
|
| 230 |
+
self.args = args
|
| 231 |
+
self.config = config
|
| 232 |
+
self.device = torch.device(self.args.gpu)
|
| 233 |
+
self.global_step = 0
|
| 234 |
+
self.convert_to_6d = self.config.Data.pose.convert_to_6d
|
| 235 |
+
self.init_params()
|
| 236 |
+
|
| 237 |
+
self.generator = Generator(
|
| 238 |
+
each_dim=self.each_dim,
|
| 239 |
+
training=not self.args.infer,
|
| 240 |
+
device=self.device,
|
| 241 |
+
).to(self.device)
|
| 242 |
+
self.discriminator = Discriminator(
|
| 243 |
+
input_dim=self.each_dim[1] + self.each_dim[2] + 64
|
| 244 |
+
).to(self.device)
|
| 245 |
+
if self.convert_to_6d:
|
| 246 |
+
self.c_index = c_index_6d
|
| 247 |
+
else:
|
| 248 |
+
self.c_index = c_index_3d
|
| 249 |
+
self.MSELoss = KeypointLoss().to(self.device)
|
| 250 |
+
self.L1Loss = L1Loss().to(self.device)
|
| 251 |
+
super().__init__(args, config)
|
| 252 |
+
|
| 253 |
+
def init_params(self):
|
| 254 |
+
scale = 1
|
| 255 |
+
|
| 256 |
+
global_orient = round(0 * scale)
|
| 257 |
+
leye_pose = reye_pose = round(0 * scale)
|
| 258 |
+
jaw_pose = round(3 * scale)
|
| 259 |
+
body_pose = round((63 - 24) * scale)
|
| 260 |
+
left_hand_pose = right_hand_pose = round(45 * scale)
|
| 261 |
+
|
| 262 |
+
expression = 100
|
| 263 |
+
|
| 264 |
+
b_j = 0
|
| 265 |
+
jaw_dim = jaw_pose
|
| 266 |
+
b_e = b_j + jaw_dim
|
| 267 |
+
eye_dim = leye_pose + reye_pose
|
| 268 |
+
b_b = b_e + eye_dim
|
| 269 |
+
body_dim = global_orient + body_pose
|
| 270 |
+
b_h = b_b + body_dim
|
| 271 |
+
hand_dim = left_hand_pose + right_hand_pose
|
| 272 |
+
b_f = b_h + hand_dim
|
| 273 |
+
face_dim = expression
|
| 274 |
+
|
| 275 |
+
self.dim_list = [b_j, b_e, b_b, b_h, b_f]
|
| 276 |
+
self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
|
| 277 |
+
self.pose = int(self.full_dim / round(3 * scale))
|
| 278 |
+
self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
|
| 279 |
+
|
| 280 |
+
def __call__(self, bat):
|
| 281 |
+
assert (not self.args.infer), "infer mode"
|
| 282 |
+
self.global_step += 1
|
| 283 |
+
|
| 284 |
+
loss_dict = {}
|
| 285 |
+
|
| 286 |
+
aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
|
| 287 |
+
expression = bat['expression'].to(self.device).to(torch.float32)
|
| 288 |
+
jaw = poses[:, :3, :]
|
| 289 |
+
poses = poses[:, self.c_index, :]
|
| 290 |
+
|
| 291 |
+
pred = self.generator(in_spec=aud)
|
| 292 |
+
|
| 293 |
+
D_loss, D_loss_dict = self.get_loss(
|
| 294 |
+
pred_poses=pred.detach(),
|
| 295 |
+
gt_poses=poses,
|
| 296 |
+
aud=aud,
|
| 297 |
+
mode='training_D',
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
self.discriminator_optimizer.zero_grad()
|
| 301 |
+
D_loss.backward()
|
| 302 |
+
self.discriminator_optimizer.step()
|
| 303 |
+
|
| 304 |
+
G_loss, G_loss_dict = self.get_loss(
|
| 305 |
+
pred_poses=pred,
|
| 306 |
+
gt_poses=poses,
|
| 307 |
+
aud=aud,
|
| 308 |
+
expression=expression,
|
| 309 |
+
jaw=jaw,
|
| 310 |
+
mode='training_G',
|
| 311 |
+
)
|
| 312 |
+
self.generator_optimizer.zero_grad()
|
| 313 |
+
G_loss.backward()
|
| 314 |
+
self.generator_optimizer.step()
|
| 315 |
+
|
| 316 |
+
total_loss = None
|
| 317 |
+
loss_dict = {}
|
| 318 |
+
for key in list(D_loss_dict.keys()) + list(G_loss_dict.keys()):
|
| 319 |
+
loss_dict[key] = G_loss_dict.get(key, 0) + D_loss_dict.get(key, 0)
|
| 320 |
+
|
| 321 |
+
return total_loss, loss_dict
|
| 322 |
+
|
| 323 |
+
def get_loss(self,
|
| 324 |
+
pred_poses,
|
| 325 |
+
gt_poses,
|
| 326 |
+
aud=None,
|
| 327 |
+
jaw=None,
|
| 328 |
+
expression=None,
|
| 329 |
+
mode='training_G',
|
| 330 |
+
):
|
| 331 |
+
loss_dict = {}
|
| 332 |
+
aud = aud.transpose(1, 2)
|
| 333 |
+
gt_poses = gt_poses.transpose(1, 2)
|
| 334 |
+
gt_aud = torch.cat([gt_poses, aud], dim=2)
|
| 335 |
+
pred_aud = torch.cat([pred_poses[:, :, 103:], aud], dim=2)
|
| 336 |
+
|
| 337 |
+
if mode == 'training_D':
|
| 338 |
+
dis_real = self.discriminator(gt_aud)
|
| 339 |
+
dis_fake = self.discriminator(pred_aud)
|
| 340 |
+
dis_error = self.MSELoss(torch.ones_like(dis_real).to(self.device), dis_real) + self.MSELoss(
|
| 341 |
+
torch.zeros_like(dis_fake).to(self.device), dis_fake)
|
| 342 |
+
loss_dict['dis'] = dis_error
|
| 343 |
+
|
| 344 |
+
return dis_error, loss_dict
|
| 345 |
+
elif mode == 'training_G':
|
| 346 |
+
jaw_loss = self.L1Loss(pred_poses[:, :, :3], jaw.transpose(1, 2))
|
| 347 |
+
face_loss = self.MSELoss(pred_poses[:, :, 3:103], expression.transpose(1, 2))
|
| 348 |
+
body_loss = self.L1Loss(pred_poses[:, :, 103:142], gt_poses[:, :, :39])
|
| 349 |
+
hand_loss = self.L1Loss(pred_poses[:, :, 142:], gt_poses[:, :, 39:])
|
| 350 |
+
l1_loss = jaw_loss + face_loss + body_loss + hand_loss
|
| 351 |
+
|
| 352 |
+
dis_output = self.discriminator(pred_aud)
|
| 353 |
+
gen_error = self.MSELoss(torch.ones_like(dis_output).to(self.device), dis_output)
|
| 354 |
+
gen_loss = self.config.Train.weights.keypoint_loss_weight * l1_loss + self.config.Train.weights.gan_loss_weight * gen_error
|
| 355 |
+
|
| 356 |
+
loss_dict['gen'] = gen_error
|
| 357 |
+
loss_dict['jaw_loss'] = jaw_loss
|
| 358 |
+
loss_dict['face_loss'] = face_loss
|
| 359 |
+
loss_dict['body_loss'] = body_loss
|
| 360 |
+
loss_dict['hand_loss'] = hand_loss
|
| 361 |
+
return gen_loss, loss_dict
|
| 362 |
+
else:
|
| 363 |
+
raise ValueError(mode)
|
| 364 |
+
|
| 365 |
+
def infer_on_audio(self, aud_fn, fps=30, initial_pose=None, norm_stats=None, id=None, B=1, **kwargs):
|
| 366 |
+
output = []
|
| 367 |
+
assert self.args.infer, "train mode"
|
| 368 |
+
self.generator.eval()
|
| 369 |
+
|
| 370 |
+
if self.config.Data.pose.normalization:
|
| 371 |
+
assert norm_stats is not None
|
| 372 |
+
data_mean = norm_stats[0]
|
| 373 |
+
data_std = norm_stats[1]
|
| 374 |
+
|
| 375 |
+
pre_length = self.config.Data.pose.pre_pose_length
|
| 376 |
+
generate_length = self.config.Data.pose.generate_length
|
| 377 |
+
# assert pre_length == initial_pose.shape[-1]
|
| 378 |
+
# pre_poses = initial_pose.permute(0, 2, 1).to(self.device).to(torch.float32)
|
| 379 |
+
# B = pre_poses.shape[0]
|
| 380 |
+
|
| 381 |
+
aud_feat = get_mfcc_ta(aud_fn, sr=22000, fps=fps, smlpx=True, type='mfcc').transpose(1, 0)
|
| 382 |
+
num_poses_to_generate = aud_feat.shape[-1]
|
| 383 |
+
aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
|
| 384 |
+
aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device)
|
| 385 |
+
|
| 386 |
+
with torch.no_grad():
|
| 387 |
+
pred_poses = self.generator(aud_feat)
|
| 388 |
+
pred_poses = pred_poses.cpu().numpy()
|
| 389 |
+
output = pred_poses.squeeze()
|
| 390 |
+
|
| 391 |
+
return output
|
| 392 |
+
|
| 393 |
+
def generate(self, aud, id):
|
| 394 |
+
self.generator.eval()
|
| 395 |
+
pred_poses = self.generator(aud)
|
| 396 |
+
return pred_poses
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
if __name__ == '__main__':
|
| 400 |
+
from trainer.options import parse_args
|
| 401 |
+
|
| 402 |
+
parser = parse_args()
|
| 403 |
+
args = parser.parse_args(
|
| 404 |
+
['--exp_name', '0', '--data_root', '0', '--speakers', '0', '--pre_pose_length', '4', '--generate_length', '64',
|
| 405 |
+
'--infer'])
|
| 406 |
+
|
| 407 |
+
generator = TrainWrapper(args)
|
| 408 |
+
|
| 409 |
+
aud_fn = '../sample_audio/jon.wav'
|
| 410 |
+
initial_pose = torch.randn(64, 108, 4)
|
| 411 |
+
norm_stats = (np.random.randn(108), np.random.randn(108))
|
| 412 |
+
output = generator.infer_on_audio(aud_fn, initial_pose, norm_stats)
|
| 413 |
+
|
| 414 |
+
print(output.shape)
|
nets/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .smplx_face import TrainWrapper as s2g_face
|
| 2 |
+
from .smplx_body_vq import TrainWrapper as s2g_body_vq
|
| 3 |
+
from .smplx_body_pixel import TrainWrapper as s2g_body_pixel
|
| 4 |
+
from .body_ae import TrainWrapper as s2g_body_ae
|
| 5 |
+
from .LS3DCG import TrainWrapper as LS3DCG
|
| 6 |
+
from .base import TrainWrapperBaseClass
|
| 7 |
+
|
| 8 |
+
from .utils import normalize, denormalize
|
nets/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (407 Bytes). View file
|
|
|
nets/__pycache__/base.cpython-37.pyc
ADDED
|
Binary file (2.29 kB). View file
|
|
|
nets/__pycache__/init_model.cpython-37.pyc
ADDED
|
Binary file (460 Bytes). View file
|
|
|