sshravani commited on
Commit
ec76118
·
1 Parent(s): f8b6a4b

Added visualise folder files from GitHub repo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.DS_Store DELETED
Binary file (8.2 kB)
 
.gitattributes DELETED
@@ -1 +0,0 @@
1
- demo_audio/1st-page.wav filter=lfs diff=lfs merge=lfs -text
 
 
.gitignore DELETED
@@ -1,13 +0,0 @@
1
- cat > .gitignore << EOF
2
- # Binary and large files
3
- *.pkl
4
- *.mp4
5
- *.npy
6
- # Demo binary files
7
- demo/**/*.mp4
8
- demo/**/*.npy
9
- # Large model files
10
- experiments/
11
- # Any other large files
12
- visualise/teaser_01.png
13
- EOF
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Dockerfile DELETED
@@ -1,41 +0,0 @@
1
- FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-runtime
2
-
3
- # Install system dependencies
4
- RUN apt-get update && apt-get install -y \
5
- ffmpeg \
6
- libgl1-mesa-glx \
7
- git \
8
- wget \
9
- unzip \
10
- libsndfile1 \
11
- && rm -rf /var/lib/apt/lists/*
12
-
13
- # Set up a non-root user for Hugging Face Space compatibility
14
- RUN useradd -m -u 1000 user
15
- USER user
16
- WORKDIR /home/user
17
-
18
- # Copy project files
19
- COPY --chown=user requirements.txt .
20
- COPY --chown=user . .
21
-
22
- # Install Python dependencies
23
- RUN pip install --no-cache-dir -r requirements.txt
24
-
25
- # Create necessary directories
26
- RUN mkdir -p visualise/smplx_model \
27
- && mkdir -p experiments \
28
- && mkdir -p visualise/video/body-pixel \
29
- && mkdir -p visualise/video/body-pixel2 \
30
- && mkdir -p demo_audio
31
-
32
- # Set environment variables for GPU and Python
33
- ENV PYTHONUNBUFFERED=1
34
- ENV NVIDIA_VISIBLE_DEVICES=all
35
- ENV NVIDIA_DRIVER_CAPABILITIES=all
36
-
37
- # Expose Gradio port
38
- EXPOSE 7860
39
-
40
- # Default command
41
- CMD ["python", "app.py"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md DELETED
@@ -1,118 +0,0 @@
1
- ---
2
- title: TalkSHOW Speech-to-Motion Translation
3
- emoji: 🎙️
4
- colorFrom: blue
5
- colorTo: purple
6
- sdk: docker
7
- app_port: 7860
8
- pinned: false
9
- license: mit
10
- ---
11
-
12
- # Team 14 - TalkSHOW: Generating Holistic 3D Human Motion from Speech
13
-
14
- Contributors - Abinaya Odeti , Shipra , Shravani , Vishal
15
-
16
- ![teaser](visualise/teaser_01.png)
17
-
18
- ## About
19
-
20
- This repository hosts the implementation of "TalkSHOW: A Speech-to-Motion Translation System", which maps raw audio input to full-body 3D motion using the SMPL-X model. It enables synchronized generation of expressive human body motion (including face, hands, and body) from speech input — supporting real-time animation, virtual avatars, and digital storytelling.
21
-
22
- ## Highlights
23
-
24
- Translates raw .wav audio into natural whole-body motion (jaw, pose, expressions, hands) using deep learning.
25
-
26
- Based on SMPL-X model for realistic 3D human mesh generation.
27
-
28
- Modular pipeline with support for face-body composition.
29
-
30
- Visualization with OpenGL & FFmpeg for final rendered video.
31
-
32
- End-to-end customizable configuration with audio models, latent generation, and rendering.
33
-
34
- ## Prerequisites
35
-
36
- Python 3.7+
37
-
38
- Anaconda for environment management
39
-
40
- Install required packages:
41
-
42
- ```bash
43
- pip install -r requirements.txt
44
- ```
45
- Install FFmpeg
46
-
47
- ➤ Extract the FFmpeg ZIP and add its bin folder to System PATH
48
-
49
-
50
- ## Getting started
51
-
52
- The visualization code was test on `Windows 10`, and it requires:
53
-
54
- * Python 3.7
55
- * conda3 or miniconda3
56
- * CUDA capable GPU (one is enough)
57
-
58
-
59
-
60
- ### 1. Setup and Steps
61
-
62
- Clone the repo:
63
- ```bash
64
- git clone https://github.com/YOUR_USERNAME/TALKSHOW-speech-to-motion-translation-system.git
65
- cd TalkSHOW
66
- ```
67
- Create conda environment:
68
- ```bash
69
- conda create -n talkshow python=3.7 -y
70
- conda activate talkshow
71
- pip install -r requirements.txt
72
- ```
73
-
74
- ### 2.Download models
75
- Download or place the required checkpoints:
76
- Download [**pretrained models**](https://drive.google.com/file/d/1bC0ZTza8HOhLB46WOJ05sBywFvcotDZG/view?usp=sharing),
77
- unzip and place it in the TalkSHOW folder, i.e. ``path-to-TalkSHOW/experiments``.
78
-
79
- Download [**smplx model**](https://drive.google.com/file/d/1Ly_hQNLQcZ89KG0Nj4jYZwccQiimSUVn/view?usp=share_link) (Please register in the official [**SMPLX webpage**](https://smpl-x.is.tue.mpg.de) before you use it.)
80
- and place it in ``path-to-TalkSHOW/visualise/smplx_model``.
81
- To visualise the test set and generated result (in each video, left: generated result | right: ground truth).
82
- The videos and generated motion data are saved in ``./visualise/video/body-pixel``:
83
-
84
- SMPLX Model Weights – visualise/smplx_model/SMPLX_NEUTRAL_2020.npz
85
-
86
- Extra joints, regressors, YAML configs – inside visualise/smplx_model/
87
-
88
- Also, ensure vq_path in body_pixel.json points to a valid .pth model (in ./experiments/.../ckpt-*.pth)
89
-
90
-
91
- ### 3.🎙️ Running Inference
92
-
93
- To generate a 3D animated video from an audio file:
94
- ```bash
95
- python scripts/demo.py \
96
- --config_file ./config/body_pixel.json \
97
- --infer \
98
- --audio_file ./demo_audio/1st-page.wav \
99
- --id 0 \
100
- --whole_body
101
- ```
102
- Change Input
103
- Replace --audio_file value with your own .wav file path.
104
-
105
-
106
- ### 4. Output
107
- The final 3D animated video will be saved under:
108
- ```bash
109
- visualise/video/body-pixel2/<audio_file_name>/1st-page.mp4
110
- ```
111
- The exact command you used to run the project
112
- ```bash
113
- python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/1st-page.wav --id 0 --whole_body
114
- ```
115
-
116
- ### Contact
117
-
118
- For issues or questions, raise an issue or contact the contributors directly!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__init__.py DELETED
File without changes
app.py DELETED
@@ -1,98 +0,0 @@
1
- import gradio as gr
2
- import os
3
- import subprocess
4
- import time
5
- import logging
6
- import traceback
7
-
8
- def process_audio(audio_file):
9
- # Configure detailed logging
10
- logging.basicConfig(level=logging.DEBUG,
11
- format='%(asctime)s - %(levelname)s - %(message)s')
12
- logger = logging.getLogger(__name__)
13
-
14
- try:
15
- # Detailed logging for input
16
- logger.info(f"Received audio file: {audio_file}")
17
- logger.info(f"Audio file exists: {os.path.exists(audio_file)}")
18
-
19
- # Validate input file
20
- if not audio_file or not os.path.exists(audio_file):
21
- raise ValueError(f"Invalid or non-existent audio file: {audio_file}")
22
-
23
- # Ensure output directory exists
24
- os.makedirs("visualise/video/body-pixel2", exist_ok=True)
25
-
26
- # Debugging: print current working directory and file details
27
- logger.debug(f"Current working directory: {os.getcwd()}")
28
- logger.debug(f"Audio file path: {os.path.abspath(audio_file)}")
29
- logger.debug(f"Audio file size: {os.path.getsize(audio_file)} bytes")
30
-
31
- # Construct command with full paths
32
- cmd = [
33
- "python",
34
- os.path.abspath("scripts/demo.py"),
35
- "--config_file", os.path.abspath("config/body_pixel.json"),
36
- "--infer",
37
- "--audio_file", os.path.abspath(audio_file),
38
- "--id", "0",
39
- "--whole_body"
40
- ]
41
-
42
- logger.info(f"Executing command: {' '.join(cmd)}")
43
-
44
- # Run with more detailed error capture
45
- result = subprocess.run(
46
- cmd,
47
- stdout=subprocess.PIPE,
48
- stderr=subprocess.PIPE,
49
- text=True,
50
- cwd=os.getcwd(), # Ensure correct working directory
51
- timeout=1800
52
- )
53
-
54
- # Log full command output
55
- logger.info(f"Command STDOUT: {result.stdout}")
56
- logger.error(f"Command STDERR: {result.stderr}")
57
-
58
- # Determine output video path
59
- audio_name = os.path.splitext(os.path.basename(audio_file))[0]
60
- output_dir = f"visualise/video/body-pixel2/{audio_name}"
61
- output_path = f"{output_dir}/1st-page.mp4"
62
-
63
- logger.info(f"Expected output path: {output_path}")
64
-
65
- # Check output video
66
- if os.path.exists(output_path):
67
- logger.info(f"Output video found: {output_path}")
68
- return output_path
69
- else:
70
- logger.error("Output video not generated")
71
- return None, f"Error: Output video not generated. STDERR: {result.stderr}"
72
-
73
- except subprocess.TimeoutExpired:
74
- logger.error("Inference process timed out")
75
- return None, "Error: Inference process took too long"
76
-
77
- except Exception as e:
78
- logger.error(f"Unexpected error: {str(e)}")
79
- logger.error(traceback.format_exc())
80
- return None, f"Unexpected error: {str(e)}"
81
-
82
- # Gradio Interface for 3.x compatibility
83
- demo = gr.Interface(
84
- fn=process_audio,
85
- inputs=gr.inputs.File(type="file", label="Upload Audio File"),
86
- outputs=gr.outputs.Video(label="Generated Motion Video"),
87
- title="TalkSHOW: Speech-to-Motion Translation System",
88
- description="Convert speech audio to realistic 3D human motion using the SMPL-X model.",
89
- examples=[["demo_audio/1st-page.wav"]]
90
- )
91
-
92
- # Launch with comprehensive logging
93
- if __name__ == "__main__":
94
- demo.launch(
95
- server_name="0.0.0.0",
96
- server_port=7860,
97
- debug=True # Enable Gradio debug mode
98
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/LS3DCG.json DELETED
@@ -1,65 +0,0 @@
1
- {
2
- "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
3
- "dataset_load_mode": "pickle",
4
- "store_file_path": "store.pkl",
5
- "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
6
- "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
7
- "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
8
- "param": {
9
- "w_j": 1,
10
- "w_b": 1,
11
- "w_h": 1
12
- },
13
- "Data": {
14
- "data_root": "../ExpressiveWholeBodyDatasetv1.0/",
15
- "pklname": "_3d_mfcc.pkl",
16
- "whole_video": false,
17
- "pose": {
18
- "normalization": false,
19
- "convert_to_6d": false,
20
- "norm_method": "all",
21
- "augmentation": false,
22
- "generate_length": 88,
23
- "pre_pose_length": 0,
24
- "pose_dim": 99,
25
- "expression": true
26
- },
27
- "aud": {
28
- "feat_method": "mfcc",
29
- "aud_feat_dim": 64,
30
- "aud_feat_win_size": null,
31
- "context_info": false
32
- }
33
- },
34
- "Model": {
35
- "model_type": "body",
36
- "model_name": "s2g_LS3DCG",
37
- "code_num": 2048,
38
- "AudioOpt": "Adam",
39
- "encoder_choice": "mfcc",
40
- "gan": false
41
- },
42
- "DataLoader": {
43
- "batch_size": 128,
44
- "num_workers": 0
45
- },
46
- "Train": {
47
- "epochs": 100,
48
- "max_gradient_norm": 5,
49
- "learning_rate": {
50
- "generator_learning_rate": 1e-4,
51
- "discriminator_learning_rate": 1e-4
52
- },
53
- "weights": {
54
- "keypoint_loss_weight": 1.0,
55
- "gan_loss_weight": 1.0
56
- }
57
- },
58
- "Log": {
59
- "save_every": 50,
60
- "print_every": 200,
61
- "name": "LS3DCG"
62
- },
63
- "device": "cpu"
64
- }
65
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/body_pixel.json DELETED
@@ -1,63 +0,0 @@
1
- {
2
- "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
3
- "dataset_load_mode": "json",
4
- "store_file_path": "store.pkl",
5
- "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
6
- "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
7
- "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
8
- "param": {
9
- "w_j": 1,
10
- "w_b": 1,
11
- "w_h": 1
12
- },
13
- "Data": {
14
- "data_root": "../ExpressiveWholeBodyDatasetv1.0/",
15
- "pklname": "_3d_mfcc.pkl",
16
- "whole_video": false,
17
- "pose": {
18
- "normalization": false,
19
- "convert_to_6d": false,
20
- "norm_method": "all",
21
- "augmentation": false,
22
- "generate_length": 88,
23
- "pre_pose_length": 0,
24
- "pose_dim": 99,
25
- "expression": true
26
- },
27
- "aud": {
28
- "feat_method": "mfcc",
29
- "aud_feat_dim": 64,
30
- "aud_feat_win_size": null,
31
- "context_info": false
32
- }
33
- },
34
- "Model": {
35
- "model_type": "body",
36
- "model_name": "s2g_body_pixel",
37
- "composition": true,
38
- "code_num": 2048,
39
- "bh_model": true,
40
- "AudioOpt": "Adam",
41
- "encoder_choice": "mfcc",
42
- "gan": false,
43
- "vq_path": "./experiments/2022-10-31-smplx_S2G-body-vq-3d/ckpt-99.pth"
44
- },
45
- "DataLoader": {
46
- "batch_size": 128,
47
- "num_workers": 0
48
- },
49
- "Train": {
50
- "epochs": 100,
51
- "max_gradient_norm": 5,
52
- "learning_rate": {
53
- "generator_learning_rate": 1e-4,
54
- "discriminator_learning_rate": 1e-4
55
- }
56
- },
57
- "Log": {
58
- "save_every": 50,
59
- "print_every": 200,
60
- "name": "body-pixel2"
61
- }
62
- }
63
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/body_vq.json DELETED
@@ -1,62 +0,0 @@
1
- {
2
- "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
3
- "dataset_load_mode": "json",
4
- "store_file_path": "store.pkl",
5
- "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
6
- "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
7
- "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
8
- "param": {
9
- "w_j": 1,
10
- "w_b": 1,
11
- "w_h": 1
12
- },
13
- "Data": {
14
- "data_root": "../ExpressiveWholeBodyDatasetv1.0/",
15
- "pklname": "_3d_mfcc.pkl",
16
- "whole_video": false,
17
- "pose": {
18
- "normalization": false,
19
- "convert_to_6d": false,
20
- "norm_method": "all",
21
- "augmentation": false,
22
- "generate_length": 88,
23
- "pre_pose_length": 0,
24
- "pose_dim": 99,
25
- "expression": true
26
- },
27
- "aud": {
28
- "feat_method": "mfcc",
29
- "aud_feat_dim": 64,
30
- "aud_feat_win_size": null,
31
- "context_info": false
32
- }
33
- },
34
- "Model": {
35
- "model_type": "body",
36
- "model_name": "s2g_body_vq",
37
- "composition": true,
38
- "code_num": 2048,
39
- "bh_model": true,
40
- "AudioOpt": "Adam",
41
- "encoder_choice": "mfcc",
42
- "gan": false
43
- },
44
- "DataLoader": {
45
- "batch_size": 128,
46
- "num_workers": 0
47
- },
48
- "Train": {
49
- "epochs": 100,
50
- "max_gradient_norm": 5,
51
- "learning_rate": {
52
- "generator_learning_rate": 1e-4,
53
- "discriminator_learning_rate": 1e-4
54
- }
55
- },
56
- "Log": {
57
- "save_every": 50,
58
- "print_every": 200,
59
- "name": "body-vq"
60
- }
61
- }
62
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/face.json DELETED
@@ -1,59 +0,0 @@
1
- {
2
- "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
3
- "dataset_load_mode": "json",
4
- "store_file_path": "store.pkl",
5
- "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
6
- "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
7
- "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
8
- "param": {
9
- "w_j": 1,
10
- "w_b": 1,
11
- "w_h": 1
12
- },
13
- "Data": {
14
- "data_root": "../ExpressiveWholeBodyDatasetv1.0/",
15
- "pklname": "_3d_wv2.pkl",
16
- "whole_video": true,
17
- "pose": {
18
- "normalization": false,
19
- "convert_to_6d": false,
20
- "norm_method": "all",
21
- "augmentation": false,
22
- "generate_length": 88,
23
- "pre_pose_length": 0,
24
- "pose_dim": 99,
25
- "expression": true
26
- },
27
- "aud": {
28
- "feat_method": "mfcc",
29
- "aud_feat_dim": 64,
30
- "aud_feat_win_size": null,
31
- "context_info": false
32
- }
33
- },
34
- "Model": {
35
- "model_type": "face",
36
- "model_name": "s2g_face",
37
- "AudioOpt": "SGD",
38
- "encoder_choice": "faceformer",
39
- "gan": false
40
- },
41
- "DataLoader": {
42
- "batch_size": 1,
43
- "num_workers": 0
44
- },
45
- "Train": {
46
- "epochs": 100,
47
- "max_gradient_norm": 5,
48
- "learning_rate": {
49
- "generator_learning_rate": 1e-4,
50
- "discriminator_learning_rate": 1e-4
51
- }
52
- },
53
- "Log": {
54
- "save_every": 50,
55
- "print_every": 1000,
56
- "name": "face"
57
- }
58
- }
59
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_utils/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- # from .dataloader_csv import MultiVidData as csv_data
2
- from .dataloader_torch import MultiVidData as torch_data
3
- from .utils import get_melspec, get_mfcc, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
 
 
 
 
data_utils/apply_split.py DELETED
@@ -1,51 +0,0 @@
1
- import os
2
- from tqdm import tqdm
3
- import pickle
4
- import shutil
5
-
6
- speakers = ['seth', 'oliver', 'conan', 'chemistry']
7
- source_data_root = "../expressive_body-V0.7"
8
- data_root = "D:/Downloads/SHOW_dataset_v1.0/ExpressiveWholeBodyDatasetReleaseV1.0"
9
-
10
- f_read = open('split_more_than_2s.pkl', 'rb')
11
- f_save = open('none.pkl', 'wb')
12
- data_split = pickle.load(f_read)
13
- none_split = []
14
-
15
- train = val = test = 0
16
-
17
- for speaker_name in speakers:
18
- speaker_root = os.path.join(data_root, speaker_name)
19
-
20
- videos = [v for v in data_split[speaker_name]]
21
-
22
- for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
23
- for split in data_split[speaker_name][vid]:
24
- for seq in data_split[speaker_name][vid][split]:
25
-
26
- seq = seq.replace('\\', '/')
27
- old_file_path = os.path.join(data_root, speaker_name, vid, seq.split('/')[-1])
28
- old_file_path = old_file_path.replace('\\', '/')
29
- new_file_path = seq.replace(source_data_root.split('/')[-1], data_root.split('/')[-1])
30
- try:
31
- shutil.move(old_file_path, new_file_path)
32
- if split == 'train':
33
- train = train + 1
34
- elif split == 'test':
35
- test = test + 1
36
- elif split == 'val':
37
- val = val + 1
38
- except FileNotFoundError:
39
- none_split.append(old_file_path)
40
- print(f"The file {old_file_path} does not exists.")
41
- except shutil.Error:
42
- none_split.append(old_file_path)
43
- print(f"The file {old_file_path} does not exists.")
44
-
45
- print(none_split.__len__())
46
- pickle.dump(none_split, f_save)
47
- f_save.close()
48
-
49
- print(train, val, test)
50
-
51
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_utils/axis2matrix.py DELETED
@@ -1,29 +0,0 @@
1
- import numpy as np
2
- import math
3
- import scipy.linalg as linalg
4
-
5
-
6
- def rotate_mat(axis, radian):
7
-
8
- a = np.cross(np.eye(3), axis / linalg.norm(axis) * radian)
9
-
10
- rot_matrix = linalg.expm(a)
11
-
12
- return rot_matrix
13
-
14
- def aaa2mat(axis, sin, cos):
15
- i = np.eye(3)
16
- nnt = np.dot(axis.T, axis)
17
- s = np.asarray([[0, -axis[0,2], axis[0,1]],
18
- [axis[0,2], 0, -axis[0,0]],
19
- [-axis[0,1], axis[0,0], 0]])
20
- r = cos * i + (1-cos)*nnt +sin * s
21
- return r
22
-
23
- rand_axis = np.asarray([[1,0,0]])
24
- #旋转角度
25
- r = math.pi/2
26
- #返回旋转矩阵
27
- rot_matrix = rotate_mat(rand_axis, r)
28
- r2 = aaa2mat(rand_axis, np.sin(r), np.cos(r))
29
- print(rot_matrix)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_utils/consts.py DELETED
The diff for this file is too large to render. See raw diff
 
data_utils/dataloader_torch.py DELETED
@@ -1,279 +0,0 @@
1
- import sys
2
- import os
3
- sys.path.append(os.getcwd())
4
- import os
5
- from tqdm import tqdm
6
- from data_utils.utils import *
7
- import torch.utils.data as data
8
- from data_utils.mesh_dataset import SmplxDataset
9
- from transformers import Wav2Vec2Processor
10
-
11
-
12
- class MultiVidData():
13
- def __init__(self,
14
- data_root,
15
- speakers,
16
- split='train',
17
- limbscaling=False,
18
- normalization=False,
19
- norm_method='new',
20
- split_trans_zero=False,
21
- num_frames=25,
22
- num_pre_frames=25,
23
- num_generate_length=None,
24
- aud_feat_win_size=None,
25
- aud_feat_dim=64,
26
- feat_method='mel_spec',
27
- context_info=False,
28
- smplx=False,
29
- audio_sr=16000,
30
- convert_to_6d=False,
31
- expression=False,
32
- config=None
33
- ):
34
- self.data_root = data_root
35
- self.speakers = speakers
36
- self.split = split
37
- if split == 'pre':
38
- self.split = 'train'
39
- self.norm_method=norm_method
40
- self.normalization = normalization
41
- self.limbscaling = limbscaling
42
- self.convert_to_6d = convert_to_6d
43
- self.num_frames=num_frames
44
- self.num_pre_frames=num_pre_frames
45
- if num_generate_length is None:
46
- self.num_generate_length = num_frames
47
- else:
48
- self.num_generate_length = num_generate_length
49
- self.split_trans_zero=split_trans_zero
50
-
51
- dataset = SmplxDataset
52
-
53
- if self.split_trans_zero:
54
- self.trans_dataset_list = []
55
- self.zero_dataset_list = []
56
- else:
57
- self.all_dataset_list = []
58
- self.dataset={}
59
- self.complete_data=[]
60
- self.config=config
61
- load_mode=self.config.dataset_load_mode
62
-
63
- ######################load with pickle file
64
- if load_mode=='pickle':
65
- import pickle
66
- import subprocess
67
-
68
- # store_file_path='/tmp/store.pkl'
69
- # cp /is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts/store.pkl /tmp/store.pkl
70
- # subprocess.run(f'cp /is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts/store.pkl {store_file_path}',shell=True)
71
-
72
- # f = open(self.config.store_file_path, 'rb+')
73
- f = open(self.split+config.Data.pklname, 'rb+')
74
- self.dataset=pickle.load(f)
75
- f.close()
76
- for key in self.dataset:
77
- self.complete_data.append(self.dataset[key].complete_data)
78
- ######################load with pickle file
79
-
80
- ######################load with a csv file
81
- elif load_mode=='csv':
82
-
83
- # 这里从我的一个code文件夹导入的,后续再完善进来
84
- try:
85
- sys.path.append(self.config.config_root_path)
86
- from config import config_path
87
- from csv_parser import csv_parse
88
-
89
- except ImportError as e:
90
- print(f'err: {e}')
91
- raise ImportError('config root path error...')
92
-
93
-
94
- for speaker_name in self.speakers:
95
- # df_intervals=pd.read_csv(self.config.voca_csv_file_path)
96
- df_intervals=None
97
- df_intervals=df_intervals[df_intervals['speaker']==speaker_name]
98
- df_intervals = df_intervals[df_intervals['dataset'] == self.split]
99
-
100
- print(f'speaker {speaker_name} train interval length: {len(df_intervals)}')
101
- for iter_index, (_, interval) in tqdm(
102
- (enumerate(df_intervals.iterrows())),desc=f'load {speaker_name}'
103
- ):
104
-
105
- (
106
- interval_index,
107
- interval_speaker,
108
- interval_video_fn,
109
- interval_id,
110
-
111
- start_time,
112
- end_time,
113
- duration_time,
114
- start_time_10,
115
- over_flow_flag,
116
- short_dur_flag,
117
-
118
- big_video_dir,
119
- small_video_dir_name,
120
- speaker_video_path,
121
-
122
- voca_basename,
123
- json_basename,
124
- wav_basename,
125
- voca_top_clip_path,
126
- voca_json_clip_path,
127
- voca_wav_clip_path,
128
-
129
- audio_output_fn,
130
- image_output_path,
131
- pifpaf_output_path,
132
- mp_output_path,
133
- op_output_path,
134
- deca_output_path,
135
- pixie_output_path,
136
- cam_output_path,
137
- ours_output_path,
138
- merge_output_path,
139
- multi_output_path,
140
- gt_output_path,
141
- ours_images_path,
142
- pkl_fil_path,
143
- )=csv_parse(interval)
144
-
145
- if not os.path.exists(pkl_fil_path) or not os.path.exists(audio_output_fn):
146
- continue
147
-
148
- key=f'{interval_video_fn}/{small_video_dir_name}'
149
- self.dataset[key] = dataset(
150
- data_root=pkl_fil_path,
151
- speaker=speaker_name,
152
- audio_fn=audio_output_fn,
153
- audio_sr=audio_sr,
154
- fps=num_frames,
155
- feat_method=feat_method,
156
- audio_feat_dim=aud_feat_dim,
157
- train=(self.split == 'train'),
158
- load_all=True,
159
- split_trans_zero=self.split_trans_zero,
160
- limbscaling=self.limbscaling,
161
- num_frames=self.num_frames,
162
- num_pre_frames=self.num_pre_frames,
163
- num_generate_length=self.num_generate_length,
164
- audio_feat_win_size=aud_feat_win_size,
165
- context_info=context_info,
166
- convert_to_6d=convert_to_6d,
167
- expression=expression,
168
- config=self.config
169
- )
170
- self.complete_data.append(self.dataset[key].complete_data)
171
- ######################load with a csv file
172
-
173
- ######################origin load method
174
- elif load_mode=='json':
175
-
176
- # if self.split == 'train':
177
- # import pickle
178
- # f = open('store.pkl', 'rb+')
179
- # self.dataset=pickle.load(f)
180
- # f.close()
181
- # for key in self.dataset:
182
- # self.complete_data.append(self.dataset[key].complete_data)
183
- # else:https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav
184
- # if config.Model.model_type == 'face':
185
- am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
186
- am_sr = 16000
187
- # else:
188
- # am, am_sr = None, None
189
- for speaker_name in self.speakers:
190
- speaker_root = os.path.join(self.data_root, speaker_name)
191
-
192
- videos=[v for v in os.listdir(speaker_root) ]
193
- print(videos)
194
-
195
- haode = huaide = 0
196
-
197
- for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
198
- source_vid=vid
199
- # vid_pth=os.path.join(speaker_root, source_vid, 'images/half', self.split)
200
- vid_pth = os.path.join(speaker_root, source_vid, self.split)
201
- if smplx == 'pose':
202
- seqs = [s for s in os.listdir(vid_pth) if (s.startswith('clip'))]
203
- else:
204
- try:
205
- seqs = [s for s in os.listdir(vid_pth)]
206
- except:
207
- continue
208
-
209
- for s in seqs:
210
- seq_root=os.path.join(vid_pth, s)
211
- key = seq_root # correspond to clip******
212
- audio_fname = os.path.join(speaker_root, source_vid, self.split, s, '%s.wav' % (s))
213
- motion_fname = os.path.join(speaker_root, source_vid, self.split, s, '%s.pkl' % (s))
214
- if not os.path.isfile(audio_fname) or not os.path.isfile(motion_fname):
215
- huaide = huaide + 1
216
- continue
217
-
218
- self.dataset[key]=dataset(
219
- data_root=seq_root,
220
- speaker=speaker_name,
221
- motion_fn=motion_fname,
222
- audio_fn=audio_fname,
223
- audio_sr=audio_sr,
224
- fps=num_frames,
225
- feat_method=feat_method,
226
- audio_feat_dim=aud_feat_dim,
227
- train=(self.split=='train'),
228
- load_all=True,
229
- split_trans_zero=self.split_trans_zero,
230
- limbscaling=self.limbscaling,
231
- num_frames=self.num_frames,
232
- num_pre_frames=self.num_pre_frames,
233
- num_generate_length=self.num_generate_length,
234
- audio_feat_win_size=aud_feat_win_size,
235
- context_info=context_info,
236
- convert_to_6d=convert_to_6d,
237
- expression=expression,
238
- config=self.config,
239
- am=am,
240
- am_sr=am_sr,
241
- whole_video=config.Data.whole_video
242
- )
243
- self.complete_data.append(self.dataset[key].complete_data)
244
- haode = haode + 1
245
- print("huaide:{}, haode:{}".format(huaide, haode))
246
- import pickle
247
-
248
- f = open(self.split+config.Data.pklname, 'wb')
249
- pickle.dump(self.dataset, f)
250
- f.close()
251
- ######################origin load method
252
-
253
- self.complete_data=np.concatenate(self.complete_data, axis=0)
254
-
255
- # assert self.complete_data.shape[-1] == (12+21+21)*2
256
- self.normalize_stats = {}
257
-
258
- self.data_mean = None
259
- self.data_std = None
260
-
261
- def get_dataset(self):
262
- self.normalize_stats['mean'] = self.data_mean
263
- self.normalize_stats['std'] = self.data_std
264
-
265
- for key in list(self.dataset.keys()):
266
- if self.dataset[key].complete_data.shape[0] < self.num_generate_length:
267
- continue
268
- self.dataset[key].num_generate_length = self.num_generate_length
269
- self.dataset[key].get_dataset(self.normalization, self.normalize_stats, self.split)
270
- self.all_dataset_list.append(self.dataset[key].all_dataset)
271
-
272
- if self.split_trans_zero:
273
- self.trans_dataset = data.ConcatDataset(self.trans_dataset_list)
274
- self.zero_dataset = data.ConcatDataset(self.zero_dataset_list)
275
- else:
276
- self.all_dataset = data.ConcatDataset(self.all_dataset_list)
277
-
278
-
279
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_utils/dataset_preprocess.py DELETED
@@ -1,170 +0,0 @@
1
- import os
2
- import pickle
3
- from tqdm import tqdm
4
- import shutil
5
- import torch
6
- import numpy as np
7
- import librosa
8
- import random
9
-
10
- speakers = ['seth', 'conan', 'oliver', 'chemistry']
11
- data_root = "../ExpressiveWholeBodyDatasetv1.0/"
12
- split = 'train'
13
-
14
-
15
-
16
- def split_list(full_list,shuffle=False,ratio=0.2):
17
- n_total = len(full_list)
18
- offset_0 = int(n_total * ratio)
19
- offset_1 = int(n_total * ratio * 2)
20
- if n_total==0 or offset_1<1:
21
- return [],full_list
22
- if shuffle:
23
- random.shuffle(full_list)
24
- sublist_0 = full_list[:offset_0]
25
- sublist_1 = full_list[offset_0:offset_1]
26
- sublist_2 = full_list[offset_1:]
27
- return sublist_0, sublist_1, sublist_2
28
-
29
-
30
- def moveto(list, file):
31
- for f in list:
32
- before, after = '/'.join(f.split('/')[:-1]), f.split('/')[-1]
33
- new_path = os.path.join(before, file)
34
- new_path = os.path.join(new_path, after)
35
- # os.makedirs(new_path)
36
- # os.path.isdir(new_path)
37
- # shutil.move(f, new_path)
38
-
39
- #转移到新目录
40
- shutil.copytree(f, new_path)
41
- #删除原train里的文件
42
- shutil.rmtree(f)
43
- return None
44
-
45
-
46
- def read_pkl(data):
47
- betas = np.array(data['betas'])
48
-
49
- jaw_pose = np.array(data['jaw_pose'])
50
- leye_pose = np.array(data['leye_pose'])
51
- reye_pose = np.array(data['reye_pose'])
52
- global_orient = np.array(data['global_orient']).squeeze()
53
- body_pose = np.array(data['body_pose_axis'])
54
- left_hand_pose = np.array(data['left_hand_pose'])
55
- right_hand_pose = np.array(data['right_hand_pose'])
56
-
57
- full_body = np.concatenate(
58
- (jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose), axis=1)
59
-
60
- expression = np.array(data['expression'])
61
- full_body = np.concatenate((full_body, expression), axis=1)
62
-
63
- if (full_body.shape[0] < 90) or (torch.isnan(torch.from_numpy(full_body)).sum() > 0):
64
- return 1
65
- else:
66
- return 0
67
-
68
-
69
- for speaker_name in speakers:
70
- speaker_root = os.path.join(data_root, speaker_name)
71
-
72
- videos = [v for v in os.listdir(speaker_root)]
73
- print(videos)
74
-
75
- haode = huaide = 0
76
- total_seqs = []
77
-
78
- for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
79
- # for vid in videos:
80
- source_vid = vid
81
- vid_pth = os.path.join(speaker_root, source_vid)
82
- # vid_pth = os.path.join(speaker_root, source_vid, 'images/half', split)
83
- t = os.path.join(speaker_root, source_vid, 'test')
84
- v = os.path.join(speaker_root, source_vid, 'val')
85
-
86
- # if os.path.exists(t):
87
- # shutil.rmtree(t)
88
- # if os.path.exists(v):
89
- # shutil.rmtree(v)
90
- try:
91
- seqs = [s for s in os.listdir(vid_pth)]
92
- except:
93
- continue
94
- # if len(seqs) == 0:
95
- # shutil.rmtree(os.path.join(speaker_root, source_vid))
96
- # None
97
- for s in seqs:
98
- quality = 0
99
- total_seqs.append(os.path.join(vid_pth,s))
100
- seq_root = os.path.join(vid_pth, s)
101
- key = seq_root # correspond to clip******
102
- audio_fname = os.path.join(speaker_root, source_vid, s, '%s.wav' % (s))
103
-
104
- # delete the data without audio or the audio file could not be read
105
- if os.path.isfile(audio_fname):
106
- try:
107
- audio = librosa.load(audio_fname)
108
- except:
109
- # print(key)
110
- shutil.rmtree(key)
111
- huaide = huaide + 1
112
- continue
113
- else:
114
- huaide = huaide + 1
115
- # print(key)
116
- shutil.rmtree(key)
117
- continue
118
-
119
- # check motion file
120
- motion_fname = os.path.join(speaker_root, source_vid, s, '%s.pkl' % (s))
121
- try:
122
- f = open(motion_fname, 'rb+')
123
- except:
124
- shutil.rmtree(key)
125
- huaide = huaide + 1
126
- continue
127
-
128
- data = pickle.load(f)
129
- w = read_pkl(data)
130
- f.close()
131
- quality = quality + w
132
-
133
- if w == 1:
134
- shutil.rmtree(key)
135
- # print(key)
136
- huaide = huaide + 1
137
- continue
138
-
139
- haode = haode + 1
140
-
141
- print("huaide:{}, haode:{}, total_seqs:{}".format(huaide, haode, total_seqs.__len__()))
142
-
143
- for speaker_name in speakers:
144
- speaker_root = os.path.join(data_root, speaker_name)
145
-
146
- videos = [v for v in os.listdir(speaker_root)]
147
- print(videos)
148
-
149
- haode = huaide = 0
150
- total_seqs = []
151
-
152
- for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
153
- # for vid in videos:
154
- source_vid = vid
155
- vid_pth = os.path.join(speaker_root, source_vid)
156
- try:
157
- seqs = [s for s in os.listdir(vid_pth)]
158
- except:
159
- continue
160
- for s in seqs:
161
- quality = 0
162
- total_seqs.append(os.path.join(vid_pth, s))
163
- print("total_seqs:{}".format(total_seqs.__len__()))
164
- # split the dataset
165
- test_list, val_list, train_list = split_list(total_seqs, True, 0.1)
166
- print(len(test_list), len(val_list), len(train_list))
167
- moveto(train_list, 'train')
168
- moveto(test_list, 'test')
169
- moveto(val_list, 'val')
170
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_utils/get_j.py DELETED
@@ -1,51 +0,0 @@
1
- import torch
2
-
3
-
4
- def to3d(poses, config):
5
- if config.Data.pose.convert_to_6d:
6
- if config.Data.pose.expression:
7
- poses_exp = poses[:, -100:]
8
- poses = poses[:, :-100]
9
-
10
- poses = poses.reshape(poses.shape[0], -1, 5)
11
- sin, cos = poses[:, :, 3], poses[:, :, 4]
12
- pose_angle = torch.atan2(sin, cos)
13
- poses = (poses[:, :, :3] * pose_angle.unsqueeze(dim=-1)).reshape(poses.shape[0], -1)
14
-
15
- if config.Data.pose.expression:
16
- poses = torch.cat([poses, poses_exp], dim=-1)
17
- return poses
18
-
19
-
20
- def get_joint(smplx_model, betas, pred):
21
- joint = smplx_model(betas=betas.repeat(pred.shape[0], 1),
22
- expression=pred[:, 165:265],
23
- jaw_pose=pred[:, 0:3],
24
- leye_pose=pred[:, 3:6],
25
- reye_pose=pred[:, 6:9],
26
- global_orient=pred[:, 9:12],
27
- body_pose=pred[:, 12:75],
28
- left_hand_pose=pred[:, 75:120],
29
- right_hand_pose=pred[:, 120:165],
30
- return_verts=True)['joints']
31
- return joint
32
-
33
-
34
- def get_joints(smplx_model, betas, pred):
35
- if len(pred.shape) == 3:
36
- B = pred.shape[0]
37
- x = 4 if B>= 4 else B
38
- T = pred.shape[1]
39
- pred = pred.reshape(-1, 265)
40
- smplx_model.batch_size = L = T * x
41
-
42
- times = pred.shape[0] // smplx_model.batch_size
43
- joints = []
44
- for i in range(times):
45
- joints.append(get_joint(smplx_model, betas, pred[i*L:(i+1)*L]))
46
- joints = torch.cat(joints, dim=0)
47
- joints = joints.reshape(B, T, -1, 3)
48
- else:
49
- smplx_model.batch_size = pred.shape[0]
50
- joints = get_joint(smplx_model, betas, pred)
51
- return joints
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_utils/hand_component.json DELETED
The diff for this file is too large to render. See raw diff
 
data_utils/lower_body.py DELETED
@@ -1,143 +0,0 @@
1
- import numpy as np
2
- import torch
3
-
4
- lower_pose = torch.tensor(
5
- [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0747, -0.0158, -0.0152, -1.1826512813568115, 0.23866955935955048,
6
- 0.15146760642528534, -1.2604516744613647, -0.3160211145877838,
7
- -0.1603458970785141, 1.1654603481292725, 0.0, 0.0, 1.2521806955337524, 0.041598282754421234, -0.06312154978513718,
8
- 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
9
- lower_pose_stand = torch.tensor([
10
- 8.9759e-04, 7.1074e-04, -5.9163e-06, 8.9759e-04, 7.1074e-04, -5.9163e-06,
11
- 3.0747, -0.0158, -0.0152,
12
- -3.6665e-01, -8.8455e-03, 1.6113e-01, -3.6665e-01, -8.8455e-03, 1.6113e-01,
13
- -3.9716e-01, -4.0229e-02, -1.2637e-01,
14
- 7.9163e-01, 6.8519e-02, -1.5091e-01, 7.9163e-01, 6.8519e-02, -1.5091e-01,
15
- 7.8632e-01, -4.3810e-02, 1.4375e-02,
16
- -1.0675e-01, 1.2635e-01, 1.6711e-02, -1.0675e-01, 1.2635e-01, 1.6711e-02, ])
17
- # lower_pose_stand = torch.tensor(
18
- # [6.4919e-02, 3.3018e-02, 1.7485e-02, 8.9759e-04, 7.1074e-04, -5.9163e-06,
19
- # 3.0747, -0.0158, -0.0152,
20
- # -3.3633e+00, -9.3915e-02, 3.0996e-01, -3.6665e-01, -8.8455e-03, 1.6113e-01,
21
- # 1.1654603481292725, 0.0, 0.0,
22
- # 4.4167e-01, 6.7183e-03, -3.6379e-03, 7.9163e-01, 6.8519e-02, -1.5091e-01,
23
- # 0.0, 0.0, 0.0,
24
- # 2.2910e-02, -2.4797e-02, -5.5657e-03, -1.0675e-01, 1.2635e-01, 1.6711e-02,])
25
- lower_body = [0, 1, 3, 4, 6, 7, 9, 10]
26
- count_part = [6, 9, 12, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
27
- 29, 30, 31, 32, 33, 34, 35, 36, 37,
28
- 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54]
29
- fix_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
30
- 29,
31
- 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
32
- 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
33
- 65, 66, 67, 68, 69, 70, 71, 72, 73, 74]
34
- all_index = np.ones(275)
35
- all_index[fix_index] = 0
36
- c_index = []
37
- i = 0
38
- for num in all_index:
39
- if num == 1:
40
- c_index.append(i)
41
- i = i + 1
42
- c_index = np.asarray(c_index)
43
-
44
- fix_index_3d = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
45
- 21, 22, 23, 24, 25, 26,
46
- 30, 31, 32, 33, 34, 35,
47
- 45, 46, 47, 48, 49, 50]
48
- all_index_3d = np.ones(165)
49
- all_index_3d[fix_index_3d] = 0
50
- c_index_3d = []
51
- i = 0
52
- for num in all_index_3d:
53
- if num == 1:
54
- c_index_3d.append(i)
55
- i = i + 1
56
- c_index_3d = np.asarray(c_index_3d)
57
-
58
- c_index_6d = []
59
- i = 0
60
- for num in all_index_3d:
61
- if num == 1:
62
- c_index_6d.append(2*i)
63
- c_index_6d.append(2 * i + 1)
64
- i = i + 1
65
- c_index_6d = np.asarray(c_index_6d)
66
-
67
-
68
- def part2full(input, stand=False):
69
- if stand:
70
- # lp = lower_pose_stand.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
71
- lp = torch.zeros_like(lower_pose)
72
- lp[6:9] = torch.tensor([3.0747, -0.0158, -0.0152])
73
- lp = lp.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
74
- else:
75
- lp = lower_pose.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
76
-
77
- input = torch.cat([input[:, :3],
78
- lp[:, :15],
79
- input[:, 3:6],
80
- lp[:, 15:21],
81
- input[:, 6:9],
82
- lp[:, 21:27],
83
- input[:, 9:12],
84
- lp[:, 27:],
85
- input[:, 12:]]
86
- , dim=1)
87
- return input
88
-
89
-
90
- def pred2poses(input, gt):
91
- input = torch.cat([input[:, :3],
92
- gt[0:1, 3:18].repeat(input.shape[0], 1),
93
- input[:, 3:6],
94
- gt[0:1, 21:27].repeat(input.shape[0], 1),
95
- input[:, 6:9],
96
- gt[0:1, 30:36].repeat(input.shape[0], 1),
97
- input[:, 9:12],
98
- gt[0:1, 39:45].repeat(input.shape[0], 1),
99
- input[:, 12:]]
100
- , dim=1)
101
- return input
102
-
103
-
104
- def poses2poses(input, gt):
105
- input = torch.cat([input[:, :3],
106
- gt[0:1, 3:18].repeat(input.shape[0], 1),
107
- input[:, 18:21],
108
- gt[0:1, 21:27].repeat(input.shape[0], 1),
109
- input[:, 27:30],
110
- gt[0:1, 30:36].repeat(input.shape[0], 1),
111
- input[:, 36:39],
112
- gt[0:1, 39:45].repeat(input.shape[0], 1),
113
- input[:, 45:]]
114
- , dim=1)
115
- return input
116
-
117
- def poses2pred(input, stand=False):
118
- if stand:
119
- lp = lower_pose_stand.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
120
- # lp = torch.zeros_like(lower_pose).unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
121
- else:
122
- lp = lower_pose.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
123
- input = torch.cat([input[:, :3],
124
- lp[:, :15],
125
- input[:, 18:21],
126
- lp[:, 15:21],
127
- input[:, 27:30],
128
- lp[:, 21:27],
129
- input[:, 36:39],
130
- lp[:, 27:],
131
- input[:, 45:]]
132
- , dim=1)
133
- return input
134
-
135
-
136
- rearrange = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]\
137
- # ,22, 23, 24, 25, 40, 26, 41,
138
- # 27, 42, 28, 43, 29, 44, 30, 45, 31, 46, 32, 47, 33, 48, 34, 49, 35, 50, 36, 51, 37, 52, 38, 53, 39, 54, 55,
139
- # 57, 56, 59, 58, 60, 63, 61, 64, 62, 65, 66, 71, 67, 72, 68, 73, 69, 74, 70, 75]
140
-
141
- symmetry = [0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1]#, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
142
- # 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
143
- # 1, 1, 1, 1, 1, 1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_utils/mesh_dataset.py DELETED
@@ -1,348 +0,0 @@
1
- import pickle
2
- import sys
3
- import os
4
-
5
- sys.path.append(os.getcwd())
6
-
7
- import json
8
- from glob import glob
9
- from data_utils.utils import *
10
- import torch.utils.data as data
11
- from data_utils.consts import speaker_id
12
- from data_utils.lower_body import count_part
13
- import random
14
- from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d
15
-
16
- with open('data_utils/hand_component.json') as file_obj:
17
- comp = json.load(file_obj)
18
- left_hand_c = np.asarray(comp['left'])
19
- right_hand_c = np.asarray(comp['right'])
20
-
21
-
22
- def to3d(data):
23
- left_hand_pose = np.einsum('bi,ij->bj', data[:, 75:87], left_hand_c[:12, :])
24
- right_hand_pose = np.einsum('bi,ij->bj', data[:, 87:99], right_hand_c[:12, :])
25
- data = np.concatenate((data[:, :75], left_hand_pose, right_hand_pose), axis=-1)
26
- return data
27
-
28
-
29
- class SmplxDataset():
30
- '''
31
- creat a dataset for every segment and concat.
32
- '''
33
-
34
- def __init__(self,
35
- data_root,
36
- speaker,
37
- motion_fn,
38
- audio_fn,
39
- audio_sr,
40
- fps,
41
- feat_method='mel_spec',
42
- audio_feat_dim=64,
43
- audio_feat_win_size=None,
44
-
45
- train=True,
46
- load_all=False,
47
- split_trans_zero=False,
48
- limbscaling=False,
49
- num_frames=25,
50
- num_pre_frames=25,
51
- num_generate_length=25,
52
- context_info=False,
53
- convert_to_6d=False,
54
- expression=False,
55
- config=None,
56
- am=None,
57
- am_sr=None,
58
- whole_video=False
59
- ):
60
-
61
- self.data_root = data_root
62
- self.speaker = speaker
63
-
64
- self.feat_method = feat_method
65
- self.audio_fn = audio_fn
66
- self.audio_sr = audio_sr
67
- self.fps = fps
68
- self.audio_feat_dim = audio_feat_dim
69
- self.audio_feat_win_size = audio_feat_win_size
70
- self.context_info = context_info # for aud feat
71
- self.convert_to_6d = convert_to_6d
72
- self.expression = expression
73
-
74
- self.train = train
75
- self.load_all = load_all
76
- self.split_trans_zero = split_trans_zero
77
- self.limbscaling = limbscaling
78
- self.num_frames = num_frames
79
- self.num_pre_frames = num_pre_frames
80
- self.num_generate_length = num_generate_length
81
- # print('num_generate_length ', self.num_generate_length)
82
-
83
- self.config = config
84
- self.am_sr = am_sr
85
- self.whole_video = whole_video
86
- load_mode = self.config.dataset_load_mode
87
-
88
- if load_mode == 'pickle':
89
- raise NotImplementedError
90
-
91
- elif load_mode == 'csv':
92
- import pickle
93
- with open(data_root, 'rb') as f:
94
- u = pickle._Unpickler(f)
95
- data = u.load()
96
- self.data = data[0]
97
- if self.load_all:
98
- self._load_npz_all()
99
-
100
- elif load_mode == 'json':
101
- self.annotations = glob(data_root + '/*pkl')
102
- if len(self.annotations) == 0:
103
- raise FileNotFoundError(data_root + ' are empty')
104
- self.annotations = sorted(self.annotations)
105
- self.img_name_list = self.annotations
106
-
107
- if self.load_all:
108
- self._load_them_all(am, am_sr, motion_fn)
109
-
110
- def _load_npz_all(self):
111
- self.loaded_data = {}
112
- self.complete_data = []
113
- data = self.data
114
- shape = data['body_pose_axis'].shape[0]
115
- self.betas = data['betas']
116
- self.img_name_list = []
117
- for index in range(shape):
118
- img_name = f'{index:6d}'
119
- self.img_name_list.append(img_name)
120
-
121
- jaw_pose = data['jaw_pose'][index]
122
- leye_pose = data['leye_pose'][index]
123
- reye_pose = data['reye_pose'][index]
124
- global_orient = data['global_orient'][index]
125
- body_pose = data['body_pose_axis'][index]
126
- left_hand_pose = data['left_hand_pose'][index]
127
- right_hand_pose = data['right_hand_pose'][index]
128
-
129
- full_body = np.concatenate(
130
- (jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose))
131
- assert full_body.shape[0] == 99
132
- if self.convert_to_6d:
133
- full_body = to3d(full_body)
134
- full_body = torch.from_numpy(full_body)
135
- full_body = matrix_to_rotation_6d(axis_angle_to_matrix(full_body))
136
- full_body = np.asarray(full_body)
137
- if self.expression:
138
- expression = data['expression'][index]
139
- full_body = np.concatenate((full_body, expression))
140
- # full_body = np.concatenate((full_body, non_zero))
141
- else:
142
- full_body = to3d(full_body)
143
- if self.expression:
144
- expression = data['expression'][index]
145
- full_body = np.concatenate((full_body, expression))
146
-
147
- self.loaded_data[img_name] = full_body.reshape(-1)
148
- self.complete_data.append(full_body.reshape(-1))
149
-
150
- self.complete_data = np.array(self.complete_data)
151
-
152
- if self.audio_feat_win_size is not None:
153
- self.audio_feat = get_mfcc_old(self.audio_fn).transpose(1, 0)
154
- # print(self.audio_feat.shape)
155
- else:
156
- if self.feat_method == 'mel_spec':
157
- self.audio_feat = get_melspec(self.audio_fn, fps=self.fps, sr=self.audio_sr, n_mels=self.audio_feat_dim)
158
- elif self.feat_method == 'mfcc':
159
- self.audio_feat = get_mfcc(self.audio_fn,
160
- smlpx=True,
161
- sr=self.audio_sr,
162
- n_mfcc=self.audio_feat_dim,
163
- win_size=self.audio_feat_win_size
164
- )
165
-
166
- def _load_them_all(self, am, am_sr, motion_fn):
167
- self.loaded_data = {}
168
- self.complete_data = []
169
- f = open(motion_fn, 'rb+')
170
- data = pickle.load(f)
171
-
172
- self.betas = np.array(data['betas'])
173
-
174
- jaw_pose = np.array(data['jaw_pose'])
175
- leye_pose = np.array(data['leye_pose'])
176
- reye_pose = np.array(data['reye_pose'])
177
- global_orient = np.array(data['global_orient']).squeeze()
178
- body_pose = np.array(data['body_pose_axis'])
179
- left_hand_pose = np.array(data['left_hand_pose'])
180
- right_hand_pose = np.array(data['right_hand_pose'])
181
-
182
- full_body = np.concatenate(
183
- (jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose), axis=1)
184
- assert full_body.shape[1] == 99
185
-
186
-
187
- if self.convert_to_6d:
188
- full_body = to3d(full_body)
189
- full_body = torch.from_numpy(full_body)
190
- full_body = matrix_to_rotation_6d(axis_angle_to_matrix(full_body.reshape(-1, 55, 3))).reshape(-1, 330)
191
- full_body = np.asarray(full_body)
192
- if self.expression:
193
- expression = np.array(data['expression'])
194
- full_body = np.concatenate((full_body, expression), axis=1)
195
-
196
- else:
197
- full_body = to3d(full_body)
198
- expression = np.array(data['expression'])
199
- full_body = np.concatenate((full_body, expression), axis=1)
200
-
201
- self.complete_data = full_body
202
- self.complete_data = np.array(self.complete_data)
203
-
204
- if self.audio_feat_win_size is not None:
205
- self.audio_feat = get_mfcc_old(self.audio_fn).transpose(1, 0)
206
- else:
207
- # if self.feat_method == 'mel_spec':
208
- # self.audio_feat = get_melspec(self.audio_fn, fps=self.fps, sr=self.audio_sr, n_mels=self.audio_feat_dim)
209
- # elif self.feat_method == 'mfcc':
210
- self.audio_feat = get_mfcc_ta(self.audio_fn,
211
- smlpx=True,
212
- fps=30,
213
- sr=self.audio_sr,
214
- n_mfcc=self.audio_feat_dim,
215
- win_size=self.audio_feat_win_size,
216
- type=self.feat_method,
217
- am=am,
218
- am_sr=am_sr,
219
- encoder_choice=self.config.Model.encoder_choice,
220
- )
221
- # with open(audio_file, 'w', encoding='utf-8') as file:
222
- # file.write(json.dumps(self.audio_feat.__array__().tolist(), indent=0, ensure_ascii=False))
223
-
224
- def get_dataset(self, normalization=False, normalize_stats=None, split='train'):
225
-
226
- class __Worker__(data.Dataset):
227
- def __init__(child, index_list, normalization, normalize_stats, split='train') -> None:
228
- super().__init__()
229
- child.index_list = index_list
230
- child.normalization = normalization
231
- child.normalize_stats = normalize_stats
232
- child.split = split
233
-
234
- def __getitem__(child, index):
235
- num_generate_length = self.num_generate_length
236
- num_pre_frames = self.num_pre_frames
237
- seq_len = num_generate_length + num_pre_frames
238
- # print(num_generate_length)
239
-
240
- index = child.index_list[index]
241
- index_new = index + random.randrange(0, 5, 3)
242
- if index_new + seq_len > self.complete_data.shape[0]:
243
- index_new = index
244
- index = index_new
245
-
246
- if child.split in ['val', 'pre', 'test'] or self.whole_video:
247
- index = 0
248
- seq_len = self.complete_data.shape[0]
249
- seq_data = []
250
- assert index + seq_len <= self.complete_data.shape[0]
251
- # print(seq_len)
252
- seq_data = self.complete_data[index:(index + seq_len), :]
253
- seq_data = np.array(seq_data)
254
-
255
- '''
256
- audio feature,
257
- '''
258
- if not self.context_info:
259
- if not self.whole_video:
260
- audio_feat = self.audio_feat[index:index + seq_len, ...]
261
- if audio_feat.shape[0] < seq_len:
262
- audio_feat = np.pad(audio_feat, [[0, seq_len - audio_feat.shape[0]], [0, 0]],
263
- mode='reflect')
264
-
265
- assert audio_feat.shape[0] == seq_len and audio_feat.shape[1] == self.audio_feat_dim
266
- else:
267
- audio_feat = self.audio_feat
268
-
269
- else: # including feature and history
270
- if self.audio_feat_win_size is None:
271
- audio_feat = self.audio_feat[index:index + seq_len + num_pre_frames, ...]
272
- if audio_feat.shape[0] < seq_len + num_pre_frames:
273
- audio_feat = np.pad(audio_feat,
274
- [[0, seq_len + self.num_frames - audio_feat.shape[0]], [0, 0]],
275
- mode='constant')
276
-
277
- assert audio_feat.shape[0] == self.num_frames + seq_len and audio_feat.shape[
278
- 1] == self.audio_feat_dim
279
-
280
- if child.normalization:
281
- data_mean = child.normalize_stats['mean'].reshape(1, -1)
282
- data_std = child.normalize_stats['std'].reshape(1, -1)
283
- seq_data[:, :330] = (seq_data[:, :330] - data_mean) / data_std
284
- if child.split in['train', 'test']:
285
- if self.convert_to_6d:
286
- if self.expression:
287
- data_sample = {
288
- 'poses': seq_data[:, :330].astype(np.float).transpose(1, 0),
289
- 'expression': seq_data[:, 330:].astype(np.float).transpose(1, 0),
290
- # 'nzero': seq_data[:, 375:].astype(np.float).transpose(1, 0),
291
- 'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
292
- 'speaker': speaker_id[self.speaker],
293
- 'betas': self.betas,
294
- 'aud_file': self.audio_fn,
295
- }
296
- else:
297
- data_sample = {
298
- 'poses': seq_data[:, :330].astype(np.float).transpose(1, 0),
299
- 'nzero': seq_data[:, 330:].astype(np.float).transpose(1, 0),
300
- 'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
301
- 'speaker': speaker_id[self.speaker],
302
- 'betas': self.betas
303
- }
304
- else:
305
- if self.expression:
306
- data_sample = {
307
- 'poses': seq_data[:, :165].astype(np.float).transpose(1, 0),
308
- 'expression': seq_data[:, 165:].astype(np.float).transpose(1, 0),
309
- 'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
310
- # 'wv2_feat': wv2_feat.astype(np.float).transpose(1, 0),
311
- 'speaker': speaker_id[self.speaker],
312
- 'aud_file': self.audio_fn,
313
- 'betas': self.betas
314
- }
315
- else:
316
- data_sample = {
317
- 'poses': seq_data.astype(np.float).transpose(1, 0),
318
- 'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
319
- 'speaker': speaker_id[self.speaker],
320
- 'betas': self.betas
321
- }
322
- return data_sample
323
- else:
324
- data_sample = {
325
- 'poses': seq_data[:, :330].astype(np.float).transpose(1, 0),
326
- 'expression': seq_data[:, 330:].astype(np.float).transpose(1, 0),
327
- # 'nzero': seq_data[:, 325:].astype(np.float).transpose(1, 0),
328
- 'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
329
- 'aud_file': self.audio_fn,
330
- 'speaker': speaker_id[self.speaker],
331
- 'betas': self.betas
332
- }
333
- return data_sample
334
- def __len__(child):
335
- return len(child.index_list)
336
-
337
- if split == 'train':
338
- index_list = list(
339
- range(0, min(self.complete_data.shape[0], self.audio_feat.shape[0]) - self.num_generate_length - self.num_pre_frames,
340
- 6))
341
- elif split in ['val', 'test']:
342
- index_list = list([0])
343
- if self.whole_video:
344
- index_list = list([0])
345
- self.all_dataset = __Worker__(index_list, normalization, normalize_stats, split)
346
-
347
- def __len__(self):
348
- return len(self.img_name_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_utils/rotation_conversion.py DELETED
@@ -1,551 +0,0 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
- # Check PYTORCH3D_LICENCE before use
3
-
4
- import functools
5
- from typing import Optional
6
-
7
- import torch
8
- import torch.nn.functional as F
9
-
10
-
11
- """
12
- The transformation matrices returned from the functions in this file assume
13
- the points on which the transformation will be applied are column vectors.
14
- i.e. the R matrix is structured as
15
-
16
- R = [
17
- [Rxx, Rxy, Rxz],
18
- [Ryx, Ryy, Ryz],
19
- [Rzx, Rzy, Rzz],
20
- ] # (3, 3)
21
-
22
- This matrix can be applied to column vectors by post multiplication
23
- by the points e.g.
24
-
25
- points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
26
- transformed_points = R * points
27
-
28
- To apply the same matrix to points which are row vectors, the R matrix
29
- can be transposed and pre multiplied by the points:
30
-
31
- e.g.
32
- points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
33
- transformed_points = points * R.transpose(1, 0)
34
- """
35
-
36
-
37
- def quaternion_to_matrix(quaternions):
38
- """
39
- Convert rotations given as quaternions to rotation matrices.
40
-
41
- Args:
42
- quaternions: quaternions with real part first,
43
- as tensor of shape (..., 4).
44
-
45
- Returns:
46
- Rotation matrices as tensor of shape (..., 3, 3).
47
- """
48
- r, i, j, k = torch.unbind(quaternions, -1)
49
- two_s = 2.0 / (quaternions * quaternions).sum(-1)
50
-
51
- o = torch.stack(
52
- (
53
- 1 - two_s * (j * j + k * k),
54
- two_s * (i * j - k * r),
55
- two_s * (i * k + j * r),
56
- two_s * (i * j + k * r),
57
- 1 - two_s * (i * i + k * k),
58
- two_s * (j * k - i * r),
59
- two_s * (i * k - j * r),
60
- two_s * (j * k + i * r),
61
- 1 - two_s * (i * i + j * j),
62
- ),
63
- -1,
64
- )
65
- return o.reshape(quaternions.shape[:-1] + (3, 3))
66
-
67
-
68
- def _copysign(a, b):
69
- """
70
- Return a tensor where each element has the absolute value taken from the,
71
- corresponding element of a, with sign taken from the corresponding
72
- element of b. This is like the standard copysign floating-point operation,
73
- but is not careful about negative 0 and NaN.
74
-
75
- Args:
76
- a: source tensor.
77
- b: tensor whose signs will be used, of the same shape as a.
78
-
79
- Returns:
80
- Tensor of the same shape as a with the signs of b.
81
- """
82
- signs_differ = (a < 0) != (b < 0)
83
- return torch.where(signs_differ, -a, a)
84
-
85
-
86
- def _sqrt_positive_part(x):
87
- """
88
- Returns torch.sqrt(torch.max(0, x))
89
- but with a zero subgradient where x is 0.
90
- """
91
- ret = torch.zeros_like(x)
92
- positive_mask = x > 0
93
- ret[positive_mask] = torch.sqrt(x[positive_mask])
94
- return ret
95
-
96
-
97
- def matrix_to_quaternion(matrix):
98
- """
99
- Convert rotations given as rotation matrices to quaternions.
100
-
101
- Args:
102
- matrix: Rotation matrices as tensor of shape (..., 3, 3).
103
-
104
- Returns:
105
- quaternions with real part first, as tensor of shape (..., 4).
106
- """
107
- if matrix.size(-1) != 3 or matrix.size(-2) != 3:
108
- raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
109
- m00 = matrix[..., 0, 0]
110
- m11 = matrix[..., 1, 1]
111
- m22 = matrix[..., 2, 2]
112
- o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
113
- x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
114
- y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
115
- z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
116
- o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
117
- o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
118
- o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
119
- return torch.stack((o0, o1, o2, o3), -1)
120
-
121
-
122
- def _axis_angle_rotation(axis: str, angle):
123
- """
124
- Return the rotation matrices for one of the rotations about an axis
125
- of which Euler angles describe, for each value of the angle given.
126
-
127
- Args:
128
- axis: Axis label "X" or "Y or "Z".
129
- angle: any shape tensor of Euler angles in radians
130
-
131
- Returns:
132
- Rotation matrices as tensor of shape (..., 3, 3).
133
- """
134
-
135
- cos = torch.cos(angle)
136
- sin = torch.sin(angle)
137
- one = torch.ones_like(angle)
138
- zero = torch.zeros_like(angle)
139
-
140
- if axis == "X":
141
- R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
142
- if axis == "Y":
143
- R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
144
- if axis == "Z":
145
- R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
146
-
147
- return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
148
-
149
-
150
- def euler_angles_to_matrix(euler_angles, convention: str):
151
- """
152
- Convert rotations given as Euler angles in radians to rotation matrices.
153
-
154
- Args:
155
- euler_angles: Euler angles in radians as tensor of shape (..., 3).
156
- convention: Convention string of three uppercase letters from
157
- {"X", "Y", and "Z"}.
158
-
159
- Returns:
160
- Rotation matrices as tensor of shape (..., 3, 3).
161
- """
162
- if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
163
- raise ValueError("Invalid input euler angles.")
164
- if len(convention) != 3:
165
- raise ValueError("Convention must have 3 letters.")
166
- if convention[1] in (convention[0], convention[2]):
167
- raise ValueError(f"Invalid convention {convention}.")
168
- for letter in convention:
169
- if letter not in ("X", "Y", "Z"):
170
- raise ValueError(f"Invalid letter {letter} in convention string.")
171
- matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
172
- return functools.reduce(torch.matmul, matrices)
173
-
174
-
175
- def _angle_from_tan(
176
- axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
177
- ):
178
- """
179
- Extract the first or third Euler angle from the two members of
180
- the matrix which are positive constant times its sine and cosine.
181
-
182
- Args:
183
- axis: Axis label "X" or "Y or "Z" for the angle we are finding.
184
- other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
185
- convention.
186
- data: Rotation matrices as tensor of shape (..., 3, 3).
187
- horizontal: Whether we are looking for the angle for the third axis,
188
- which means the relevant entries are in the same row of the
189
- rotation matrix. If not, they are in the same column.
190
- tait_bryan: Whether the first and third axes in the convention differ.
191
-
192
- Returns:
193
- Euler Angles in radians for each matrix in data as a tensor
194
- of shape (...).
195
- """
196
-
197
- i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
198
- if horizontal:
199
- i2, i1 = i1, i2
200
- even = (axis + other_axis) in ["XY", "YZ", "ZX"]
201
- if horizontal == even:
202
- return torch.atan2(data[..., i1], data[..., i2])
203
- if tait_bryan:
204
- return torch.atan2(-data[..., i2], data[..., i1])
205
- return torch.atan2(data[..., i2], -data[..., i1])
206
-
207
-
208
- def _index_from_letter(letter: str):
209
- if letter == "X":
210
- return 0
211
- if letter == "Y":
212
- return 1
213
- if letter == "Z":
214
- return 2
215
-
216
-
217
- def matrix_to_euler_angles(matrix, convention: str):
218
- """
219
- Convert rotations given as rotation matrices to Euler angles in radians.
220
-
221
- Args:
222
- matrix: Rotation matrices as tensor of shape (..., 3, 3).
223
- convention: Convention string of three uppercase letters.
224
-
225
- Returns:
226
- Euler angles in radians as tensor of shape (..., 3).
227
- """
228
- if len(convention) != 3:
229
- raise ValueError("Convention must have 3 letters.")
230
- if convention[1] in (convention[0], convention[2]):
231
- raise ValueError(f"Invalid convention {convention}.")
232
- for letter in convention:
233
- if letter not in ("X", "Y", "Z"):
234
- raise ValueError(f"Invalid letter {letter} in convention string.")
235
- if matrix.size(-1) != 3 or matrix.size(-2) != 3:
236
- raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
237
- i0 = _index_from_letter(convention[0])
238
- i2 = _index_from_letter(convention[2])
239
- tait_bryan = i0 != i2
240
- if tait_bryan:
241
- central_angle = torch.asin(
242
- matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
243
- )
244
- else:
245
- central_angle = torch.acos(matrix[..., i0, i0])
246
-
247
- o = (
248
- _angle_from_tan(
249
- convention[0], convention[1], matrix[..., i2], False, tait_bryan
250
- ),
251
- central_angle,
252
- _angle_from_tan(
253
- convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
254
- ),
255
- )
256
- return torch.stack(o, -1)
257
-
258
-
259
- def random_quaternions(
260
- n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
261
- ):
262
- """
263
- Generate random quaternions representing rotations,
264
- i.e. versors with nonnegative real part.
265
-
266
- Args:
267
- n: Number of quaternions in a batch to return.
268
- dtype: Type to return.
269
- device: Desired device of returned tensor. Default:
270
- uses the current device for the default tensor type.
271
- requires_grad: Whether the resulting tensor should have the gradient
272
- flag set.
273
-
274
- Returns:
275
- Quaternions as tensor of shape (N, 4).
276
- """
277
- o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
278
- s = (o * o).sum(1)
279
- o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
280
- return o
281
-
282
-
283
- def random_rotations(
284
- n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
285
- ):
286
- """
287
- Generate random rotations as 3x3 rotation matrices.
288
-
289
- Args:
290
- n: Number of rotation matrices in a batch to return.
291
- dtype: Type to return.
292
- device: Device of returned tensor. Default: if None,
293
- uses the current device for the default tensor type.
294
- requires_grad: Whether the resulting tensor should have the gradient
295
- flag set.
296
-
297
- Returns:
298
- Rotation matrices as tensor of shape (n, 3, 3).
299
- """
300
- quaternions = random_quaternions(
301
- n, dtype=dtype, device=device, requires_grad=requires_grad
302
- )
303
- return quaternion_to_matrix(quaternions)
304
-
305
-
306
- def random_rotation(
307
- dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
308
- ):
309
- """
310
- Generate a single random 3x3 rotation matrix.
311
-
312
- Args:
313
- dtype: Type to return
314
- device: Device of returned tensor. Default: if None,
315
- uses the current device for the default tensor type
316
- requires_grad: Whether the resulting tensor should have the gradient
317
- flag set
318
-
319
- Returns:
320
- Rotation matrix as tensor of shape (3, 3).
321
- """
322
- return random_rotations(1, dtype, device, requires_grad)[0]
323
-
324
-
325
- def standardize_quaternion(quaternions):
326
- """
327
- Convert a unit quaternion to a standard form: one in which the real
328
- part is non negative.
329
-
330
- Args:
331
- quaternions: Quaternions with real part first,
332
- as tensor of shape (..., 4).
333
-
334
- Returns:
335
- Standardized quaternions as tensor of shape (..., 4).
336
- """
337
- return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
338
-
339
-
340
- def quaternion_raw_multiply(a, b):
341
- """
342
- Multiply two quaternions.
343
- Usual torch rules for broadcasting apply.
344
-
345
- Args:
346
- a: Quaternions as tensor of shape (..., 4), real part first.
347
- b: Quaternions as tensor of shape (..., 4), real part first.
348
-
349
- Returns:
350
- The product of a and b, a tensor of quaternions shape (..., 4).
351
- """
352
- aw, ax, ay, az = torch.unbind(a, -1)
353
- bw, bx, by, bz = torch.unbind(b, -1)
354
- ow = aw * bw - ax * bx - ay * by - az * bz
355
- ox = aw * bx + ax * bw + ay * bz - az * by
356
- oy = aw * by - ax * bz + ay * bw + az * bx
357
- oz = aw * bz + ax * by - ay * bx + az * bw
358
- return torch.stack((ow, ox, oy, oz), -1)
359
-
360
-
361
- def quaternion_multiply(a, b):
362
- """
363
- Multiply two quaternions representing rotations, returning the quaternion
364
- representing their composition, i.e. the versor with nonnegative real part.
365
- Usual torch rules for broadcasting apply.
366
-
367
- Args:
368
- a: Quaternions as tensor of shape (..., 4), real part first.
369
- b: Quaternions as tensor of shape (..., 4), real part first.
370
-
371
- Returns:
372
- The product of a and b, a tensor of quaternions of shape (..., 4).
373
- """
374
- ab = quaternion_raw_multiply(a, b)
375
- return standardize_quaternion(ab)
376
-
377
-
378
- def quaternion_invert(quaternion):
379
- """
380
- Given a quaternion representing rotation, get the quaternion representing
381
- its inverse.
382
-
383
- Args:
384
- quaternion: Quaternions as tensor of shape (..., 4), with real part
385
- first, which must be versors (unit quaternions).
386
-
387
- Returns:
388
- The inverse, a tensor of quaternions of shape (..., 4).
389
- """
390
-
391
- return quaternion * quaternion.new_tensor([1, -1, -1, -1])
392
-
393
-
394
- def quaternion_apply(quaternion, point):
395
- """
396
- Apply the rotation given by a quaternion to a 3D point.
397
- Usual torch rules for broadcasting apply.
398
-
399
- Args:
400
- quaternion: Tensor of quaternions, real part first, of shape (..., 4).
401
- point: Tensor of 3D points of shape (..., 3).
402
-
403
- Returns:
404
- Tensor of rotated points of shape (..., 3).
405
- """
406
- if point.size(-1) != 3:
407
- raise ValueError(f"Points are not in 3D, f{point.shape}.")
408
- real_parts = point.new_zeros(point.shape[:-1] + (1,))
409
- point_as_quaternion = torch.cat((real_parts, point), -1)
410
- out = quaternion_raw_multiply(
411
- quaternion_raw_multiply(quaternion, point_as_quaternion),
412
- quaternion_invert(quaternion),
413
- )
414
- return out[..., 1:]
415
-
416
-
417
- def axis_angle_to_matrix(axis_angle):
418
- """
419
- Convert rotations given as axis/angle to rotation matrices.
420
-
421
- Args:
422
- axis_angle: Rotations given as a vector in axis angle form,
423
- as a tensor of shape (..., 3), where the magnitude is
424
- the angle turned anticlockwise in radians around the
425
- vector's direction.
426
-
427
- Returns:
428
- Rotation matrices as tensor of shape (..., 3, 3).
429
- """
430
- return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
431
-
432
-
433
- def matrix_to_axis_angle(matrix):
434
- """
435
- Convert rotations given as rotation matrices to axis/angle.
436
-
437
- Args:
438
- matrix: Rotation matrices as tensor of shape (..., 3, 3).
439
-
440
- Returns:
441
- Rotations given as a vector in axis angle form, as a tensor
442
- of shape (..., 3), where the magnitude is the angle
443
- turned anticlockwise in radians around the vector's
444
- direction.
445
- """
446
- return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
447
-
448
-
449
- def axis_angle_to_quaternion(axis_angle):
450
- """
451
- Convert rotations given as axis/angle to quaternions.
452
-
453
- Args:
454
- axis_angle: Rotations given as a vector in axis angle form,
455
- as a tensor of shape (..., 3), where the magnitude is
456
- the angle turned anticlockwise in radians around the
457
- vector's direction.
458
-
459
- Returns:
460
- quaternions with real part first, as tensor of shape (..., 4).
461
- """
462
- angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
463
- half_angles = 0.5 * angles
464
- eps = 1e-6
465
- small_angles = angles.abs() < eps
466
- sin_half_angles_over_angles = torch.empty_like(angles)
467
- sin_half_angles_over_angles[~small_angles] = (
468
- torch.sin(half_angles[~small_angles]) / angles[~small_angles]
469
- )
470
- # for x small, sin(x/2) is about x/2 - (x/2)^3/6
471
- # so sin(x/2)/x is about 1/2 - (x*x)/48
472
- sin_half_angles_over_angles[small_angles] = (
473
- 0.5 - (angles[small_angles] * angles[small_angles]) / 48
474
- )
475
- quaternions = torch.cat(
476
- [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
477
- )
478
- return quaternions
479
-
480
-
481
- def quaternion_to_axis_angle(quaternions):
482
- """
483
- Convert rotations given as quaternions to axis/angle.
484
-
485
- Args:
486
- quaternions: quaternions with real part first,
487
- as tensor of shape (..., 4).
488
-
489
- Returns:
490
- Rotations given as a vector in axis angle form, as a tensor
491
- of shape (..., 3), where the magnitude is the angle
492
- turned anticlockwise in radians around the vector's
493
- direction.
494
- """
495
- norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
496
- half_angles = torch.atan2(norms, quaternions[..., :1])
497
- angles = 2 * half_angles
498
- eps = 1e-6
499
- small_angles = angles.abs() < eps
500
- sin_half_angles_over_angles = torch.empty_like(angles)
501
- sin_half_angles_over_angles[~small_angles] = (
502
- torch.sin(half_angles[~small_angles]) / angles[~small_angles]
503
- )
504
- # for x small, sin(x/2) is about x/2 - (x/2)^3/6
505
- # so sin(x/2)/x is about 1/2 - (x*x)/48
506
- sin_half_angles_over_angles[small_angles] = (
507
- 0.5 - (angles[small_angles] * angles[small_angles]) / 48
508
- )
509
- return quaternions[..., 1:] / sin_half_angles_over_angles
510
-
511
-
512
- def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
513
- """
514
- Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
515
- using Gram--Schmidt orthogonalisation per Section B of [1].
516
- Args:
517
- d6: 6D rotation representation, of size (*, 6)
518
-
519
- Returns:
520
- batch of rotation matrices of size (*, 3, 3)
521
-
522
- [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
523
- On the Continuity of Rotation Representations in Neural Networks.
524
- IEEE Conference on Computer Vision and Pattern Recognition, 2019.
525
- Retrieved from http://arxiv.org/abs/1812.07035
526
- """
527
-
528
- a1, a2 = d6[..., :3], d6[..., 3:]
529
- b1 = F.normalize(a1, dim=-1)
530
- b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
531
- b2 = F.normalize(b2, dim=-1)
532
- b3 = torch.cross(b1, b2, dim=-1)
533
- return torch.stack((b1, b2, b3), dim=-2)
534
-
535
-
536
- def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
537
- """
538
- Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
539
- by dropping the last row. Note that 6D representation is not unique.
540
- Args:
541
- matrix: batch of rotation matrices of size (*, 3, 3)
542
-
543
- Returns:
544
- 6D rotation representation, of size (*, 6)
545
-
546
- [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
547
- On the Continuity of Rotation Representations in Neural Networks.
548
- IEEE Conference on Computer Vision and Pattern Recognition, 2019.
549
- Retrieved from http://arxiv.org/abs/1812.07035
550
- """
551
- return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_utils/split_train_val_test.py DELETED
@@ -1,27 +0,0 @@
1
- import os
2
- import json
3
- import shutil
4
-
5
- if __name__ =='__main__':
6
- id_list = "chemistry conan oliver seth"
7
- id_list = id_list.split(' ')
8
-
9
- old_root = '/home/usename/talkshow_data/ExpressiveWholeBodyDatasetReleaseV1.0'
10
- new_root = '/home/usename/talkshow_data/ExpressiveWholeBodyDatasetReleaseV1.0/talkshow_data_splited'
11
-
12
- with open('train_val_test.json') as f:
13
- split_info = json.load(f)
14
- phase_list = ['train', 'val', 'test']
15
- for phase in phase_list:
16
- phase_path_list = split_info[phase]
17
- for p in phase_path_list:
18
- old_path = os.path.join(old_root, p)
19
- if not os.path.exists(old_path):
20
- print(f'{old_path} not found, continue' )
21
- continue
22
- new_path = os.path.join(new_root, phase, p)
23
- dir_name = os.path.dirname(new_path)
24
- if not os.path.isdir(dir_name):
25
- os.makedirs(dir_name, exist_ok=True)
26
- shutil.move(old_path, new_path)
27
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_utils/train_val_test.json DELETED
The diff for this file is too large to render. See raw diff
 
data_utils/utils.py DELETED
@@ -1,318 +0,0 @@
1
- import numpy as np
2
- # import librosa #has to do this cause librosa is not supported on my server
3
- import python_speech_features
4
- from scipy.io import wavfile
5
- from scipy import signal
6
- import librosa
7
- import torch
8
- import torchaudio as ta
9
- import torchaudio.functional as ta_F
10
- import torchaudio.transforms as ta_T
11
- # import pyloudnorm as pyln
12
-
13
-
14
- def load_wav_old(audio_fn, sr = 16000):
15
- sample_rate, sig = wavfile.read(audio_fn)
16
- if sample_rate != sr:
17
- result = int((sig.shape[0]) / sample_rate * sr)
18
- x_resampled = signal.resample(sig, result)
19
- x_resampled = x_resampled.astype(np.float64)
20
- return x_resampled, sr
21
-
22
- sig = sig / (2**15)
23
- return sig, sample_rate
24
-
25
-
26
- def get_mfcc(audio_fn, eps=1e-6, fps=25, smlpx=False, sr=16000, n_mfcc=64, win_size=None):
27
-
28
- y, sr = librosa.load(audio_fn, sr=sr, mono=True)
29
-
30
- if win_size is None:
31
- hop_len=int(sr / fps)
32
- else:
33
- hop_len=int(sr / win_size)
34
-
35
- n_fft=2048
36
-
37
- C = librosa.feature.mfcc(
38
- y = y,
39
- sr = sr,
40
- n_mfcc = n_mfcc,
41
- hop_length = hop_len,
42
- n_fft = n_fft
43
- )
44
-
45
- if C.shape[0] == n_mfcc:
46
- C = C.transpose(1, 0)
47
-
48
- return C
49
-
50
-
51
- def get_melspec(audio_fn, eps=1e-6, fps = 25, sr=16000, n_mels=64):
52
- raise NotImplementedError
53
- '''
54
- # y, sr = load_wav(audio_fn=audio_fn, sr=sr)
55
-
56
- # hop_len = int(sr / fps)
57
- # n_fft = 2048
58
-
59
- # C = librosa.feature.melspectrogram(
60
- # y = y,
61
- # sr = sr,
62
- # n_fft=n_fft,
63
- # hop_length=hop_len,
64
- # n_mels = n_mels,
65
- # fmin=0,
66
- # fmax=8000)
67
-
68
-
69
- # mask = (C == 0).astype(np.float)
70
- # C = mask * eps + (1-mask) * C
71
-
72
- # C = np.log(C)
73
- # #wierd error may occur here
74
- # assert not (np.isnan(C).any()), audio_fn
75
- # if C.shape[0] == n_mels:
76
- # C = C.transpose(1, 0)
77
-
78
- # return C
79
- '''
80
-
81
- def extract_mfcc(audio,sample_rate=16000):
82
- mfcc = zip(*python_speech_features.mfcc(audio,sample_rate, numcep=64, nfilt=64, nfft=2048, winstep=0.04))
83
- mfcc = np.stack([np.array(i) for i in mfcc])
84
- return mfcc
85
-
86
- def get_mfcc_psf(audio_fn, eps=1e-6, fps=25, smlpx=False, sr=16000, n_mfcc=64, win_size=None):
87
- y, sr = load_wav_old(audio_fn, sr=sr)
88
-
89
- if y.shape.__len__() > 1:
90
- y = (y[:,0]+y[:,1])/2
91
-
92
- if win_size is None:
93
- hop_len=int(sr / fps)
94
- else:
95
- hop_len=int(sr/ win_size)
96
-
97
- n_fft=2048
98
-
99
- #hard coded for 25 fps
100
- if not smlpx:
101
- C = python_speech_features.mfcc(y, sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=0.04)
102
- else:
103
- C = python_speech_features.mfcc(y, sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=1.01/15)
104
- # if C.shape[0] == n_mfcc:
105
- # C = C.transpose(1, 0)
106
-
107
- return C
108
-
109
-
110
- def get_mfcc_psf_min(audio_fn, eps=1e-6, fps=25, smlpx=False, sr=16000, n_mfcc=64, win_size=None):
111
- y, sr = load_wav_old(audio_fn, sr=sr)
112
-
113
- if y.shape.__len__() > 1:
114
- y = (y[:, 0] + y[:, 1]) / 2
115
- n_fft = 2048
116
-
117
- slice_len = 22000 * 5
118
- slice = y.size // slice_len
119
-
120
- C = []
121
-
122
- for i in range(slice):
123
- if i != (slice - 1):
124
- feat = python_speech_features.mfcc(y[i*slice_len:(i+1)*slice_len], sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=1.01 / 15)
125
- else:
126
- feat = python_speech_features.mfcc(y[i * slice_len:], sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=1.01 / 15)
127
-
128
- C.append(feat)
129
-
130
- return C
131
-
132
-
133
- def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000):
134
- """
135
- :param audio: 1 x T tensor containing a 16kHz audio signal
136
- :param frame_rate: frame rate for video (we need one audio chunk per video frame)
137
- :param chunk_size: number of audio samples per chunk
138
- :return: num_chunks x chunk_size tensor containing sliced audio
139
- """
140
- samples_per_frame = chunk_size // frame_rate
141
- padding = (chunk_size - samples_per_frame) // 2
142
- audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0)
143
- anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame))
144
- audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0)
145
- return audio
146
-
147
-
148
- def get_mfcc_ta(audio_fn, eps=1e-6, fps=15, smlpx=False, sr=16000, n_mfcc=64, win_size=None, type='mfcc', am=None, am_sr=None, encoder_choice='mfcc'):
149
- if am is None:
150
- audio, sr_0 = ta.load(audio_fn)
151
- if sr != sr_0:
152
- audio = ta.transforms.Resample(sr_0, sr)(audio)
153
- if audio.shape[0] > 1:
154
- audio = torch.mean(audio, dim=0, keepdim=True)
155
-
156
- n_fft = 2048
157
- if fps == 15:
158
- hop_length = 1467
159
- elif fps == 30:
160
- hop_length = 734
161
- win_length = hop_length * 2
162
- n_mels = 256
163
- n_mfcc = 64
164
-
165
- if type == 'mfcc':
166
- mfcc_transform = ta_T.MFCC(
167
- sample_rate=sr,
168
- n_mfcc=n_mfcc,
169
- melkwargs={
170
- "n_fft": n_fft,
171
- "n_mels": n_mels,
172
- # "win_length": win_length,
173
- "hop_length": hop_length,
174
- "mel_scale": "htk",
175
- },
176
- )
177
- audio_ft = mfcc_transform(audio).squeeze(dim=0).transpose(0,1).numpy()
178
- elif type == 'mel':
179
- # audio = 0.01 * audio / torch.mean(torch.abs(audio))
180
- mel_transform = ta_T.MelSpectrogram(
181
- sample_rate=sr, n_fft=n_fft, win_length=None, hop_length=hop_length, n_mels=n_mels
182
- )
183
- audio_ft = mel_transform(audio).squeeze(0).transpose(0,1).numpy()
184
- # audio_ft = torch.log(audio_ft.clamp(min=1e-10, max=None)).transpose(0,1).numpy()
185
- elif type == 'mel_mul':
186
- audio = 0.01 * audio / torch.mean(torch.abs(audio))
187
- audio = audio_chunking(audio, frame_rate=fps, chunk_size=sr)
188
- mel_transform = ta_T.MelSpectrogram(
189
- sample_rate=sr, n_fft=n_fft, win_length=int(sr/20), hop_length=int(sr/100), n_mels=n_mels
190
- )
191
- audio_ft = mel_transform(audio).squeeze(1)
192
- audio_ft = torch.log(audio_ft.clamp(min=1e-10, max=None)).numpy()
193
- else:
194
- speech_array, sampling_rate = librosa.load(audio_fn, sr=16000)
195
-
196
- if encoder_choice == 'faceformer':
197
- # audio_ft = np.squeeze(am(speech_array, sampling_rate=16000).input_values).reshape(-1, 1)
198
- audio_ft = speech_array.reshape(-1, 1)
199
- elif encoder_choice == 'meshtalk':
200
- audio_ft = 0.01 * speech_array / np.mean(np.abs(speech_array))
201
- elif encoder_choice == 'onset':
202
- audio_ft = librosa.onset.onset_detect(y=speech_array, sr=16000, units='time').reshape(-1, 1)
203
- else:
204
- audio, sr_0 = ta.load(audio_fn)
205
- if sr != sr_0:
206
- audio = ta.transforms.Resample(sr_0, sr)(audio)
207
- if audio.shape[0] > 1:
208
- audio = torch.mean(audio, dim=0, keepdim=True)
209
-
210
- n_fft = 2048
211
- if fps == 15:
212
- hop_length = 1467
213
- elif fps == 30:
214
- hop_length = 734
215
- win_length = hop_length * 2
216
- n_mels = 256
217
- n_mfcc = 64
218
-
219
- mfcc_transform = ta_T.MFCC(
220
- sample_rate=sr,
221
- n_mfcc=n_mfcc,
222
- melkwargs={
223
- "n_fft": n_fft,
224
- "n_mels": n_mels,
225
- # "win_length": win_length,
226
- "hop_length": hop_length,
227
- "mel_scale": "htk",
228
- },
229
- )
230
- audio_ft = mfcc_transform(audio).squeeze(dim=0).transpose(0, 1).numpy()
231
- return audio_ft
232
-
233
-
234
- def get_mfcc_sepa(audio_fn, fps=15, sr=16000):
235
- audio, sr_0 = ta.load(audio_fn)
236
- if sr != sr_0:
237
- audio = ta.transforms.Resample(sr_0, sr)(audio)
238
- if audio.shape[0] > 1:
239
- audio = torch.mean(audio, dim=0, keepdim=True)
240
-
241
- n_fft = 2048
242
- if fps == 15:
243
- hop_length = 1467
244
- elif fps == 30:
245
- hop_length = 734
246
- n_mels = 256
247
- n_mfcc = 64
248
-
249
- mfcc_transform = ta_T.MFCC(
250
- sample_rate=sr,
251
- n_mfcc=n_mfcc,
252
- melkwargs={
253
- "n_fft": n_fft,
254
- "n_mels": n_mels,
255
- # "win_length": win_length,
256
- "hop_length": hop_length,
257
- "mel_scale": "htk",
258
- },
259
- )
260
- audio_ft_0 = mfcc_transform(audio[0, :sr*2]).squeeze(dim=0).transpose(0,1).numpy()
261
- audio_ft_1 = mfcc_transform(audio[0, sr*2:]).squeeze(dim=0).transpose(0,1).numpy()
262
- audio_ft = np.concatenate((audio_ft_0, audio_ft_1), axis=0)
263
- return audio_ft, audio_ft_0.shape[0]
264
-
265
-
266
- def get_mfcc_old(wav_file):
267
- sig, sample_rate = load_wav_old(wav_file)
268
- mfcc = extract_mfcc(sig)
269
- return mfcc
270
-
271
-
272
- def smooth_geom(geom, mask: torch.Tensor = None, filter_size: int = 9, sigma: float = 2.0):
273
- """
274
- :param geom: T x V x 3 tensor containing a temporal sequence of length T with V vertices in each frame
275
- :param mask: V-dimensional Tensor containing a mask with vertices to be smoothed
276
- :param filter_size: size of the Gaussian filter
277
- :param sigma: standard deviation of the Gaussian filter
278
- :return: T x V x 3 tensor containing smoothed geometry (i.e., smoothed in the area indicated by the mask)
279
- """
280
- assert filter_size % 2 == 1, f"filter size must be odd but is {filter_size}"
281
- # Gaussian smoothing (low-pass filtering)
282
- fltr = np.arange(-(filter_size // 2), filter_size // 2 + 1)
283
- fltr = np.exp(-0.5 * fltr ** 2 / sigma ** 2)
284
- fltr = torch.Tensor(fltr) / np.sum(fltr)
285
- # apply fltr
286
- fltr = fltr.view(1, 1, -1).to(device=geom.device)
287
- T, V = geom.shape[1], geom.shape[2]
288
- g = torch.nn.functional.pad(
289
- geom.permute(2, 0, 1).view(V, 1, T),
290
- pad=[filter_size // 2, filter_size // 2], mode='replicate'
291
- )
292
- g = torch.nn.functional.conv1d(g, fltr).view(V, 1, T)
293
- smoothed = g.permute(1, 2, 0).contiguous()
294
- # blend smoothed signal with original signal
295
- if mask is None:
296
- return smoothed
297
- else:
298
- return smoothed * mask[None, :, None] + geom * (-mask[None, :, None] + 1)
299
-
300
- if __name__ == '__main__':
301
- audio_fn = '../sample_audio/clip000028_tCAkv4ggPgI.wav'
302
-
303
- C = get_mfcc_psf(audio_fn)
304
- print(C.shape)
305
-
306
- C_2 = get_mfcc_librosa(audio_fn)
307
- print(C.shape)
308
-
309
- print(C)
310
- print(C_2)
311
- print((C == C_2).all())
312
- # print(y.shape, sr)
313
- # mel_spec = get_melspec(audio_fn)
314
- # print(mel_spec.shape)
315
- # mfcc = get_mfcc(audio_fn, sr = 16000)
316
- # print(mfcc.shape)
317
- # print(mel_spec.max(), mel_spec.min())
318
- # print(mfcc.max(), mfcc.min())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo_audio/1st-page.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5fd78f4976c2fded490d274a9d4f20b5ebbc8e3c4e9f08ff2f69b38f92786818
3
- size 410190
 
 
 
 
demo_audio/yoy.py DELETED
File without changes
download_models.py DELETED
@@ -1,28 +0,0 @@
1
- import os
2
- import urllib.request
3
- import zipfile
4
- import subprocess
5
-
6
- def download_file(url, output_path):
7
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
8
- if not os.path.exists(output_path):
9
- print(f"Downloading {url} to {output_path}...")
10
- urllib.request.urlretrieve(url, output_path)
11
- print("Download complete!")
12
- else:
13
- print(f"File already exists: {output_path}")
14
-
15
- def main():
16
- # Create necessary directories
17
- os.makedirs("experiments", exist_ok=True)
18
- os.makedirs("visualise/smplx_model", exist_ok=True)
19
-
20
- # Here you would need to add URLs to download your models
21
- # For example:
22
- # download_file("YOUR_MODEL_URL", "experiments/your_model.pth")
23
- # download_file("SMPLX_MODEL_URL", "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz")
24
-
25
- print("Setup complete!")
26
-
27
- if __name__ == "__main__":
28
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
evaluation/FGD.py DELETED
@@ -1,199 +0,0 @@
1
- import time
2
-
3
- import numpy as np
4
- import torch
5
- import torch.nn.functional as F
6
- from scipy import linalg
7
- import math
8
- from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d
9
-
10
- import warnings
11
- warnings.filterwarnings("ignore", category=RuntimeWarning) # ignore warnings
12
-
13
-
14
- change_angle = torch.tensor([6.0181e-05, 5.1597e-05, 2.1344e-04, 2.1899e-04])
15
- class EmbeddingSpaceEvaluator:
16
- def __init__(self, ae, vae, device):
17
-
18
- # init embed net
19
- self.ae = ae
20
- # self.vae = vae
21
-
22
- # storage
23
- self.real_feat_list = []
24
- self.generated_feat_list = []
25
- self.real_joints_list = []
26
- self.generated_joints_list = []
27
- self.real_6d_list = []
28
- self.generated_6d_list = []
29
- self.audio_beat_list = []
30
-
31
- def reset(self):
32
- self.real_feat_list = []
33
- self.generated_feat_list = []
34
-
35
- def get_no_of_samples(self):
36
- return len(self.real_feat_list)
37
-
38
- def push_samples(self, generated_poses, real_poses):
39
- # self.net.eval()
40
- # convert poses to latent features
41
- real_feat, real_poses = self.ae.extract(real_poses)
42
- generated_feat, generated_poses = self.ae.extract(generated_poses)
43
-
44
- num_joints = real_poses.shape[2] // 3
45
-
46
- real_feat = real_feat.squeeze()
47
- generated_feat = generated_feat.reshape(generated_feat.shape[0]*generated_feat.shape[1], -1)
48
-
49
- self.real_feat_list.append(real_feat.data.cpu().numpy())
50
- self.generated_feat_list.append(generated_feat.data.cpu().numpy())
51
-
52
- # real_poses = matrix_to_rotation_6d(axis_angle_to_matrix(real_poses.reshape(-1, 3))).reshape(-1, num_joints, 6)
53
- # generated_poses = matrix_to_rotation_6d(axis_angle_to_matrix(generated_poses.reshape(-1, 3))).reshape(-1, num_joints, 6)
54
- #
55
- # self.real_feat_list.append(real_poses.data.cpu().numpy())
56
- # self.generated_feat_list.append(generated_poses.data.cpu().numpy())
57
-
58
- def push_joints(self, generated_poses, real_poses):
59
- self.real_joints_list.append(real_poses.data.cpu())
60
- self.generated_joints_list.append(generated_poses.squeeze().data.cpu())
61
-
62
- def push_aud(self, aud):
63
- self.audio_beat_list.append(aud.squeeze().data.cpu())
64
-
65
- def get_MAAC(self):
66
- ang_vel_list = []
67
- for real_joints in self.real_joints_list:
68
- real_joints[:, 15:21] = real_joints[:, 16:22]
69
- vec = real_joints[:, 15:21] - real_joints[:, 13:19]
70
- inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]])
71
- inner_product = torch.clamp(inner_product, -1, 1, out=None)
72
- angle = torch.acos(inner_product) / math.pi
73
- ang_vel = (angle[1:] - angle[:-1]).abs().mean(dim=0)
74
- ang_vel_list.append(ang_vel.unsqueeze(dim=0))
75
- all_vel = torch.cat(ang_vel_list, dim=0)
76
- MAAC = all_vel.mean(dim=0)
77
- return MAAC
78
-
79
- def get_BCscore(self):
80
- thres = 0.01
81
- sigma = 0.1
82
- sum_1 = 0
83
- total_beat = 0
84
- for joints, audio_beat_time in zip(self.generated_joints_list, self.audio_beat_list):
85
- motion_beat_time = []
86
- if joints.dim() == 4:
87
- joints = joints[0]
88
- joints[:, 15:21] = joints[:, 16:22]
89
- vec = joints[:, 15:21] - joints[:, 13:19]
90
- inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]])
91
- inner_product = torch.clamp(inner_product, -1, 1, out=None)
92
- angle = torch.acos(inner_product) / math.pi
93
- ang_vel = (angle[1:] - angle[:-1]).abs() / change_angle / len(change_angle)
94
-
95
- angle_diff = torch.cat((torch.zeros(1, 4), ang_vel), dim=0)
96
-
97
- sum_2 = 0
98
- for i in range(angle_diff.shape[1]):
99
- motion_beat_time = []
100
- for t in range(1, joints.shape[0]-1):
101
- if (angle_diff[t][i] < angle_diff[t - 1][i] and angle_diff[t][i] < angle_diff[t + 1][i]):
102
- if (angle_diff[t - 1][i] - angle_diff[t][i] >= thres or angle_diff[t + 1][i] - angle_diff[
103
- t][i] >= thres):
104
- motion_beat_time.append(float(t) / 30.0)
105
- if (len(motion_beat_time) == 0):
106
- continue
107
- motion_beat_time = torch.tensor(motion_beat_time)
108
- sum = 0
109
- for audio in audio_beat_time:
110
- sum += np.power(math.e, -(np.power((audio.item() - motion_beat_time), 2)).min() / (2 * sigma * sigma))
111
- sum_2 = sum_2 + sum
112
- total_beat = total_beat + len(audio_beat_time)
113
- sum_1 = sum_1 + sum_2
114
- return sum_1/total_beat
115
-
116
-
117
- def get_scores(self):
118
- generated_feats = np.vstack(self.generated_feat_list)
119
- real_feats = np.vstack(self.real_feat_list)
120
-
121
- def frechet_distance(samples_A, samples_B):
122
- A_mu = np.mean(samples_A, axis=0)
123
- A_sigma = np.cov(samples_A, rowvar=False)
124
- B_mu = np.mean(samples_B, axis=0)
125
- B_sigma = np.cov(samples_B, rowvar=False)
126
- try:
127
- frechet_dist = self.calculate_frechet_distance(A_mu, A_sigma, B_mu, B_sigma)
128
- except ValueError:
129
- frechet_dist = 1e+10
130
- return frechet_dist
131
-
132
- ####################################################################
133
- # frechet distance
134
- frechet_dist = frechet_distance(generated_feats, real_feats)
135
-
136
- ####################################################################
137
- # distance between real and generated samples on the latent feature space
138
- dists = []
139
- for i in range(real_feats.shape[0]):
140
- d = np.sum(np.absolute(real_feats[i] - generated_feats[i])) # MAE
141
- dists.append(d)
142
- feat_dist = np.mean(dists)
143
-
144
- return frechet_dist, feat_dist
145
-
146
- @staticmethod
147
- def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
148
- """ from https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py """
149
- """Numpy implementation of the Frechet Distance.
150
- The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
151
- and X_2 ~ N(mu_2, C_2) is
152
- d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
153
- Stable version by Dougal J. Sutherland.
154
- Params:
155
- -- mu1 : Numpy array containing the activations of a layer of the
156
- inception net (like returned by the function 'get_predictions')
157
- for generated samples.
158
- -- mu2 : The sample mean over activations, precalculated on an
159
- representative data set.
160
- -- sigma1: The covariance matrix over activations for generated samples.
161
- -- sigma2: The covariance matrix over activations, precalculated on an
162
- representative data set.
163
- Returns:
164
- -- : The Frechet Distance.
165
- """
166
-
167
- mu1 = np.atleast_1d(mu1)
168
- mu2 = np.atleast_1d(mu2)
169
-
170
- sigma1 = np.atleast_2d(sigma1)
171
- sigma2 = np.atleast_2d(sigma2)
172
-
173
- assert mu1.shape == mu2.shape, \
174
- 'Training and test mean vectors have different lengths'
175
- assert sigma1.shape == sigma2.shape, \
176
- 'Training and test covariances have different dimensions'
177
-
178
- diff = mu1 - mu2
179
-
180
- # Product might be almost singular
181
- covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
182
- if not np.isfinite(covmean).all():
183
- msg = ('fid calculation produces singular product; '
184
- 'adding %s to diagonal of cov estimates') % eps
185
- print(msg)
186
- offset = np.eye(sigma1.shape[0]) * eps
187
- covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
188
-
189
- # Numerical error might give slight imaginary component
190
- if np.iscomplexobj(covmean):
191
- if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
192
- m = np.max(np.abs(covmean.imag))
193
- raise ValueError('Imaginary component {}'.format(m))
194
- covmean = covmean.real
195
-
196
- tr_covmean = np.trace(covmean)
197
-
198
- return (diff.dot(diff) + np.trace(sigma1) +
199
- np.trace(sigma2) - 2 * tr_covmean)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
evaluation/__init__.py DELETED
File without changes
evaluation/diversity_LVD.py DELETED
@@ -1,64 +0,0 @@
1
- '''
2
- LVD: different initial pose
3
- diversity: same initial pose
4
- '''
5
- import os
6
- import sys
7
- sys.path.append(os.getcwd())
8
-
9
- from glob import glob
10
-
11
- from argparse import ArgumentParser
12
- import json
13
-
14
- from evaluation.util import *
15
- from evaluation.metrics import *
16
- from tqdm import tqdm
17
-
18
- parser = ArgumentParser()
19
- parser.add_argument('--speaker', required=True, type=str)
20
- parser.add_argument('--post_fix', nargs='+', default=['base'], type=str)
21
- args = parser.parse_args()
22
-
23
- speaker = args.speaker
24
- test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
25
-
26
- LVD_list = []
27
- diversity_list = []
28
-
29
- for aud in tqdm(test_audios):
30
- base_name = os.path.splitext(aud)[0]
31
- gt_path = get_full_path(aud, speaker, 'val')
32
- _, gt_poses, _ = get_gts(gt_path)
33
- gt_poses = gt_poses[np.newaxis,...]
34
- # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
35
- for post_fix in args.post_fix:
36
- pred_path = base_name + '_'+post_fix+'.json'
37
- pred_poses = np.array(json.load(open(pred_path)))
38
- # print(pred_poses.shape)#(B, seq_len, 108)
39
- pred_poses = cvt25(pred_poses, gt_poses)
40
- # print(pred_poses.shape)#(B, seq, pose_dim)
41
-
42
- gt_valid_points = hand_points(gt_poses)
43
- pred_valid_points = hand_points(pred_poses)
44
-
45
- lvd = LVD(gt_valid_points, pred_valid_points)
46
- # div = diversity(pred_valid_points)
47
-
48
- LVD_list.append(lvd)
49
- # diversity_list.append(div)
50
-
51
- # gt_velocity = peak_velocity(gt_valid_points, order=2)
52
- # pred_velocity = peak_velocity(pred_valid_points, order=2)
53
-
54
- # gt_consistency = velocity_consistency(gt_velocity, pred_velocity)
55
- # pred_consistency = velocity_consistency(pred_velocity, gt_velocity)
56
-
57
- # gt_consistency_list.append(gt_consistency)
58
- # pred_consistency_list.append(pred_consistency)
59
-
60
- lvd = np.mean(LVD_list)
61
- # diversity_list = np.mean(diversity_list)
62
-
63
- print('LVD:', lvd)
64
- # print("diversity:", diversity_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
evaluation/get_quality_samples.py DELETED
@@ -1,62 +0,0 @@
1
- '''
2
- '''
3
- import os
4
- import sys
5
- sys.path.append(os.getcwd())
6
-
7
- from glob import glob
8
-
9
- from argparse import ArgumentParser
10
- import json
11
-
12
- from evaluation.util import *
13
- from evaluation.metrics import *
14
- from tqdm import tqdm
15
-
16
- parser = ArgumentParser()
17
- parser.add_argument('--speaker', required=True, type=str)
18
- parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
19
- args = parser.parse_args()
20
-
21
- speaker = args.speaker
22
- test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
23
-
24
- quality_samples={'gt':[]}
25
- for post_fix in args.post_fix:
26
- quality_samples[post_fix] = []
27
-
28
- for aud in tqdm(test_audios):
29
- base_name = os.path.splitext(aud)[0]
30
- gt_path = get_full_path(aud, speaker, 'val')
31
- _, gt_poses, _ = get_gts(gt_path)
32
- gt_poses = gt_poses[np.newaxis,...]
33
- gt_valid_points = valid_points(gt_poses)
34
- # print(gt_valid_points.shape)
35
- quality_samples['gt'].append(gt_valid_points)
36
-
37
- for post_fix in args.post_fix:
38
- pred_path = base_name + '_'+post_fix+'.json'
39
- pred_poses = np.array(json.load(open(pred_path)))
40
- # print(pred_poses.shape)#(B, seq_len, 108)
41
- pred_poses = cvt25(pred_poses, gt_poses)
42
- # print(pred_poses.shape)#(B, seq, pose_dim)
43
-
44
- pred_valid_points = valid_points(pred_poses)[0:1]
45
- quality_samples[post_fix].append(pred_valid_points)
46
-
47
- quality_samples['gt'] = np.concatenate(quality_samples['gt'], axis=1)
48
- for post_fix in args.post_fix:
49
- quality_samples[post_fix] = np.concatenate(quality_samples[post_fix], axis=1)
50
-
51
- print('gt:', quality_samples['gt'].shape)
52
- quality_samples['gt'] = quality_samples['gt'].tolist()
53
- for post_fix in args.post_fix:
54
- print(post_fix, ':', quality_samples[post_fix].shape)
55
- quality_samples[post_fix] = quality_samples[post_fix].tolist()
56
-
57
- save_dir = '../../experiments/'
58
- os.makedirs(save_dir, exist_ok=True)
59
- save_name = os.path.join(save_dir, 'quality_samples_%s.json'%(speaker))
60
- with open(save_name, 'w') as f:
61
- json.dump(quality_samples, f)
62
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
evaluation/metrics.py DELETED
@@ -1,109 +0,0 @@
1
- '''
2
- Warning: metrics are for reference only, may have limited significance
3
- '''
4
- import os
5
- import sys
6
- sys.path.append(os.getcwd())
7
- import numpy as np
8
- import torch
9
-
10
- from data_utils.lower_body import rearrange, symmetry
11
- import torch.nn.functional as F
12
-
13
- def data_driven_baselines(gt_kps):
14
- '''
15
- gt_kps: T, D
16
- '''
17
- gt_velocity = np.abs(gt_kps[1:] - gt_kps[:-1])
18
-
19
- mean= np.mean(gt_velocity, axis=0)[np.newaxis] #(1, D)
20
- mean = np.mean(np.abs(gt_velocity-mean))
21
- last_step = gt_kps[1] - gt_kps[0]
22
- last_step = last_step[np.newaxis] #(1, D)
23
- last_step = np.mean(np.abs(gt_velocity-last_step))
24
- return last_step, mean
25
-
26
- def Batch_LVD(gt_kps, pr_kps, symmetrical, weight):
27
- if gt_kps.shape[0] > pr_kps.shape[1]:
28
- length = pr_kps.shape[1]
29
- else:
30
- length = gt_kps.shape[0]
31
- gt_kps = gt_kps[:length]
32
- pr_kps = pr_kps[:, :length]
33
- global symmetry
34
- symmetry = torch.tensor(symmetry).bool()
35
-
36
- if symmetrical:
37
- # rearrange for compute symmetric. ns means non-symmetrical joints, ys means symmetrical joints.
38
- gt_kps = gt_kps[:, rearrange]
39
- ns_gt_kps = gt_kps[:, ~symmetry]
40
- ys_gt_kps = gt_kps[:, symmetry]
41
- ys_gt_kps = ys_gt_kps.reshape(ys_gt_kps.shape[0], -1, 2, 3)
42
- ns_gt_velocity = (ns_gt_kps[1:] - ns_gt_kps[:-1]).norm(p=2, dim=-1)
43
- ys_gt_velocity = (ys_gt_kps[1:] - ys_gt_kps[:-1]).norm(p=2, dim=-1)
44
- left_gt_vel = ys_gt_velocity[:, :, 0].sum(dim=-1)
45
- right_gt_vel = ys_gt_velocity[:, :, 1].sum(dim=-1)
46
- move_side = torch.where(left_gt_vel>right_gt_vel, torch.ones(left_gt_vel.shape).cuda(), torch.zeros(left_gt_vel.shape).cuda())
47
- ys_gt_velocity = torch.mul(ys_gt_velocity[:, :, 0].transpose(0,1), move_side) + torch.mul(ys_gt_velocity[:, :, 1].transpose(0,1), ~move_side.bool())
48
- ys_gt_velocity = ys_gt_velocity.transpose(0,1)
49
- gt_velocity = torch.cat([ns_gt_velocity, ys_gt_velocity], dim=1)
50
-
51
- pr_kps = pr_kps[:, :, rearrange]
52
- ns_pr_kps = pr_kps[:, :, ~symmetry]
53
- ys_pr_kps = pr_kps[:, :, symmetry]
54
- ys_pr_kps = ys_pr_kps.reshape(ys_pr_kps.shape[0], ys_pr_kps.shape[1], -1, 2, 3)
55
- ns_pr_velocity = (ns_pr_kps[:, 1:] - ns_pr_kps[:, :-1]).norm(p=2, dim=-1)
56
- ys_pr_velocity = (ys_pr_kps[:, 1:] - ys_pr_kps[:, :-1]).norm(p=2, dim=-1)
57
- left_pr_vel = ys_pr_velocity[:, :, :, 0].sum(dim=-1)
58
- right_pr_vel = ys_pr_velocity[:, :, :, 1].sum(dim=-1)
59
- move_side = torch.where(left_pr_vel > right_pr_vel, torch.ones(left_pr_vel.shape).cuda(),
60
- torch.zeros(left_pr_vel.shape).cuda())
61
- ys_pr_velocity = torch.mul(ys_pr_velocity[..., 0].permute(2, 0, 1), move_side) + torch.mul(
62
- ys_pr_velocity[..., 1].permute(2, 0, 1), ~move_side.long())
63
- ys_pr_velocity = ys_pr_velocity.permute(1, 2, 0)
64
- pr_velocity = torch.cat([ns_pr_velocity, ys_pr_velocity], dim=2)
65
- else:
66
- gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1)
67
- pr_velocity = (pr_kps[:, 1:] - pr_kps[:, :-1]).norm(p=2, dim=-1)
68
-
69
- if weight:
70
- w = F.softmax(gt_velocity.sum(dim=1).normal_(), dim=0)
71
- else:
72
- w = 1 / gt_velocity.shape[0]
73
-
74
- v_diff = ((pr_velocity - gt_velocity).abs().sum(dim=-1) * w).sum(dim=-1).mean()
75
-
76
- return v_diff
77
-
78
-
79
- def LVD(gt_kps, pr_kps, symmetrical=False, weight=False):
80
- gt_kps = gt_kps.squeeze()
81
- pr_kps = pr_kps.squeeze()
82
- if len(pr_kps.shape) == 4:
83
- return Batch_LVD(gt_kps, pr_kps, symmetrical, weight)
84
- # length = np.minimum(gt_kps.shape[0], pr_kps.shape[0])
85
- length = gt_kps.shape[0]-10
86
- # gt_kps = gt_kps[25:length]
87
- # pr_kps = pr_kps[25:length] #(T, D)
88
- # if pr_kps.shape[0] < gt_kps.shape[0]:
89
- # pr_kps = np.pad(pr_kps, [[0, int(gt_kps.shape[0]-pr_kps.shape[0])], [0, 0]], mode='constant')
90
-
91
- gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1)
92
- pr_velocity = (pr_kps[1:] - pr_kps[:-1]).norm(p=2, dim=-1)
93
-
94
- return (pr_velocity-gt_velocity).abs().sum(dim=-1).mean()
95
-
96
- def diversity(kps):
97
- '''
98
- kps: bs, seq, dim
99
- '''
100
- dis_list = []
101
- #the distance between each pair
102
- for i in range(kps.shape[0]):
103
- for j in range(i+1, kps.shape[0]):
104
- seq_i = kps[i]
105
- seq_j = kps[j]
106
-
107
- dis = np.mean(np.abs(seq_i - seq_j))
108
- dis_list.append(dis)
109
- return np.mean(dis_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
evaluation/mode_transition.py DELETED
@@ -1,60 +0,0 @@
1
- import os
2
- import sys
3
- sys.path.append(os.getcwd())
4
-
5
- from glob import glob
6
-
7
- from argparse import ArgumentParser
8
- import json
9
-
10
- from evaluation.util import *
11
- from evaluation.metrics import *
12
- from tqdm import tqdm
13
-
14
- parser = ArgumentParser()
15
- parser.add_argument('--speaker', required=True, type=str)
16
- parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
17
- args = parser.parse_args()
18
-
19
- speaker = args.speaker
20
- test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
21
-
22
- precision_list=[]
23
- recall_list=[]
24
- accuracy_list=[]
25
-
26
- for aud in tqdm(test_audios):
27
- base_name = os.path.splitext(aud)[0]
28
- gt_path = get_full_path(aud, speaker, 'val')
29
- _, gt_poses, _ = get_gts(gt_path)
30
- if gt_poses.shape[0] < 50:
31
- continue
32
- gt_poses = gt_poses[np.newaxis,...]
33
- # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
34
- for post_fix in args.post_fix:
35
- pred_path = base_name + '_'+post_fix+'.json'
36
- pred_poses = np.array(json.load(open(pred_path)))
37
- # print(pred_poses.shape)#(B, seq_len, 108)
38
- pred_poses = cvt25(pred_poses, gt_poses)
39
- # print(pred_poses.shape)#(B, seq, pose_dim)
40
-
41
- gt_valid_points = valid_points(gt_poses)
42
- pred_valid_points = valid_points(pred_poses)
43
-
44
- # print(gt_valid_points.shape, pred_valid_points.shape)
45
-
46
- gt_mode_transition_seq = mode_transition_seq(gt_valid_points, speaker)#(B, N)
47
- pred_mode_transition_seq = mode_transition_seq(pred_valid_points, speaker)#(B, N)
48
-
49
- # baseline = np.random.randint(0, 2, size=pred_mode_transition_seq.shape)
50
- # pred_mode_transition_seq = baseline
51
- precision, recall, accuracy = mode_transition_consistency(pred_mode_transition_seq, gt_mode_transition_seq)
52
- precision_list.append(precision)
53
- recall_list.append(recall)
54
- accuracy_list.append(accuracy)
55
- print(len(precision_list), len(recall_list), len(accuracy_list))
56
- precision_list = np.mean(precision_list)
57
- recall_list = np.mean(recall_list)
58
- accuracy_list = np.mean(accuracy_list)
59
-
60
- print('precision, recall, accu:', precision_list, recall_list, accuracy_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
evaluation/peak_velocity.py DELETED
@@ -1,65 +0,0 @@
1
- import os
2
- import sys
3
- sys.path.append(os.getcwd())
4
-
5
- from glob import glob
6
-
7
- from argparse import ArgumentParser
8
- import json
9
-
10
- from evaluation.util import *
11
- from evaluation.metrics import *
12
- from tqdm import tqdm
13
-
14
- parser = ArgumentParser()
15
- parser.add_argument('--speaker', required=True, type=str)
16
- parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
17
- args = parser.parse_args()
18
-
19
- speaker = args.speaker
20
- test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
21
-
22
- gt_consistency_list=[]
23
- pred_consistency_list=[]
24
-
25
- for aud in tqdm(test_audios):
26
- base_name = os.path.splitext(aud)[0]
27
- gt_path = get_full_path(aud, speaker, 'val')
28
- _, gt_poses, _ = get_gts(gt_path)
29
- gt_poses = gt_poses[np.newaxis,...]
30
- # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
31
- for post_fix in args.post_fix:
32
- pred_path = base_name + '_'+post_fix+'.json'
33
- pred_poses = np.array(json.load(open(pred_path)))
34
- # print(pred_poses.shape)#(B, seq_len, 108)
35
- pred_poses = cvt25(pred_poses, gt_poses)
36
- # print(pred_poses.shape)#(B, seq, pose_dim)
37
-
38
- gt_valid_points = hand_points(gt_poses)
39
- pred_valid_points = hand_points(pred_poses)
40
-
41
- gt_velocity = peak_velocity(gt_valid_points, order=2)
42
- pred_velocity = peak_velocity(pred_valid_points, order=2)
43
-
44
- gt_consistency = velocity_consistency(gt_velocity, pred_velocity)
45
- pred_consistency = velocity_consistency(pred_velocity, gt_velocity)
46
-
47
- gt_consistency_list.append(gt_consistency)
48
- pred_consistency_list.append(pred_consistency)
49
-
50
- gt_consistency_list = np.concatenate(gt_consistency_list)
51
- pred_consistency_list = np.concatenate(pred_consistency_list)
52
-
53
- print(gt_consistency_list.max(), gt_consistency_list.min())
54
- print(pred_consistency_list.max(), pred_consistency_list.min())
55
- print(np.mean(gt_consistency_list), np.mean(pred_consistency_list))
56
- print(np.std(gt_consistency_list), np.std(pred_consistency_list))
57
-
58
- draw_cdf(gt_consistency_list, save_name='%s_gt.jpg'%(speaker), color='slateblue')
59
- draw_cdf(pred_consistency_list, save_name='%s_pred.jpg'%(speaker), color='lightskyblue')
60
-
61
- to_excel(gt_consistency_list, '%s_gt.xlsx'%(speaker))
62
- to_excel(pred_consistency_list, '%s_pred.xlsx'%(speaker))
63
-
64
- np.save('%s_gt.npy'%(speaker), gt_consistency_list)
65
- np.save('%s_pred.npy'%(speaker), pred_consistency_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
evaluation/util.py DELETED
@@ -1,148 +0,0 @@
1
- import os
2
- from glob import glob
3
- import numpy as np
4
- import json
5
- from matplotlib import pyplot as plt
6
- import pandas as pd
7
- def get_gts(clip):
8
- '''
9
- clip: abs path to the clip dir
10
- '''
11
- keypoints_files = sorted(glob(os.path.join(clip, 'keypoints_new/person_1')+'/*.json'))
12
-
13
- upper_body_points = list(np.arange(0, 25))
14
- poses = []
15
- confs = []
16
- neck_to_nose_len = []
17
- mean_position = []
18
- for kp_file in keypoints_files:
19
- kp_load = json.load(open(kp_file, 'r'))['people'][0]
20
- posepts = kp_load['pose_keypoints_2d']
21
- lhandpts = kp_load['hand_left_keypoints_2d']
22
- rhandpts = kp_load['hand_right_keypoints_2d']
23
- facepts = kp_load['face_keypoints_2d']
24
-
25
- neck = np.array(posepts).reshape(-1,3)[1]
26
- nose = np.array(posepts).reshape(-1,3)[0]
27
- x_offset = abs(neck[0]-nose[0])
28
- y_offset = abs(neck[1]-nose[1])
29
- neck_to_nose_len.append(y_offset)
30
- mean_position.append([neck[0],neck[1]])
31
-
32
- keypoints=np.array(posepts+lhandpts+rhandpts+facepts).reshape(-1,3)[:,:2]
33
-
34
- upper_body = keypoints[upper_body_points, :]
35
- hand_points = keypoints[25:, :]
36
- keypoints = np.vstack([upper_body, hand_points])
37
-
38
- poses.append(keypoints)
39
-
40
- if len(neck_to_nose_len) > 0:
41
- scale_factor = np.mean(neck_to_nose_len)
42
- else:
43
- raise ValueError(clip)
44
- mean_position = np.mean(np.array(mean_position), axis=0)
45
-
46
- unlocalized_poses = np.array(poses).copy()
47
- localized_poses = []
48
- for i in range(len(poses)):
49
- keypoints = poses[i]
50
- neck = keypoints[1].copy()
51
-
52
- keypoints[:, 0] = (keypoints[:, 0] - neck[0]) / scale_factor
53
- keypoints[:, 1] = (keypoints[:, 1] - neck[1]) / scale_factor
54
- localized_poses.append(keypoints.reshape(-1))
55
-
56
- localized_poses=np.array(localized_poses)
57
- return unlocalized_poses, localized_poses, (scale_factor, mean_position)
58
-
59
- def get_full_path(wav_name, speaker, split):
60
- '''
61
- get clip path from aud file
62
- '''
63
- wav_name = os.path.basename(wav_name)
64
- wav_name = os.path.splitext(wav_name)[0]
65
- clip_name, vid_name = wav_name[:10], wav_name[11:]
66
-
67
- full_path = os.path.join('pose_dataset/videos/', speaker, 'clips', vid_name, 'images/half', split, clip_name)
68
-
69
- assert os.path.isdir(full_path), full_path
70
-
71
- return full_path
72
-
73
- def smooth(res):
74
- '''
75
- res: (B, seq_len, pose_dim)
76
- '''
77
- window = [res[:, 7, :], res[:, 8, :], res[:, 9, :], res[:, 10, :], res[:, 11, :], res[:, 12, :]]
78
- w_size=7
79
- for i in range(10, res.shape[1]-3):
80
- window.append(res[:, i+3, :])
81
- if len(window) > w_size:
82
- window = window[1:]
83
-
84
- if (i%25) in [22, 23, 24, 0, 1, 2, 3]:
85
- res[:, i, :] = np.mean(window, axis=1)
86
-
87
- return res
88
-
89
- def cvt25(pred_poses, gt_poses=None):
90
- '''
91
- gt_poses: (1, seq_len, 270), 135 *2
92
- pred_poses: (B, seq_len, 108), 54 * 2
93
- '''
94
- if gt_poses is None:
95
- gt_poses = np.zeros_like(pred_poses)
96
- else:
97
- gt_poses = gt_poses.repeat(pred_poses.shape[0], axis=0)
98
-
99
- length = min(pred_poses.shape[1], gt_poses.shape[1])
100
- pred_poses = pred_poses[:, :length, :]
101
- gt_poses = gt_poses[:, :length, :]
102
- gt_poses = gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1, 2)
103
- pred_poses = pred_poses.reshape(pred_poses.shape[0], pred_poses.shape[1], -1, 2)
104
-
105
- gt_poses[:, :, [1, 2, 3, 4, 5, 6, 7], :] = pred_poses[:, :, 1:8, :]
106
- gt_poses[:, :, 25:25+21+21, :] = pred_poses[:, :, 12:, :]
107
-
108
- return gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1)
109
-
110
- def hand_points(seq):
111
- '''
112
- seq: (B, seq_len, 135*2)
113
- hands only
114
- '''
115
- hand_idx = [1, 2, 3, 4,5 ,6,7] + list(range(25, 25+21+21))
116
- seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2)
117
- return seq[:, :, hand_idx, :].reshape(seq.shape[0], seq.shape[1], -1)
118
-
119
- def valid_points(seq):
120
- '''
121
- hands with some head points
122
- '''
123
- valid_idx = [0, 1, 2, 3, 4,5 ,6,7, 8, 9, 10, 11] + list(range(25, 25+21+21))
124
- seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2)
125
-
126
- seq = seq[:, :, valid_idx, :].reshape(seq.shape[0], seq.shape[1], -1)
127
- assert seq.shape[-1] == 108, seq.shape
128
- return seq
129
-
130
- def draw_cdf(seq, save_name='cdf.jpg', color='slatebule'):
131
- plt.figure()
132
- plt.hist(seq, bins=100, range=(0, 100), color=color)
133
- plt.savefig(save_name)
134
-
135
- def to_excel(seq, save_name='res.xlsx'):
136
- '''
137
- seq: (T)
138
- '''
139
- df = pd.DataFrame(seq)
140
- writer = pd.ExcelWriter(save_name)
141
- df.to_excel(writer, 'sheet1')
142
- writer.save()
143
- writer.close()
144
-
145
-
146
- if __name__ == '__main__':
147
- random_data = np.random.randint(0, 10, 100)
148
- draw_cdf(random_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
losses/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .losses import *
 
 
losses/losses.py DELETED
@@ -1,91 +0,0 @@
1
- import os
2
- import sys
3
-
4
- sys.path.append(os.getcwd())
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- import numpy as np
10
-
11
- class KeypointLoss(nn.Module):
12
- def __init__(self):
13
- super(KeypointLoss, self).__init__()
14
-
15
- def forward(self, pred_seq, gt_seq, gt_conf=None):
16
- #pred_seq: (B, C, T)
17
- if gt_conf is not None:
18
- gt_conf = gt_conf >= 0.01
19
- return F.mse_loss(pred_seq[gt_conf], gt_seq[gt_conf], reduction='mean')
20
- else:
21
- return F.mse_loss(pred_seq, gt_seq)
22
-
23
-
24
- class KLLoss(nn.Module):
25
- def __init__(self, kl_tolerance):
26
- super(KLLoss, self).__init__()
27
- self.kl_tolerance = kl_tolerance
28
-
29
- def forward(self, mu, var, mul=1):
30
- kl_tolerance = self.kl_tolerance * mul * var.shape[1] / 64
31
- kld_loss = -0.5 * torch.sum(1 + var - mu**2 - var.exp(), dim=1)
32
- # kld_loss = -0.5 * torch.sum(1 + (var-1) - (mu) ** 2 - (var-1).exp(), dim=1)
33
- if self.kl_tolerance is not None:
34
- # above_line = kld_loss[kld_loss > self.kl_tolerance]
35
- # if len(above_line) > 0:
36
- # kld_loss = torch.mean(kld_loss)
37
- # else:
38
- # kld_loss = 0
39
- kld_loss = torch.where(kld_loss > kl_tolerance, kld_loss, torch.tensor(kl_tolerance, device='cuda'))
40
- # else:
41
- kld_loss = torch.mean(kld_loss)
42
- return kld_loss
43
-
44
-
45
- class L2KLLoss(nn.Module):
46
- def __init__(self, kl_tolerance):
47
- super(L2KLLoss, self).__init__()
48
- self.kl_tolerance = kl_tolerance
49
-
50
- def forward(self, x):
51
- # TODO: check
52
- kld_loss = torch.sum(x ** 2, dim=1)
53
- if self.kl_tolerance is not None:
54
- above_line = kld_loss[kld_loss > self.kl_tolerance]
55
- if len(above_line) > 0:
56
- kld_loss = torch.mean(kld_loss)
57
- else:
58
- kld_loss = 0
59
- else:
60
- kld_loss = torch.mean(kld_loss)
61
- return kld_loss
62
-
63
- class L2RegLoss(nn.Module):
64
- def __init__(self):
65
- super(L2RegLoss, self).__init__()
66
-
67
- def forward(self, x):
68
- #TODO: check
69
- return torch.sum(x**2)
70
-
71
-
72
- class L2Loss(nn.Module):
73
- def __init__(self):
74
- super(L2Loss, self).__init__()
75
-
76
- def forward(self, x):
77
- # TODO: check
78
- return torch.sum(x ** 2)
79
-
80
-
81
- class AudioLoss(nn.Module):
82
- def __init__(self):
83
- super(AudioLoss, self).__init__()
84
-
85
- def forward(self, dynamics, gt_poses):
86
- #pay attention, normalized
87
- mean = torch.mean(gt_poses, dim=-1).unsqueeze(-1)
88
- gt = gt_poses - mean
89
- return F.mse_loss(dynamics, gt)
90
-
91
- L1Loss = nn.L1Loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nets/LS3DCG.py DELETED
@@ -1,414 +0,0 @@
1
- '''
2
- not exactly the same as the official repo but the results are good
3
- '''
4
- import sys
5
- import os
6
-
7
- from data_utils.lower_body import c_index_3d, c_index_6d
8
-
9
- sys.path.append(os.getcwd())
10
-
11
- import numpy as np
12
- import torch
13
- import torch.nn as nn
14
- import torch.optim as optim
15
- import torch.nn.functional as F
16
- import math
17
-
18
- from nets.base import TrainWrapperBaseClass
19
- from nets.layers import SeqEncoder1D
20
- from losses import KeypointLoss, L1Loss, KLLoss
21
- from data_utils.utils import get_melspec, get_mfcc_psf, get_mfcc_ta
22
- from nets.utils import denormalize
23
-
24
- class Conv1d_tf(nn.Conv1d):
25
- """
26
- Conv1d with the padding behavior from TF
27
- modified from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py
28
- """
29
-
30
- def __init__(self, *args, **kwargs):
31
- super(Conv1d_tf, self).__init__(*args, **kwargs)
32
- self.padding = kwargs.get("padding", "same")
33
-
34
- def _compute_padding(self, input, dim):
35
- input_size = input.size(dim + 2)
36
- filter_size = self.weight.size(dim + 2)
37
- effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1
38
- out_size = (input_size + self.stride[dim] - 1) // self.stride[dim]
39
- total_padding = max(
40
- 0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size
41
- )
42
- additional_padding = int(total_padding % 2 != 0)
43
-
44
- return additional_padding, total_padding
45
-
46
- def forward(self, input):
47
- if self.padding == "VALID":
48
- return F.conv1d(
49
- input,
50
- self.weight,
51
- self.bias,
52
- self.stride,
53
- padding=0,
54
- dilation=self.dilation,
55
- groups=self.groups,
56
- )
57
- rows_odd, padding_rows = self._compute_padding(input, dim=0)
58
- if rows_odd:
59
- input = F.pad(input, [0, rows_odd])
60
-
61
- return F.conv1d(
62
- input,
63
- self.weight,
64
- self.bias,
65
- self.stride,
66
- padding=(padding_rows // 2),
67
- dilation=self.dilation,
68
- groups=self.groups,
69
- )
70
-
71
-
72
- def ConvNormRelu(in_channels, out_channels, type='1d', downsample=False, k=None, s=None, norm='bn', padding='valid'):
73
- if k is None and s is None:
74
- if not downsample:
75
- k = 3
76
- s = 1
77
- else:
78
- k = 4
79
- s = 2
80
-
81
- if type == '1d':
82
- conv_block = Conv1d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding)
83
- if norm == 'bn':
84
- norm_block = nn.BatchNorm1d(out_channels)
85
- elif norm == 'ln':
86
- norm_block = nn.LayerNorm(out_channels)
87
- elif type == '2d':
88
- conv_block = Conv2d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding)
89
- norm_block = nn.BatchNorm2d(out_channels)
90
- else:
91
- assert False
92
-
93
- return nn.Sequential(
94
- conv_block,
95
- norm_block,
96
- nn.LeakyReLU(0.2, True)
97
- )
98
-
99
- class Decoder(nn.Module):
100
- def __init__(self, in_ch, out_ch):
101
- super(Decoder, self).__init__()
102
- self.up1 = nn.Sequential(
103
- ConvNormRelu(in_ch // 2 + in_ch, in_ch // 2),
104
- ConvNormRelu(in_ch // 2, in_ch // 2),
105
- nn.Upsample(scale_factor=2, mode='nearest')
106
- )
107
- self.up2 = nn.Sequential(
108
- ConvNormRelu(in_ch // 4 + in_ch // 2, in_ch // 4),
109
- ConvNormRelu(in_ch // 4, in_ch // 4),
110
- nn.Upsample(scale_factor=2, mode='nearest')
111
- )
112
- self.up3 = nn.Sequential(
113
- ConvNormRelu(in_ch // 8 + in_ch // 4, in_ch // 8),
114
- ConvNormRelu(in_ch // 8, in_ch // 8),
115
- nn.Conv1d(in_ch // 8, out_ch, 1, 1)
116
- )
117
-
118
- def forward(self, x, x1, x2, x3):
119
- x = F.interpolate(x, x3.shape[2])
120
- x = torch.cat([x, x3], dim=1)
121
- x = self.up1(x)
122
- x = F.interpolate(x, x2.shape[2])
123
- x = torch.cat([x, x2], dim=1)
124
- x = self.up2(x)
125
- x = F.interpolate(x, x1.shape[2])
126
- x = torch.cat([x, x1], dim=1)
127
- x = self.up3(x)
128
- return x
129
-
130
-
131
- class EncoderDecoder(nn.Module):
132
- def __init__(self, n_frames, each_dim):
133
- super().__init__()
134
- self.n_frames = n_frames
135
-
136
- self.down1 = nn.Sequential(
137
- ConvNormRelu(64, 64, '1d', False),
138
- ConvNormRelu(64, 128, '1d', False),
139
- )
140
- self.down2 = nn.Sequential(
141
- ConvNormRelu(128, 128, '1d', False),
142
- ConvNormRelu(128, 256, '1d', False),
143
- )
144
- self.down3 = nn.Sequential(
145
- ConvNormRelu(256, 256, '1d', False),
146
- ConvNormRelu(256, 512, '1d', False),
147
- )
148
- self.down4 = nn.Sequential(
149
- ConvNormRelu(512, 512, '1d', False),
150
- ConvNormRelu(512, 1024, '1d', False),
151
- )
152
-
153
- self.down = nn.MaxPool1d(kernel_size=2)
154
- self.up = nn.Upsample(scale_factor=2, mode='nearest')
155
-
156
- self.face_decoder = Decoder(1024, each_dim[0] + each_dim[3])
157
- self.body_decoder = Decoder(1024, each_dim[1])
158
- self.hand_decoder = Decoder(1024, each_dim[2])
159
-
160
- def forward(self, spectrogram, time_steps=None):
161
- if time_steps is None:
162
- time_steps = self.n_frames
163
-
164
- x1 = self.down1(spectrogram)
165
- x = self.down(x1)
166
- x2 = self.down2(x)
167
- x = self.down(x2)
168
- x3 = self.down3(x)
169
- x = self.down(x3)
170
- x = self.down4(x)
171
- x = self.up(x)
172
-
173
- face = self.face_decoder(x, x1, x2, x3)
174
- body = self.body_decoder(x, x1, x2, x3)
175
- hand = self.hand_decoder(x, x1, x2, x3)
176
-
177
- return face, body, hand
178
-
179
-
180
- class Generator(nn.Module):
181
- def __init__(self,
182
- each_dim,
183
- training=False,
184
- device=None
185
- ):
186
- super().__init__()
187
-
188
- self.training = training
189
- self.device = device
190
-
191
- self.encoderdecoder = EncoderDecoder(15, each_dim)
192
-
193
- def forward(self, in_spec, time_steps=None):
194
- if time_steps is not None:
195
- self.gen_length = time_steps
196
-
197
- face, body, hand = self.encoderdecoder(in_spec)
198
- out = torch.cat([face, body, hand], dim=1)
199
- out = out.transpose(1, 2)
200
-
201
- return out
202
-
203
-
204
- class Discriminator(nn.Module):
205
- def __init__(self, input_dim):
206
- super().__init__()
207
- self.net = nn.Sequential(
208
- ConvNormRelu(input_dim, 128, '1d'),
209
- ConvNormRelu(128, 256, '1d'),
210
- nn.MaxPool1d(kernel_size=2),
211
- ConvNormRelu(256, 256, '1d'),
212
- ConvNormRelu(256, 512, '1d'),
213
- nn.MaxPool1d(kernel_size=2),
214
- ConvNormRelu(512, 512, '1d'),
215
- ConvNormRelu(512, 1024, '1d'),
216
- nn.MaxPool1d(kernel_size=2),
217
- nn.Conv1d(1024, 1, 1, 1),
218
- nn.Sigmoid()
219
- )
220
-
221
- def forward(self, x):
222
- x = x.transpose(1, 2)
223
-
224
- out = self.net(x)
225
- return out
226
-
227
-
228
- class TrainWrapper(TrainWrapperBaseClass):
229
- def __init__(self, args, config) -> None:
230
- self.args = args
231
- self.config = config
232
- self.device = torch.device(self.args.gpu)
233
- self.global_step = 0
234
- self.convert_to_6d = self.config.Data.pose.convert_to_6d
235
- self.init_params()
236
-
237
- self.generator = Generator(
238
- each_dim=self.each_dim,
239
- training=not self.args.infer,
240
- device=self.device,
241
- ).to(self.device)
242
- self.discriminator = Discriminator(
243
- input_dim=self.each_dim[1] + self.each_dim[2] + 64
244
- ).to(self.device)
245
- if self.convert_to_6d:
246
- self.c_index = c_index_6d
247
- else:
248
- self.c_index = c_index_3d
249
- self.MSELoss = KeypointLoss().to(self.device)
250
- self.L1Loss = L1Loss().to(self.device)
251
- super().__init__(args, config)
252
-
253
- def init_params(self):
254
- scale = 1
255
-
256
- global_orient = round(0 * scale)
257
- leye_pose = reye_pose = round(0 * scale)
258
- jaw_pose = round(3 * scale)
259
- body_pose = round((63 - 24) * scale)
260
- left_hand_pose = right_hand_pose = round(45 * scale)
261
-
262
- expression = 100
263
-
264
- b_j = 0
265
- jaw_dim = jaw_pose
266
- b_e = b_j + jaw_dim
267
- eye_dim = leye_pose + reye_pose
268
- b_b = b_e + eye_dim
269
- body_dim = global_orient + body_pose
270
- b_h = b_b + body_dim
271
- hand_dim = left_hand_pose + right_hand_pose
272
- b_f = b_h + hand_dim
273
- face_dim = expression
274
-
275
- self.dim_list = [b_j, b_e, b_b, b_h, b_f]
276
- self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
277
- self.pose = int(self.full_dim / round(3 * scale))
278
- self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
279
-
280
- def __call__(self, bat):
281
- assert (not self.args.infer), "infer mode"
282
- self.global_step += 1
283
-
284
- loss_dict = {}
285
-
286
- aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
287
- expression = bat['expression'].to(self.device).to(torch.float32)
288
- jaw = poses[:, :3, :]
289
- poses = poses[:, self.c_index, :]
290
-
291
- pred = self.generator(in_spec=aud)
292
-
293
- D_loss, D_loss_dict = self.get_loss(
294
- pred_poses=pred.detach(),
295
- gt_poses=poses,
296
- aud=aud,
297
- mode='training_D',
298
- )
299
-
300
- self.discriminator_optimizer.zero_grad()
301
- D_loss.backward()
302
- self.discriminator_optimizer.step()
303
-
304
- G_loss, G_loss_dict = self.get_loss(
305
- pred_poses=pred,
306
- gt_poses=poses,
307
- aud=aud,
308
- expression=expression,
309
- jaw=jaw,
310
- mode='training_G',
311
- )
312
- self.generator_optimizer.zero_grad()
313
- G_loss.backward()
314
- self.generator_optimizer.step()
315
-
316
- total_loss = None
317
- loss_dict = {}
318
- for key in list(D_loss_dict.keys()) + list(G_loss_dict.keys()):
319
- loss_dict[key] = G_loss_dict.get(key, 0) + D_loss_dict.get(key, 0)
320
-
321
- return total_loss, loss_dict
322
-
323
- def get_loss(self,
324
- pred_poses,
325
- gt_poses,
326
- aud=None,
327
- jaw=None,
328
- expression=None,
329
- mode='training_G',
330
- ):
331
- loss_dict = {}
332
- aud = aud.transpose(1, 2)
333
- gt_poses = gt_poses.transpose(1, 2)
334
- gt_aud = torch.cat([gt_poses, aud], dim=2)
335
- pred_aud = torch.cat([pred_poses[:, :, 103:], aud], dim=2)
336
-
337
- if mode == 'training_D':
338
- dis_real = self.discriminator(gt_aud)
339
- dis_fake = self.discriminator(pred_aud)
340
- dis_error = self.MSELoss(torch.ones_like(dis_real).to(self.device), dis_real) + self.MSELoss(
341
- torch.zeros_like(dis_fake).to(self.device), dis_fake)
342
- loss_dict['dis'] = dis_error
343
-
344
- return dis_error, loss_dict
345
- elif mode == 'training_G':
346
- jaw_loss = self.L1Loss(pred_poses[:, :, :3], jaw.transpose(1, 2))
347
- face_loss = self.MSELoss(pred_poses[:, :, 3:103], expression.transpose(1, 2))
348
- body_loss = self.L1Loss(pred_poses[:, :, 103:142], gt_poses[:, :, :39])
349
- hand_loss = self.L1Loss(pred_poses[:, :, 142:], gt_poses[:, :, 39:])
350
- l1_loss = jaw_loss + face_loss + body_loss + hand_loss
351
-
352
- dis_output = self.discriminator(pred_aud)
353
- gen_error = self.MSELoss(torch.ones_like(dis_output).to(self.device), dis_output)
354
- gen_loss = self.config.Train.weights.keypoint_loss_weight * l1_loss + self.config.Train.weights.gan_loss_weight * gen_error
355
-
356
- loss_dict['gen'] = gen_error
357
- loss_dict['jaw_loss'] = jaw_loss
358
- loss_dict['face_loss'] = face_loss
359
- loss_dict['body_loss'] = body_loss
360
- loss_dict['hand_loss'] = hand_loss
361
- return gen_loss, loss_dict
362
- else:
363
- raise ValueError(mode)
364
-
365
- def infer_on_audio(self, aud_fn, fps=30, initial_pose=None, norm_stats=None, id=None, B=1, **kwargs):
366
- output = []
367
- assert self.args.infer, "train mode"
368
- self.generator.eval()
369
-
370
- if self.config.Data.pose.normalization:
371
- assert norm_stats is not None
372
- data_mean = norm_stats[0]
373
- data_std = norm_stats[1]
374
-
375
- pre_length = self.config.Data.pose.pre_pose_length
376
- generate_length = self.config.Data.pose.generate_length
377
- # assert pre_length == initial_pose.shape[-1]
378
- # pre_poses = initial_pose.permute(0, 2, 1).to(self.device).to(torch.float32)
379
- # B = pre_poses.shape[0]
380
-
381
- aud_feat = get_mfcc_ta(aud_fn, sr=22000, fps=fps, smlpx=True, type='mfcc').transpose(1, 0)
382
- num_poses_to_generate = aud_feat.shape[-1]
383
- aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
384
- aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device)
385
-
386
- with torch.no_grad():
387
- pred_poses = self.generator(aud_feat)
388
- pred_poses = pred_poses.cpu().numpy()
389
- output = pred_poses.squeeze()
390
-
391
- return output
392
-
393
- def generate(self, aud, id):
394
- self.generator.eval()
395
- pred_poses = self.generator(aud)
396
- return pred_poses
397
-
398
-
399
- if __name__ == '__main__':
400
- from trainer.options import parse_args
401
-
402
- parser = parse_args()
403
- args = parser.parse_args(
404
- ['--exp_name', '0', '--data_root', '0', '--speakers', '0', '--pre_pose_length', '4', '--generate_length', '64',
405
- '--infer'])
406
-
407
- generator = TrainWrapper(args)
408
-
409
- aud_fn = '../sample_audio/jon.wav'
410
- initial_pose = torch.randn(64, 108, 4)
411
- norm_stats = (np.random.randn(108), np.random.randn(108))
412
- output = generator.infer_on_audio(aud_fn, initial_pose, norm_stats)
413
-
414
- print(output.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nets/__init__.py DELETED
@@ -1,8 +0,0 @@
1
- from .smplx_face import TrainWrapper as s2g_face
2
- from .smplx_body_vq import TrainWrapper as s2g_body_vq
3
- from .smplx_body_pixel import TrainWrapper as s2g_body_pixel
4
- from .body_ae import TrainWrapper as s2g_body_ae
5
- from .LS3DCG import TrainWrapper as LS3DCG
6
- from .base import TrainWrapperBaseClass
7
-
8
- from .utils import normalize, denormalize
 
 
 
 
 
 
 
 
 
nets/base.py DELETED
@@ -1,89 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.optim as optim
4
-
5
- class TrainWrapperBaseClass():
6
- def __init__(self, args, config) -> None:
7
- self.init_optimizer()
8
-
9
- def init_optimizer(self) -> None:
10
- print('using Adam')
11
- self.generator_optimizer = optim.Adam(
12
- self.generator.parameters(),
13
- lr = self.config.Train.learning_rate.generator_learning_rate,
14
- betas=[0.9, 0.999]
15
- )
16
- if self.discriminator is not None:
17
- self.discriminator_optimizer = optim.Adam(
18
- self.discriminator.parameters(),
19
- lr = self.config.Train.learning_rate.discriminator_learning_rate,
20
- betas=[0.9, 0.999]
21
- )
22
-
23
- def __call__(self, bat):
24
- raise NotImplementedError
25
-
26
- def get_loss(self, **kwargs):
27
- raise NotImplementedError
28
-
29
- def state_dict(self):
30
- model_state = {
31
- 'generator': self.generator.state_dict(),
32
- 'generator_optim': self.generator_optimizer.state_dict(),
33
- 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
34
- 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
35
- }
36
- return model_state
37
-
38
- def parameters(self):
39
- return self.generator.parameters()
40
-
41
- def load_state_dict(self, state_dict):
42
- if 'generator' in state_dict:
43
- self.generator.load_state_dict(state_dict['generator'])
44
- else:
45
- self.generator.load_state_dict(state_dict)
46
-
47
- if 'generator_optim' in state_dict and self.generator_optimizer is not None:
48
- self.generator_optimizer.load_state_dict(state_dict['generator_optim'])
49
-
50
- if self.discriminator is not None:
51
- self.discriminator.load_state_dict(state_dict['discriminator'])
52
-
53
- if 'discriminator_optim' in state_dict and self.discriminator_optimizer is not None:
54
- self.discriminator_optimizer.load_state_dict(state_dict['discriminator_optim'])
55
-
56
- def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, **kwargs):
57
- raise NotImplementedError
58
-
59
- def init_params(self):
60
- if self.config.Data.pose.convert_to_6d:
61
- scale = 2
62
- else:
63
- scale = 1
64
-
65
- global_orient = round(0 * scale)
66
- leye_pose = reye_pose = round(0 * scale)
67
- jaw_pose = round(0 * scale)
68
- body_pose = round((63 - 24) * scale)
69
- left_hand_pose = right_hand_pose = round(45 * scale)
70
- if self.expression:
71
- expression = 100
72
- else:
73
- expression = 0
74
-
75
- b_j = 0
76
- jaw_dim = jaw_pose
77
- b_e = b_j + jaw_dim
78
- eye_dim = leye_pose + reye_pose
79
- b_b = b_e + eye_dim
80
- body_dim = global_orient + body_pose
81
- b_h = b_b + body_dim
82
- hand_dim = left_hand_pose + right_hand_pose
83
- b_f = b_h + hand_dim
84
- face_dim = expression
85
-
86
- self.dim_list = [b_j, b_e, b_b, b_h, b_f]
87
- self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
88
- self.pose = int(self.full_dim / round(3 * scale))
89
- self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nets/body_ae.py DELETED
@@ -1,152 +0,0 @@
1
- import os
2
- import sys
3
-
4
- sys.path.append(os.getcwd())
5
-
6
- from nets.base import TrainWrapperBaseClass
7
- from nets.spg.s2glayers import Discriminator as D_S2G
8
- from nets.spg.vqvae_1d import AE as s2g_body
9
- import torch
10
- import torch.optim as optim
11
- import torch.nn.functional as F
12
-
13
- from data_utils.lower_body import c_index, c_index_3d, c_index_6d
14
-
15
-
16
- def separate_aa(aa):
17
- aa = aa[:, :, :].reshape(aa.shape[0], aa.shape[1], -1, 5)
18
- axis = F.normalize(aa[:, :, :, :3], dim=-1)
19
- angle = F.normalize(aa[:, :, :, 3:5], dim=-1)
20
- return axis, angle
21
-
22
-
23
- class TrainWrapper(TrainWrapperBaseClass):
24
- '''
25
- a wrapper receving a batch from data_utils and calculate loss
26
- '''
27
-
28
- def __init__(self, args, config):
29
- self.args = args
30
- self.config = config
31
- self.device = torch.device(self.args.gpu)
32
- self.global_step = 0
33
-
34
- self.gan = False
35
- self.convert_to_6d = self.config.Data.pose.convert_to_6d
36
- self.preleng = self.config.Data.pose.pre_pose_length
37
- self.expression = self.config.Data.pose.expression
38
- self.epoch = 0
39
- self.init_params()
40
- self.num_classes = 4
41
- self.g = s2g_body(self.each_dim[1] + self.each_dim[2], embedding_dim=64, num_embeddings=0,
42
- num_hiddens=1024, num_residual_layers=2, num_residual_hiddens=512).to(self.device)
43
- if self.gan:
44
- self.discriminator = D_S2G(
45
- pose_dim=110 + 64, pose=self.pose
46
- ).to(self.device)
47
- else:
48
- self.discriminator = None
49
-
50
- if self.convert_to_6d:
51
- self.c_index = c_index_6d
52
- else:
53
- self.c_index = c_index_3d
54
-
55
- super().__init__(args, config)
56
-
57
- def init_optimizer(self):
58
-
59
- self.g_optimizer = optim.Adam(
60
- self.g.parameters(),
61
- lr=self.config.Train.learning_rate.generator_learning_rate,
62
- betas=[0.9, 0.999]
63
- )
64
-
65
- def state_dict(self):
66
- model_state = {
67
- 'g': self.g.state_dict(),
68
- 'g_optim': self.g_optimizer.state_dict(),
69
- 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
70
- 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
71
- }
72
- return model_state
73
-
74
-
75
- def __call__(self, bat):
76
- # assert (not self.args.infer), "infer mode"
77
- self.global_step += 1
78
-
79
- total_loss = None
80
- loss_dict = {}
81
-
82
- aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
83
-
84
- # id = bat['speaker'].to(self.device) - 20
85
- # id = F.one_hot(id, self.num_classes)
86
-
87
- poses = poses[:, self.c_index, :]
88
- gt_poses = poses[:, :, self.preleng:].permute(0, 2, 1)
89
-
90
- loss = 0
91
- loss_dict, loss = self.vq_train(gt_poses[:, :], 'g', self.g, loss_dict, loss)
92
-
93
- return total_loss, loss_dict
94
-
95
- def vq_train(self, gt, name, model, dict, total_loss, pre=None):
96
- x_recon = model(gt_poses=gt, pre_state=pre)
97
- loss, loss_dict = self.get_loss(pred_poses=x_recon, gt_poses=gt, pre=pre)
98
- # total_loss = total_loss + loss
99
-
100
- if name == 'g':
101
- optimizer_name = 'g_optimizer'
102
-
103
- optimizer = getattr(self, optimizer_name)
104
- optimizer.zero_grad()
105
- loss.backward()
106
- optimizer.step()
107
-
108
- for key in list(loss_dict.keys()):
109
- dict[name + key] = loss_dict.get(key, 0).item()
110
- return dict, total_loss
111
-
112
- def get_loss(self,
113
- pred_poses,
114
- gt_poses,
115
- pre=None
116
- ):
117
- loss_dict = {}
118
-
119
-
120
- rec_loss = torch.mean(torch.abs(pred_poses - gt_poses))
121
- v_pr = pred_poses[:, 1:] - pred_poses[:, :-1]
122
- v_gt = gt_poses[:, 1:] - gt_poses[:, :-1]
123
- velocity_loss = torch.mean(torch.abs(v_pr - v_gt))
124
-
125
- if pre is None:
126
- f0_vel = 0
127
- else:
128
- v0_pr = pred_poses[:, 0] - pre[:, -1]
129
- v0_gt = gt_poses[:, 0] - pre[:, -1]
130
- f0_vel = torch.mean(torch.abs(v0_pr - v0_gt))
131
-
132
- gen_loss = rec_loss + velocity_loss + f0_vel
133
-
134
- loss_dict['rec_loss'] = rec_loss
135
- loss_dict['velocity_loss'] = velocity_loss
136
- # loss_dict['e_q_loss'] = e_q_loss
137
- if pre is not None:
138
- loss_dict['f0_vel'] = f0_vel
139
-
140
- return gen_loss, loss_dict
141
-
142
- def load_state_dict(self, state_dict):
143
- self.g.load_state_dict(state_dict['g'])
144
-
145
- def extract(self, x):
146
- self.g.eval()
147
- if x.shape[2] > self.full_dim:
148
- if x.shape[2] == 239:
149
- x = x[:, :, 102:]
150
- x = x[:, :, self.c_index]
151
- feat = self.g.encode(x)
152
- return feat.transpose(1, 2), x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nets/init_model.py DELETED
@@ -1,35 +0,0 @@
1
- from nets import *
2
-
3
-
4
- def init_model(model_name, args, config):
5
-
6
- if model_name == 's2g_face':
7
- generator = s2g_face(
8
- args,
9
- config,
10
- )
11
- elif model_name == 's2g_body_vq':
12
- generator = s2g_body_vq(
13
- args,
14
- config,
15
- )
16
- elif model_name == 's2g_body_pixel':
17
- generator = s2g_body_pixel(
18
- args,
19
- config,
20
- )
21
- elif model_name == 's2g_body_ae':
22
- generator = s2g_body_ae(
23
- args,
24
- config,
25
- )
26
- elif model_name == 's2g_LS3DCG':
27
- generator = LS3DCG(
28
- args,
29
- config,
30
- )
31
- else:
32
- raise ValueError
33
- return generator
34
-
35
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nets/layers.py DELETED
@@ -1,1052 +0,0 @@
1
- import os
2
- import sys
3
-
4
- sys.path.append(os.getcwd())
5
-
6
- import torch
7
- import torch.nn as nn
8
- import numpy as np
9
-
10
-
11
- # TODO: be aware of the actual netork structures
12
-
13
- def get_log(x):
14
- log = 0
15
- while x > 1:
16
- if x % 2 == 0:
17
- x = x // 2
18
- log += 1
19
- else:
20
- raise ValueError('x is not a power of 2')
21
-
22
- return log
23
-
24
-
25
- class ConvNormRelu(nn.Module):
26
- '''
27
- (B,C_in,H,W) -> (B, C_out, H, W)
28
- there exist some kernel size that makes the result is not H/s
29
- #TODO: there might some problems with residual
30
- '''
31
-
32
- def __init__(self,
33
- in_channels,
34
- out_channels,
35
- type='1d',
36
- leaky=False,
37
- downsample=False,
38
- kernel_size=None,
39
- stride=None,
40
- padding=None,
41
- p=0,
42
- groups=1,
43
- residual=False,
44
- norm='bn'):
45
- '''
46
- conv-bn-relu
47
- '''
48
- super(ConvNormRelu, self).__init__()
49
- self.residual = residual
50
- self.norm_type = norm
51
- # kernel_size = k
52
- # stride = s
53
-
54
- if kernel_size is None and stride is None:
55
- if not downsample:
56
- kernel_size = 3
57
- stride = 1
58
- else:
59
- kernel_size = 4
60
- stride = 2
61
-
62
- if padding is None:
63
- if isinstance(kernel_size, int) and isinstance(stride, tuple):
64
- padding = tuple(int((kernel_size - st) / 2) for st in stride)
65
- elif isinstance(kernel_size, tuple) and isinstance(stride, int):
66
- padding = tuple(int((ks - stride) / 2) for ks in kernel_size)
67
- elif isinstance(kernel_size, tuple) and isinstance(stride, tuple):
68
- padding = tuple(int((ks - st) / 2) for ks, st in zip(kernel_size, stride))
69
- else:
70
- padding = int((kernel_size - stride) / 2)
71
-
72
- if self.residual:
73
- if downsample:
74
- if type == '1d':
75
- self.residual_layer = nn.Sequential(
76
- nn.Conv1d(
77
- in_channels=in_channels,
78
- out_channels=out_channels,
79
- kernel_size=kernel_size,
80
- stride=stride,
81
- padding=padding
82
- )
83
- )
84
- elif type == '2d':
85
- self.residual_layer = nn.Sequential(
86
- nn.Conv2d(
87
- in_channels=in_channels,
88
- out_channels=out_channels,
89
- kernel_size=kernel_size,
90
- stride=stride,
91
- padding=padding
92
- )
93
- )
94
- else:
95
- if in_channels == out_channels:
96
- self.residual_layer = nn.Identity()
97
- else:
98
- if type == '1d':
99
- self.residual_layer = nn.Sequential(
100
- nn.Conv1d(
101
- in_channels=in_channels,
102
- out_channels=out_channels,
103
- kernel_size=kernel_size,
104
- stride=stride,
105
- padding=padding
106
- )
107
- )
108
- elif type == '2d':
109
- self.residual_layer = nn.Sequential(
110
- nn.Conv2d(
111
- in_channels=in_channels,
112
- out_channels=out_channels,
113
- kernel_size=kernel_size,
114
- stride=stride,
115
- padding=padding
116
- )
117
- )
118
-
119
- in_channels = in_channels * groups
120
- out_channels = out_channels * groups
121
- if type == '1d':
122
- self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
123
- kernel_size=kernel_size, stride=stride, padding=padding,
124
- groups=groups)
125
- self.norm = nn.BatchNorm1d(out_channels)
126
- self.dropout = nn.Dropout(p=p)
127
- elif type == '2d':
128
- self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
129
- kernel_size=kernel_size, stride=stride, padding=padding,
130
- groups=groups)
131
- self.norm = nn.BatchNorm2d(out_channels)
132
- self.dropout = nn.Dropout2d(p=p)
133
- if norm == 'gn':
134
- self.norm = nn.GroupNorm(2, out_channels)
135
- elif norm == 'ln':
136
- self.norm = nn.LayerNorm(out_channels)
137
- if leaky:
138
- self.relu = nn.LeakyReLU(negative_slope=0.2)
139
- else:
140
- self.relu = nn.ReLU()
141
-
142
- def forward(self, x, **kwargs):
143
- if self.norm_type == 'ln':
144
- out = self.dropout(self.conv(x))
145
- out = self.norm(out.transpose(1,2)).transpose(1,2)
146
- else:
147
- out = self.norm(self.dropout(self.conv(x)))
148
- if self.residual:
149
- residual = self.residual_layer(x)
150
- out += residual
151
- return self.relu(out)
152
-
153
-
154
- class UNet1D(nn.Module):
155
- def __init__(self,
156
- input_channels,
157
- output_channels,
158
- max_depth=5,
159
- kernel_size=None,
160
- stride=None,
161
- p=0,
162
- groups=1):
163
- super(UNet1D, self).__init__()
164
- self.pre_downsampling_conv = nn.ModuleList([])
165
- self.conv1 = nn.ModuleList([])
166
- self.conv2 = nn.ModuleList([])
167
- self.upconv = nn.Upsample(scale_factor=2, mode='nearest')
168
- self.max_depth = max_depth
169
- self.groups = groups
170
-
171
- self.pre_downsampling_conv.append(ConvNormRelu(input_channels, output_channels,
172
- type='1d', leaky=True, downsample=False,
173
- kernel_size=kernel_size, stride=stride, p=p, groups=groups))
174
- self.pre_downsampling_conv.append(ConvNormRelu(output_channels, output_channels,
175
- type='1d', leaky=True, downsample=False,
176
- kernel_size=kernel_size, stride=stride, p=p, groups=groups))
177
-
178
- for i in range(self.max_depth):
179
- self.conv1.append(ConvNormRelu(output_channels, output_channels,
180
- type='1d', leaky=True, downsample=True,
181
- kernel_size=kernel_size, stride=stride, p=p, groups=groups))
182
-
183
- for i in range(self.max_depth):
184
- self.conv2.append(ConvNormRelu(output_channels, output_channels,
185
- type='1d', leaky=True, downsample=False,
186
- kernel_size=kernel_size, stride=stride, p=p, groups=groups))
187
-
188
- def forward(self, x):
189
-
190
- input_size = x.shape[-1]
191
-
192
- assert get_log(
193
- input_size) >= self.max_depth, 'num_frames must be a power of 2 and its power must be greater than max_depth'
194
-
195
- x = nn.Sequential(*self.pre_downsampling_conv)(x)
196
-
197
- residuals = []
198
- residuals.append(x)
199
- for i, conv1 in enumerate(self.conv1):
200
- x = conv1(x)
201
- if i < self.max_depth - 1:
202
- residuals.append(x)
203
-
204
- for i, conv2 in enumerate(self.conv2):
205
- x = self.upconv(x) + residuals[self.max_depth - i - 1]
206
- x = conv2(x)
207
-
208
- return x
209
-
210
-
211
- class UNet2D(nn.Module):
212
- def __init__(self):
213
- super(UNet2D, self).__init__()
214
- raise NotImplementedError('2D Unet is wierd')
215
-
216
-
217
- class AudioPoseEncoder1D(nn.Module):
218
- '''
219
- (B, C, T) -> (B, C*2, T) -> ... -> (B, C_out, T)
220
- '''
221
-
222
- def __init__(self,
223
- C_in,
224
- C_out,
225
- kernel_size=None,
226
- stride=None,
227
- min_layer_nums=None
228
- ):
229
- super(AudioPoseEncoder1D, self).__init__()
230
- self.C_in = C_in
231
- self.C_out = C_out
232
-
233
- conv_layers = nn.ModuleList([])
234
- cur_C = C_in
235
- num_layers = 0
236
- while cur_C < self.C_out:
237
- conv_layers.append(ConvNormRelu(
238
- in_channels=cur_C,
239
- out_channels=cur_C * 2,
240
- kernel_size=kernel_size,
241
- stride=stride
242
- ))
243
- cur_C *= 2
244
- num_layers += 1
245
-
246
- if (cur_C != C_out) or (min_layer_nums is not None and num_layers < min_layer_nums):
247
- while (cur_C != C_out) or num_layers < min_layer_nums:
248
- conv_layers.append(ConvNormRelu(
249
- in_channels=cur_C,
250
- out_channels=C_out,
251
- kernel_size=kernel_size,
252
- stride=stride
253
- ))
254
- num_layers += 1
255
- cur_C = C_out
256
-
257
- self.conv_layers = nn.Sequential(*conv_layers)
258
-
259
- def forward(self, x):
260
- '''
261
- x: (B, C, T)
262
- '''
263
- x = self.conv_layers(x)
264
- return x
265
-
266
-
267
- class AudioPoseEncoder2D(nn.Module):
268
- '''
269
- (B, C, T) -> (B, 1, C, T) -> ... -> (B, C_out, T)
270
- '''
271
-
272
- def __init__(self):
273
- raise NotImplementedError
274
-
275
-
276
- class AudioPoseEncoderRNN(nn.Module):
277
- '''
278
- (B, C, T)->(B, T, C)->(B, T, C_out)->(B, C_out, T)
279
- '''
280
-
281
- def __init__(self,
282
- C_in,
283
- hidden_size,
284
- num_layers,
285
- rnn_cell='gru',
286
- bidirectional=False
287
- ):
288
- super(AudioPoseEncoderRNN, self).__init__()
289
- if rnn_cell == 'gru':
290
- self.cell = nn.GRU(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
291
- bidirectional=bidirectional)
292
- elif rnn_cell == 'lstm':
293
- self.cell = nn.LSTM(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
294
- bidirectional=bidirectional)
295
- else:
296
- raise ValueError('invalid rnn cell:%s' % (rnn_cell))
297
-
298
- def forward(self, x, state=None):
299
-
300
- x = x.permute(0, 2, 1)
301
- x, state = self.cell(x, state)
302
- x = x.permute(0, 2, 1)
303
-
304
- return x
305
-
306
-
307
- class AudioPoseEncoderGraph(nn.Module):
308
- '''
309
- (B, C, T)->(B, 2, V, T)->(B, 2, T, V)->(B, D, T, V)
310
- '''
311
-
312
- def __init__(self,
313
- layers_config, # 理应是(C_in, C_out, kernel_size)的list
314
- A, # adjacent matrix (num_parts, V, V)
315
- residual,
316
- local_bn=False,
317
- share_weights=False
318
- ) -> None:
319
- super().__init__()
320
- self.A = A
321
- self.num_joints = A.shape[1]
322
- self.num_parts = A.shape[0]
323
- self.C_in = layers_config[0][0]
324
- self.C_out = layers_config[-1][1]
325
-
326
- self.conv_layers = nn.ModuleList([
327
- GraphConvNormRelu(
328
- C_in=c_in,
329
- C_out=c_out,
330
- A=self.A,
331
- residual=residual,
332
- local_bn=local_bn,
333
- kernel_size=k,
334
- share_weights=share_weights
335
- ) for (c_in, c_out, k) in layers_config
336
- ])
337
-
338
- self.conv_layers = nn.Sequential(*self.conv_layers)
339
-
340
- def forward(self, x):
341
- '''
342
- x: (B, C, T), C should be num_joints*D
343
- output: (B, D, T, V)
344
- '''
345
- B, C, T = x.shape
346
- x = x.view(B, self.num_joints, self.C_in, T) # (B, V, D, T),D:每个joint的特征维度,注意这里V在前面
347
- x = x.permute(0, 2, 3, 1) # (B, D, T, V)
348
- assert x.shape[1] == self.C_in
349
-
350
- x_conved = self.conv_layers(x)
351
-
352
- # x_conved = x_conved.permute(0, 3, 1, 2).contiguous().view(B, self.C_out*self.num_joints, T)#(B, V*C_out, T)
353
-
354
- return x_conved
355
-
356
-
357
- class SeqEncoder2D(nn.Module):
358
- '''
359
- seq_encoder, encoding a seq to a vector
360
- (B, C, T)->(B, 2, V, T)->(B, 2, T, V) -> (B, 32, )->...->(B, C_out)
361
- '''
362
-
363
- def __init__(self,
364
- C_in, # should be 2
365
- T_in,
366
- C_out,
367
- num_joints,
368
- min_layer_num=None,
369
- residual=False
370
- ):
371
- super(SeqEncoder2D, self).__init__()
372
- self.C_in = C_in
373
- self.C_out = C_out
374
- self.T_in = T_in
375
- self.num_joints = num_joints
376
-
377
- conv_layers = nn.ModuleList([])
378
- conv_layers.append(ConvNormRelu(
379
- in_channels=C_in,
380
- out_channels=32,
381
- type='2d',
382
- residual=residual
383
- ))
384
-
385
- cur_C = 32
386
- cur_H = T_in
387
- cur_W = num_joints
388
- num_layers = 1
389
- while (cur_C < C_out) or (cur_H > 1) or (cur_W > 1):
390
- ks = [3, 3]
391
- st = [1, 1]
392
-
393
- if cur_H > 1:
394
- if cur_H > 4:
395
- ks[0] = 4
396
- st[0] = 2
397
- else:
398
- ks[0] = cur_H
399
- st[0] = cur_H
400
- if cur_W > 1:
401
- if cur_W > 4:
402
- ks[1] = 4
403
- st[1] = 2
404
- else:
405
- ks[1] = cur_W
406
- st[1] = cur_W
407
-
408
- conv_layers.append(ConvNormRelu(
409
- in_channels=cur_C,
410
- out_channels=min(C_out, cur_C * 2),
411
- type='2d',
412
- kernel_size=tuple(ks),
413
- stride=tuple(st),
414
- residual=residual
415
- ))
416
- cur_C = min(cur_C * 2, C_out)
417
- if cur_H > 1:
418
- if cur_H > 4:
419
- cur_H //= 2
420
- else:
421
- cur_H = 1
422
- if cur_W > 1:
423
- if cur_W > 4:
424
- cur_W //= 2
425
- else:
426
- cur_W = 1
427
- num_layers += 1
428
-
429
- if min_layer_num is not None and (num_layers < min_layer_num):
430
- while num_layers < min_layer_num:
431
- conv_layers.append(ConvNormRelu(
432
- in_channels=C_out,
433
- out_channels=C_out,
434
- type='2d',
435
- kernel_size=1,
436
- stride=1,
437
- residual=residual
438
- ))
439
- num_layers += 1
440
-
441
- self.conv_layers = nn.Sequential(*conv_layers)
442
- self.num_layers = num_layers
443
-
444
- def forward(self, x):
445
- B, C, T = x.shape
446
- x = x.view(B, self.num_joints, self.C_in, T) # (B, V, D, T) V in front
447
- x = x.permute(0, 2, 3, 1) # (B, D, T, V)
448
- assert x.shape[1] == self.C_in and x.shape[-1] == self.num_joints
449
-
450
- x = self.conv_layers(x)
451
- return x.squeeze()
452
-
453
-
454
- class SeqEncoder1D(nn.Module):
455
- '''
456
- (B, C, T)->(B, D)
457
- '''
458
-
459
- def __init__(self,
460
- C_in,
461
- C_out,
462
- T_in,
463
- min_layer_nums=None
464
- ):
465
- super(SeqEncoder1D, self).__init__()
466
- conv_layers = nn.ModuleList([])
467
- cur_C = C_in
468
- cur_T = T_in
469
- self.num_layers = 0
470
- while (cur_C < C_out) or (cur_T > 1):
471
- ks = 3
472
- st = 1
473
- if cur_T > 1:
474
- if cur_T > 4:
475
- ks = 4
476
- st = 2
477
- else:
478
- ks = cur_T
479
- st = cur_T
480
-
481
- conv_layers.append(ConvNormRelu(
482
- in_channels=cur_C,
483
- out_channels=min(C_out, cur_C * 2),
484
- type='1d',
485
- kernel_size=ks,
486
- stride=st
487
- ))
488
- cur_C = min(cur_C * 2, C_out)
489
- if cur_T > 1:
490
- if cur_T > 4:
491
- cur_T = cur_T // 2
492
- else:
493
- cur_T = 1
494
- self.num_layers += 1
495
-
496
- if min_layer_nums is not None and (self.num_layers < min_layer_nums):
497
- while self.num_layers < min_layer_nums:
498
- conv_layers.append(ConvNormRelu(
499
- in_channels=C_out,
500
- out_channels=C_out,
501
- type='1d',
502
- kernel_size=1,
503
- stride=1
504
- ))
505
- self.num_layers += 1
506
- self.conv_layers = nn.Sequential(*conv_layers)
507
-
508
- def forward(self, x):
509
- x = self.conv_layers(x)
510
- return x.squeeze()
511
-
512
-
513
- class SeqEncoderRNN(nn.Module):
514
- '''
515
- (B, C, T) -> (B, T, C) -> (B, D)
516
- LSTM/GRU-FC
517
- '''
518
-
519
- def __init__(self,
520
- hidden_size,
521
- in_size,
522
- num_rnn_layers,
523
- rnn_cell='gru',
524
- bidirectional=False
525
- ):
526
- super(SeqEncoderRNN, self).__init__()
527
- self.hidden_size = hidden_size
528
- self.in_size = in_size
529
- self.num_rnn_layers = num_rnn_layers
530
- self.bidirectional = bidirectional
531
-
532
- if rnn_cell == 'gru':
533
- self.cell = nn.GRU(input_size=self.in_size, hidden_size=self.hidden_size, num_layers=self.num_rnn_layers,
534
- batch_first=True, bidirectional=bidirectional)
535
- elif rnn_cell == 'lstm':
536
- self.cell = nn.LSTM(input_size=self.in_size, hidden_size=self.hidden_size, num_layers=self.num_rnn_layers,
537
- batch_first=True, bidirectional=bidirectional)
538
-
539
- def forward(self, x, state=None):
540
-
541
- x = x.permute(0, 2, 1)
542
- B, T, C = x.shape
543
- x, _ = self.cell(x, state)
544
- if self.bidirectional:
545
- out = torch.cat([x[:, -1, :self.hidden_size], x[:, 0, self.hidden_size:]], dim=-1)
546
- else:
547
- out = x[:, -1, :]
548
- assert out.shape[0] == B
549
- return out
550
-
551
-
552
- class SeqEncoderGraph(nn.Module):
553
- '''
554
- '''
555
-
556
- def __init__(self,
557
- embedding_size,
558
- layer_configs,
559
- residual,
560
- local_bn,
561
- A,
562
- T,
563
- share_weights=False
564
- ) -> None:
565
- super().__init__()
566
-
567
- self.C_in = layer_configs[0][0]
568
- self.C_out = embedding_size
569
-
570
- self.num_joints = A.shape[1]
571
-
572
- self.graph_encoder = AudioPoseEncoderGraph(
573
- layers_config=layer_configs,
574
- A=A,
575
- residual=residual,
576
- local_bn=local_bn,
577
- share_weights=share_weights
578
- )
579
-
580
- cur_C = layer_configs[-1][1]
581
- self.spatial_pool = ConvNormRelu(
582
- in_channels=cur_C,
583
- out_channels=cur_C,
584
- type='2d',
585
- kernel_size=(1, self.num_joints),
586
- stride=(1, 1),
587
- padding=(0, 0)
588
- )
589
-
590
- temporal_pool = nn.ModuleList([])
591
- cur_H = T
592
- num_layers = 0
593
- self.temporal_conv_info = []
594
- while cur_C < self.C_out or cur_H > 1:
595
- self.temporal_conv_info.append(cur_C)
596
- ks = [3, 1]
597
- st = [1, 1]
598
-
599
- if cur_H > 1:
600
- if cur_H > 4:
601
- ks[0] = 4
602
- st[0] = 2
603
- else:
604
- ks[0] = cur_H
605
- st[0] = cur_H
606
-
607
- temporal_pool.append(ConvNormRelu(
608
- in_channels=cur_C,
609
- out_channels=min(self.C_out, cur_C * 2),
610
- type='2d',
611
- kernel_size=tuple(ks),
612
- stride=tuple(st)
613
- ))
614
- cur_C = min(cur_C * 2, self.C_out)
615
-
616
- if cur_H > 1:
617
- if cur_H > 4:
618
- cur_H //= 2
619
- else:
620
- cur_H = 1
621
-
622
- num_layers += 1
623
-
624
- self.temporal_pool = nn.Sequential(*temporal_pool)
625
- print("graph seq encoder info: temporal pool:", self.temporal_conv_info)
626
- self.num_layers = num_layers
627
- # need fc?
628
-
629
- def forward(self, x):
630
- '''
631
- x: (B, C, T)
632
- '''
633
- B, C, T = x.shape
634
- x = self.graph_encoder(x)
635
- x = self.spatial_pool(x)
636
- x = self.temporal_pool(x)
637
- x = x.view(B, self.C_out)
638
-
639
- return x
640
-
641
-
642
- class SeqDecoder2D(nn.Module):
643
- '''
644
- (B, D)->(B, D, 1, 1)->(B, C_out, C, T)->(B, C_out, T)
645
- '''
646
-
647
- def __init__(self):
648
- super(SeqDecoder2D, self).__init__()
649
- raise NotImplementedError
650
-
651
-
652
- class SeqDecoder1D(nn.Module):
653
- '''
654
- (B, D)->(B, D, 1)->...->(B, C_out, T)
655
- '''
656
-
657
- def __init__(self,
658
- D_in,
659
- C_out,
660
- T_out,
661
- min_layer_num=None
662
- ):
663
- super(SeqDecoder1D, self).__init__()
664
- self.T_out = T_out
665
- self.min_layer_num = min_layer_num
666
-
667
- cur_t = 1
668
-
669
- self.pre_conv = ConvNormRelu(
670
- in_channels=D_in,
671
- out_channels=C_out,
672
- type='1d'
673
- )
674
- self.num_layers = 1
675
- self.upconv = nn.Upsample(scale_factor=2, mode='nearest')
676
- self.conv_layers = nn.ModuleList([])
677
- cur_t *= 2
678
- while cur_t <= T_out:
679
- self.conv_layers.append(ConvNormRelu(
680
- in_channels=C_out,
681
- out_channels=C_out,
682
- type='1d'
683
- ))
684
- cur_t *= 2
685
- self.num_layers += 1
686
-
687
- post_conv = nn.ModuleList([ConvNormRelu(
688
- in_channels=C_out,
689
- out_channels=C_out,
690
- type='1d'
691
- )])
692
- self.num_layers += 1
693
- if min_layer_num is not None and self.num_layers < min_layer_num:
694
- while self.num_layers < min_layer_num:
695
- post_conv.append(ConvNormRelu(
696
- in_channels=C_out,
697
- out_channels=C_out,
698
- type='1d'
699
- ))
700
- self.num_layers += 1
701
- self.post_conv = nn.Sequential(*post_conv)
702
-
703
- def forward(self, x):
704
-
705
- x = x.unsqueeze(-1)
706
- x = self.pre_conv(x)
707
- for conv in self.conv_layers:
708
- x = self.upconv(x)
709
- x = conv(x)
710
-
711
- x = torch.nn.functional.interpolate(x, size=self.T_out, mode='nearest')
712
- x = self.post_conv(x)
713
- return x
714
-
715
-
716
- class SeqDecoderRNN(nn.Module):
717
- '''
718
- (B, D)->(B, C_out, T)
719
- '''
720
-
721
- def __init__(self,
722
- hidden_size,
723
- C_out,
724
- T_out,
725
- num_layers,
726
- rnn_cell='gru'
727
- ):
728
- super(SeqDecoderRNN, self).__init__()
729
- self.num_steps = T_out
730
- if rnn_cell == 'gru':
731
- self.cell = nn.GRU(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
732
- bidirectional=False)
733
- elif rnn_cell == 'lstm':
734
- self.cell = nn.LSTM(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
735
- bidirectional=False)
736
- else:
737
- raise ValueError('invalid rnn cell:%s' % (rnn_cell))
738
-
739
- self.fc = nn.Linear(hidden_size, C_out)
740
-
741
- def forward(self, hidden, frame_0):
742
- frame_0 = frame_0.permute(0, 2, 1)
743
- dec_input = frame_0
744
- outputs = []
745
- for i in range(self.num_steps):
746
- frame_out, hidden = self.cell(dec_input, hidden)
747
- frame_out = self.fc(frame_out)
748
- dec_input = frame_out
749
- outputs.append(frame_out)
750
- output = torch.cat(outputs, dim=1)
751
- return output.permute(0, 2, 1)
752
-
753
-
754
- class SeqTranslator2D(nn.Module):
755
- '''
756
- (B, C, T)->(B, 1, C, T)-> ... -> (B, 1, C_out, T_out)
757
- '''
758
-
759
- def __init__(self,
760
- C_in=64,
761
- C_out=108,
762
- T_in=75,
763
- T_out=25,
764
- residual=True
765
- ):
766
- super(SeqTranslator2D, self).__init__()
767
- print("Warning: hard coded")
768
- self.C_in = C_in
769
- self.C_out = C_out
770
- self.T_in = T_in
771
- self.T_out = T_out
772
- self.residual = residual
773
-
774
- self.conv_layers = nn.Sequential(
775
- ConvNormRelu(1, 32, '2d', kernel_size=5, stride=1),
776
- ConvNormRelu(32, 32, '2d', kernel_size=5, stride=1, residual=self.residual),
777
- ConvNormRelu(32, 32, '2d', kernel_size=5, stride=1, residual=self.residual),
778
-
779
- ConvNormRelu(32, 64, '2d', kernel_size=5, stride=(4, 3)),
780
- ConvNormRelu(64, 64, '2d', kernel_size=5, stride=1, residual=self.residual),
781
- ConvNormRelu(64, 64, '2d', kernel_size=5, stride=1, residual=self.residual),
782
-
783
- ConvNormRelu(64, 128, '2d', kernel_size=5, stride=(4, 1)),
784
- ConvNormRelu(128, 108, '2d', kernel_size=3, stride=(4, 1)),
785
- ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1, residual=self.residual),
786
-
787
- ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1, residual=self.residual),
788
- ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1),
789
- )
790
-
791
- def forward(self, x):
792
- assert len(x.shape) == 3 and x.shape[1] == self.C_in and x.shape[2] == self.T_in
793
- x = x.view(x.shape[0], 1, x.shape[1], x.shape[2])
794
- x = self.conv_layers(x)
795
- x = x.squeeze(2)
796
- return x
797
-
798
-
799
- class SeqTranslator1D(nn.Module):
800
- '''
801
- (B, C, T)->(B, C_out, T)
802
- '''
803
-
804
- def __init__(self,
805
- C_in,
806
- C_out,
807
- kernel_size=None,
808
- stride=None,
809
- min_layers_num=None,
810
- residual=True,
811
- norm='bn'
812
- ):
813
- super(SeqTranslator1D, self).__init__()
814
-
815
- conv_layers = nn.ModuleList([])
816
- conv_layers.append(ConvNormRelu(
817
- in_channels=C_in,
818
- out_channels=C_out,
819
- type='1d',
820
- kernel_size=kernel_size,
821
- stride=stride,
822
- residual=residual,
823
- norm=norm
824
- ))
825
- self.num_layers = 1
826
- if min_layers_num is not None and self.num_layers < min_layers_num:
827
- while self.num_layers < min_layers_num:
828
- conv_layers.append(ConvNormRelu(
829
- in_channels=C_out,
830
- out_channels=C_out,
831
- type='1d',
832
- kernel_size=kernel_size,
833
- stride=stride,
834
- residual=residual,
835
- norm=norm
836
- ))
837
- self.num_layers += 1
838
- self.conv_layers = nn.Sequential(*conv_layers)
839
-
840
- def forward(self, x):
841
- return self.conv_layers(x)
842
-
843
-
844
- class SeqTranslatorRNN(nn.Module):
845
- '''
846
- (B, C, T)->(B, C_out, T)
847
- LSTM-FC
848
- '''
849
-
850
- def __init__(self,
851
- C_in,
852
- C_out,
853
- hidden_size,
854
- num_layers,
855
- rnn_cell='gru'
856
- ):
857
- super(SeqTranslatorRNN, self).__init__()
858
-
859
- if rnn_cell == 'gru':
860
- self.enc_cell = nn.GRU(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
861
- bidirectional=False)
862
- self.dec_cell = nn.GRU(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
863
- bidirectional=False)
864
- elif rnn_cell == 'lstm':
865
- self.enc_cell = nn.LSTM(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
866
- bidirectional=False)
867
- self.dec_cell = nn.LSTM(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
868
- bidirectional=False)
869
- else:
870
- raise ValueError('invalid rnn cell:%s' % (rnn_cell))
871
-
872
- self.fc = nn.Linear(hidden_size, C_out)
873
-
874
- def forward(self, x, frame_0):
875
-
876
- num_steps = x.shape[-1]
877
- x = x.permute(0, 2, 1)
878
- frame_0 = frame_0.permute(0, 2, 1)
879
- _, hidden = self.enc_cell(x, None)
880
-
881
- outputs = []
882
- for i in range(num_steps):
883
- inputs = frame_0
884
- output_frame, hidden = self.dec_cell(inputs, hidden)
885
- output_frame = self.fc(output_frame)
886
- frame_0 = output_frame
887
- outputs.append(output_frame)
888
- outputs = torch.cat(outputs, dim=1)
889
- return outputs.permute(0, 2, 1)
890
-
891
-
892
- class ResBlock(nn.Module):
893
- def __init__(self,
894
- input_dim,
895
- fc_dim,
896
- afn,
897
- nfn
898
- ):
899
- '''
900
- afn: activation fn
901
- nfn: normalization fn
902
- '''
903
- super(ResBlock, self).__init__()
904
-
905
- self.input_dim = input_dim
906
- self.fc_dim = fc_dim
907
- self.afn = afn
908
- self.nfn = nfn
909
-
910
- if self.afn != 'relu':
911
- raise ValueError('Wrong')
912
-
913
- if self.nfn == 'layer_norm':
914
- raise ValueError('wrong')
915
-
916
- self.layers = nn.Sequential(
917
- nn.Linear(self.input_dim, self.fc_dim // 2),
918
- nn.ReLU(),
919
- nn.Linear(self.fc_dim // 2, self.fc_dim // 2),
920
- nn.ReLU(),
921
- nn.Linear(self.fc_dim // 2, self.fc_dim),
922
- nn.ReLU()
923
- )
924
-
925
- self.shortcut_layer = nn.Sequential(
926
- nn.Linear(self.input_dim, self.fc_dim),
927
- nn.ReLU(),
928
- )
929
-
930
- def forward(self, inputs):
931
- return self.layers(inputs) + self.shortcut_layer(inputs)
932
-
933
-
934
- class AudioEncoder(nn.Module):
935
- def __init__(self, channels, padding=3, kernel_size=8, conv_stride=2, conv_pool=None, augmentation=False):
936
- super(AudioEncoder, self).__init__()
937
- self.in_channels = channels[0]
938
- self.augmentation = augmentation
939
-
940
- model = []
941
- acti = nn.LeakyReLU(0.2)
942
-
943
- nr_layer = len(channels) - 1
944
-
945
- for i in range(nr_layer):
946
- if conv_pool is None:
947
- model.append(nn.ReflectionPad1d(padding))
948
- model.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=kernel_size, stride=conv_stride))
949
- model.append(acti)
950
- else:
951
- model.append(nn.ReflectionPad1d(padding))
952
- model.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=kernel_size, stride=conv_stride))
953
- model.append(acti)
954
- model.append(conv_pool(kernel_size=2, stride=2))
955
-
956
- if self.augmentation:
957
- model.append(
958
- nn.Conv1d(channels[-1], channels[-1], kernel_size=kernel_size, stride=conv_stride)
959
- )
960
- model.append(acti)
961
-
962
- self.model = nn.Sequential(*model)
963
-
964
- def forward(self, x):
965
-
966
- x = x[:, :self.in_channels, :]
967
- x = self.model(x)
968
- return x
969
-
970
-
971
- class AudioDecoder(nn.Module):
972
- def __init__(self, channels, kernel_size=7, ups=25):
973
- super(AudioDecoder, self).__init__()
974
-
975
- model = []
976
- pad = (kernel_size - 1) // 2
977
- acti = nn.LeakyReLU(0.2)
978
-
979
- for i in range(len(channels) - 2):
980
- model.append(nn.Upsample(scale_factor=2, mode='nearest'))
981
- model.append(nn.ReflectionPad1d(pad))
982
- model.append(nn.Conv1d(channels[i], channels[i + 1],
983
- kernel_size=kernel_size, stride=1))
984
- if i == 0 or i == 1:
985
- model.append(nn.Dropout(p=0.2))
986
- if not i == len(channels) - 2:
987
- model.append(acti)
988
-
989
- model.append(nn.Upsample(size=ups, mode='nearest'))
990
- model.append(nn.ReflectionPad1d(pad))
991
- model.append(nn.Conv1d(channels[-2], channels[-1],
992
- kernel_size=kernel_size, stride=1))
993
-
994
- self.model = nn.Sequential(*model)
995
-
996
- def forward(self, x):
997
- return self.model(x)
998
-
999
-
1000
- class Audio2Pose(nn.Module):
1001
- def __init__(self, pose_dim, embed_size, augmentation, ups=25):
1002
- super(Audio2Pose, self).__init__()
1003
- self.pose_dim = pose_dim
1004
- self.embed_size = embed_size
1005
- self.augmentation = augmentation
1006
-
1007
- self.aud_enc = AudioEncoder(channels=[13, 64, 128, 256], padding=2, kernel_size=7, conv_stride=1,
1008
- conv_pool=nn.AvgPool1d, augmentation=self.augmentation)
1009
- if self.augmentation:
1010
- self.aud_dec = AudioDecoder(channels=[512, 256, 128, pose_dim])
1011
- else:
1012
- self.aud_dec = AudioDecoder(channels=[256, 256, 128, pose_dim], ups=ups)
1013
-
1014
- if self.augmentation:
1015
- self.pose_enc = nn.Sequential(
1016
- nn.Linear(self.embed_size // 2, 256),
1017
- nn.LayerNorm(256)
1018
- )
1019
-
1020
- def forward(self, audio_feat, dec_input=None):
1021
-
1022
- B = audio_feat.shape[0]
1023
-
1024
- aud_embed = self.aud_enc.forward(audio_feat)
1025
-
1026
- if self.augmentation:
1027
- dec_input = dec_input.squeeze(0)
1028
- dec_embed = self.pose_enc(dec_input)
1029
- dec_embed = dec_embed.unsqueeze(2)
1030
- dec_embed = dec_embed.expand(dec_embed.shape[0], dec_embed.shape[1], aud_embed.shape[-1])
1031
- aud_embed = torch.cat([aud_embed, dec_embed], dim=1)
1032
-
1033
- out = self.aud_dec.forward(aud_embed)
1034
- return out
1035
-
1036
-
1037
- if __name__ == '__main__':
1038
- import numpy as np
1039
- import os
1040
- import sys
1041
-
1042
- test_model = SeqEncoder2D(
1043
- C_in=2,
1044
- T_in=25,
1045
- C_out=512,
1046
- num_joints=54,
1047
- )
1048
- print(test_model.num_layers)
1049
-
1050
- input = torch.randn((64, 108, 25))
1051
- output = test_model(input)
1052
- print(output.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nets/smplx_body_pixel.py DELETED
@@ -1,326 +0,0 @@
1
- import os
2
- import sys
3
-
4
- import torch
5
- from torch.optim.lr_scheduler import StepLR
6
-
7
- sys.path.append(os.getcwd())
8
-
9
- from nets.layers import *
10
- from nets.base import TrainWrapperBaseClass
11
- from nets.spg.gated_pixelcnn_v2 import GatedPixelCNN as pixelcnn
12
- from nets.spg.vqvae_1d import VQVAE as s2g_body, Wav2VecEncoder
13
- from nets.spg.vqvae_1d import AudioEncoder
14
- from nets.utils import parse_audio, denormalize
15
- from data_utils import get_mfcc, get_melspec, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
16
- import numpy as np
17
- import torch.optim as optim
18
- import torch.nn.functional as F
19
- from sklearn.preprocessing import normalize
20
-
21
- from data_utils.lower_body import c_index, c_index_3d, c_index_6d
22
- from data_utils.utils import smooth_geom, get_mfcc_sepa
23
-
24
-
25
- class TrainWrapper(TrainWrapperBaseClass):
26
- '''
27
- a wrapper receving a batch from data_utils and calculate loss
28
- '''
29
-
30
- def __init__(self, args, config):
31
- self.args = args
32
- self.config = config
33
- self.device = torch.device(self.args.gpu)
34
- self.global_step = 0
35
-
36
- self.convert_to_6d = self.config.Data.pose.convert_to_6d
37
- self.expression = self.config.Data.pose.expression
38
- self.epoch = 0
39
- self.init_params()
40
- self.num_classes = 4
41
- self.audio = True
42
- self.composition = self.config.Model.composition
43
- self.bh_model = self.config.Model.bh_model
44
-
45
- if self.audio:
46
- self.audioencoder = AudioEncoder(in_dim=64, num_hiddens=256, num_residual_layers=2, num_residual_hiddens=256).to(self.device)
47
- else:
48
- self.audioencoder = None
49
- if self.convert_to_6d:
50
- dim, layer = 512, 10
51
- else:
52
- dim, layer = 256, 15
53
- self.generator = pixelcnn(2048, dim, layer, self.num_classes, self.audio, self.bh_model).to(self.device)
54
- self.g_body = s2g_body(self.each_dim[1], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
55
- num_residual_layers=2, num_residual_hiddens=512).to(self.device)
56
- self.g_hand = s2g_body(self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
57
- num_residual_layers=2, num_residual_hiddens=512).to(self.device)
58
-
59
- model_path = self.config.Model.vq_path
60
- model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
61
- self.g_body.load_state_dict(model_ckpt['generator']['g_body'])
62
- self.g_hand.load_state_dict(model_ckpt['generator']['g_hand'])
63
-
64
- if torch.cuda.device_count() > 1:
65
- self.g_body = torch.nn.DataParallel(self.g_body, device_ids=[0, 1])
66
- self.g_hand = torch.nn.DataParallel(self.g_hand, device_ids=[0, 1])
67
- self.generator = torch.nn.DataParallel(self.generator, device_ids=[0, 1])
68
- if self.audioencoder is not None:
69
- self.audioencoder = torch.nn.DataParallel(self.audioencoder, device_ids=[0, 1])
70
-
71
- self.discriminator = None
72
- if self.convert_to_6d:
73
- self.c_index = c_index_6d
74
- else:
75
- self.c_index = c_index_3d
76
-
77
- super().__init__(args, config)
78
-
79
- def init_optimizer(self):
80
-
81
- print('using Adam')
82
- self.generator_optimizer = optim.Adam(
83
- self.generator.parameters(),
84
- lr=self.config.Train.learning_rate.generator_learning_rate,
85
- betas=[0.9, 0.999]
86
- )
87
- if self.audioencoder is not None:
88
- opt = self.config.Model.AudioOpt
89
- if opt == 'Adam':
90
- self.audioencoder_optimizer = optim.Adam(
91
- self.audioencoder.parameters(),
92
- lr=self.config.Train.learning_rate.generator_learning_rate,
93
- betas=[0.9, 0.999]
94
- )
95
- else:
96
- print('using SGD')
97
- self.audioencoder_optimizer = optim.SGD(
98
- filter(lambda p: p.requires_grad,self.audioencoder.parameters()),
99
- lr=self.config.Train.learning_rate.generator_learning_rate*10,
100
- momentum=0.9,
101
- nesterov=False,
102
- )
103
-
104
- def state_dict(self):
105
- model_state = {
106
- 'generator': self.generator.state_dict(),
107
- 'generator_optim': self.generator_optimizer.state_dict(),
108
- 'audioencoder': self.audioencoder.state_dict() if self.audio else None,
109
- 'audioencoder_optim': self.audioencoder_optimizer.state_dict() if self.audio else None,
110
- 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
111
- 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
112
- }
113
- return model_state
114
-
115
- def load_state_dict(self, state_dict):
116
-
117
- from collections import OrderedDict
118
- new_state_dict = OrderedDict() # create new OrderedDict that does not contain `module.`
119
- for k, v in state_dict.items():
120
- sub_dict = OrderedDict()
121
- if v is not None:
122
- for k1, v1 in v.items():
123
- name = k1.replace('module.', '')
124
- sub_dict[name] = v1
125
- new_state_dict[k] = sub_dict
126
- state_dict = new_state_dict
127
- if 'generator' in state_dict:
128
- self.generator.load_state_dict(state_dict['generator'])
129
- else:
130
- self.generator.load_state_dict(state_dict)
131
-
132
- if 'generator_optim' in state_dict and self.generator_optimizer is not None:
133
- self.generator_optimizer.load_state_dict(state_dict['generator_optim'])
134
-
135
- if self.discriminator is not None:
136
- self.discriminator.load_state_dict(state_dict['discriminator'])
137
-
138
- if 'discriminator_optim' in state_dict and self.discriminator_optimizer is not None:
139
- self.discriminator_optimizer.load_state_dict(state_dict['discriminator_optim'])
140
-
141
- if 'audioencoder' in state_dict and self.audioencoder is not None:
142
- self.audioencoder.load_state_dict(state_dict['audioencoder'])
143
-
144
- def init_params(self):
145
- if self.config.Data.pose.convert_to_6d:
146
- scale = 2
147
- else:
148
- scale = 1
149
-
150
- global_orient = round(0 * scale)
151
- leye_pose = reye_pose = round(0 * scale)
152
- jaw_pose = round(0 * scale)
153
- body_pose = round((63 - 24) * scale)
154
- left_hand_pose = right_hand_pose = round(45 * scale)
155
- if self.expression:
156
- expression = 100
157
- else:
158
- expression = 0
159
-
160
- b_j = 0
161
- jaw_dim = jaw_pose
162
- b_e = b_j + jaw_dim
163
- eye_dim = leye_pose + reye_pose
164
- b_b = b_e + eye_dim
165
- body_dim = global_orient + body_pose
166
- b_h = b_b + body_dim
167
- hand_dim = left_hand_pose + right_hand_pose
168
- b_f = b_h + hand_dim
169
- face_dim = expression
170
-
171
- self.dim_list = [b_j, b_e, b_b, b_h, b_f]
172
- self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
173
- self.pose = int(self.full_dim / round(3 * scale))
174
- self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
175
-
176
- def __call__(self, bat):
177
- # assert (not self.args.infer), "infer mode"
178
- self.global_step += 1
179
-
180
- total_loss = None
181
- loss_dict = {}
182
-
183
- aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
184
-
185
- id = bat['speaker'].to(self.device) - 20
186
- # id = F.one_hot(id, self.num_classes)
187
-
188
- poses = poses[:, self.c_index, :]
189
-
190
- aud = aud.permute(0, 2, 1)
191
- gt_poses = poses.permute(0, 2, 1)
192
-
193
- with torch.no_grad():
194
- self.g_body.eval()
195
- self.g_hand.eval()
196
- if torch.cuda.device_count() > 1:
197
- _, body_latents = self.g_body.module.encode(gt_poses=gt_poses[..., :self.each_dim[1]], id=id)
198
- _, hand_latents = self.g_hand.module.encode(gt_poses=gt_poses[..., self.each_dim[1]:], id=id)
199
- else:
200
- _, body_latents = self.g_body.encode(gt_poses=gt_poses[..., :self.each_dim[1]], id=id)
201
- _, hand_latents = self.g_hand.encode(gt_poses=gt_poses[..., self.each_dim[1]:], id=id)
202
- latents = torch.cat([body_latents.unsqueeze(dim=-1), hand_latents.unsqueeze(dim=-1)], dim=-1)
203
- latents = latents.detach()
204
-
205
- if self.audio:
206
- audio = self.audioencoder(aud[:, :].transpose(1, 2), frame_num=latents.shape[1]*4).unsqueeze(dim=-1).repeat(1, 1, 1, 2)
207
- logits = self.generator(latents[:, :], id, audio)
208
- else:
209
- logits = self.generator(latents, id)
210
- logits = logits.permute(0, 2, 3, 1).contiguous()
211
-
212
- self.generator_optimizer.zero_grad()
213
- if self.audio:
214
- self.audioencoder_optimizer.zero_grad()
215
-
216
- loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), latents.view(-1))
217
- loss.backward()
218
-
219
- grad = torch.nn.utils.clip_grad_norm(self.generator.parameters(), self.config.Train.max_gradient_norm)
220
-
221
- if torch.isnan(grad).sum() > 0:
222
- print('fuck')
223
-
224
- loss_dict['grad'] = grad.item()
225
- loss_dict['ce_loss'] = loss.item()
226
- self.generator_optimizer.step()
227
- if self.audio:
228
- self.audioencoder_optimizer.step()
229
-
230
- return total_loss, loss_dict
231
-
232
- def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, exp=None, var=None, w_pre=False, rand=None,
233
- continuity=False, id=None, fps=15, sr=22000, B=1, am=None, am_sr=None, frame=0,**kwargs):
234
- '''
235
- initial_pose: (B, C, T), normalized
236
- (aud_fn, txgfile) -> generated motion (B, T, C)
237
- '''
238
- output = []
239
-
240
- assert self.args.infer, "train mode"
241
- self.generator.eval()
242
- self.g_body.eval()
243
- self.g_hand.eval()
244
-
245
- if continuity:
246
- aud_feat, gap = get_mfcc_sepa(aud_fn, sr=sr, fps=fps)
247
- else:
248
- aud_feat = get_mfcc_ta(aud_fn, sr=sr, fps=fps, smlpx=True, type='mfcc', am=am)
249
- aud_feat = aud_feat.transpose(1, 0)
250
- aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
251
- aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device)
252
-
253
- if id is None:
254
- id = torch.tensor([0]).to(self.device)
255
- else:
256
- id = id.repeat(B)
257
-
258
- with torch.no_grad():
259
- aud_feat = aud_feat.permute(0, 2, 1)
260
- if continuity:
261
- self.audioencoder.eval()
262
- pre_pose = {}
263
- pre_pose['b'] = pre_pose['h'] = None
264
- pre_latents, pre_audio, body_0, hand_0 = self.infer(aud_feat[:, :gap], frame, id, B, pre_pose=pre_pose)
265
- pre_pose['b'] = body_0[:, :, -4:].transpose(1,2)
266
- pre_pose['h'] = hand_0[:, :, -4:].transpose(1,2)
267
- _, _, body_1, hand_1 = self.infer(aud_feat[:, gap:], frame, id, B, pre_latents, pre_audio, pre_pose)
268
- body = torch.cat([body_0, body_1], dim=2)
269
- hand = torch.cat([hand_0, hand_1], dim=2)
270
-
271
- else:
272
- if self.audio:
273
- self.audioencoder.eval()
274
- audio = self.audioencoder(aud_feat.transpose(1, 2), frame_num=frame).unsqueeze(dim=-1).repeat(1, 1, 1, 2)
275
- latents = self.generator.generate(id, shape=[audio.shape[2], 2], batch_size=B, aud_feat=audio)
276
- else:
277
- latents = self.generator.generate(id, shape=[aud_feat.shape[1]//4, 2], batch_size=B)
278
-
279
- body_latents = latents[..., 0]
280
- hand_latents = latents[..., 1]
281
-
282
- body, _ = self.g_body.decode(b=body_latents.shape[0], w=body_latents.shape[1], latents=body_latents)
283
- hand, _ = self.g_hand.decode(b=hand_latents.shape[0], w=hand_latents.shape[1], latents=hand_latents)
284
-
285
- pred_poses = torch.cat([body, hand], dim=1).transpose(1,2).cpu().numpy()
286
-
287
- output = pred_poses
288
-
289
- return output
290
-
291
- def infer(self, aud_feat, frame, id, B, pre_latents=None, pre_audio=None, pre_pose=None):
292
- audio = self.audioencoder(aud_feat.transpose(1, 2), frame_num=frame).unsqueeze(dim=-1).repeat(1, 1, 1, 2)
293
- latents = self.generator.generate(id, shape=[audio.shape[2], 2], batch_size=B, aud_feat=audio,
294
- pre_latents=pre_latents, pre_audio=pre_audio)
295
-
296
- body_latents = latents[..., 0]
297
- hand_latents = latents[..., 1]
298
-
299
- body, _ = self.g_body.decode(b=body_latents.shape[0], w=body_latents.shape[1],
300
- latents=body_latents, pre_state=pre_pose['b'])
301
- hand, _ = self.g_hand.decode(b=hand_latents.shape[0], w=hand_latents.shape[1],
302
- latents=hand_latents, pre_state=pre_pose['h'])
303
-
304
- return latents, audio, body, hand
305
-
306
- def generate(self, aud, id, frame_num=0):
307
-
308
- self.generator.eval()
309
- self.g_body.eval()
310
- self.g_hand.eval()
311
- aud_feat = aud.permute(0, 2, 1)
312
- if self.audio:
313
- self.audioencoder.eval()
314
- audio = self.audioencoder(aud_feat.transpose(1, 2), frame_num=frame_num).unsqueeze(dim=-1).repeat(1, 1, 1, 2)
315
- latents = self.generator.generate(id, shape=[audio.shape[2], 2], batch_size=aud.shape[0], aud_feat=audio)
316
- else:
317
- latents = self.generator.generate(id, shape=[aud_feat.shape[1] // 4, 2], batch_size=aud.shape[0])
318
-
319
- body_latents = latents[..., 0]
320
- hand_latents = latents[..., 1]
321
-
322
- body = self.g_body.decode(b=body_latents.shape[0], w=body_latents.shape[1], latents=body_latents)
323
- hand = self.g_hand.decode(b=hand_latents.shape[0], w=hand_latents.shape[1], latents=hand_latents)
324
-
325
- pred_poses = torch.cat([body, hand], dim=1).transpose(1, 2)
326
- return pred_poses
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nets/smplx_body_vq.py DELETED
@@ -1,302 +0,0 @@
1
- import os
2
- import sys
3
-
4
- from torch.optim.lr_scheduler import StepLR
5
-
6
- sys.path.append(os.getcwd())
7
-
8
- from nets.layers import *
9
- from nets.base import TrainWrapperBaseClass
10
- from nets.spg.s2glayers import Generator as G_S2G, Discriminator as D_S2G
11
- from nets.spg.vqvae_1d import VQVAE as s2g_body
12
- from nets.utils import parse_audio, denormalize
13
- from data_utils import get_mfcc, get_melspec, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
14
- import numpy as np
15
- import torch.optim as optim
16
- import torch.nn.functional as F
17
- from sklearn.preprocessing import normalize
18
-
19
- from data_utils.lower_body import c_index, c_index_3d, c_index_6d
20
-
21
-
22
- class TrainWrapper(TrainWrapperBaseClass):
23
- '''
24
- a wrapper receving a batch from data_utils and calculate loss
25
- '''
26
-
27
- def __init__(self, args, config):
28
- self.args = args
29
- self.config = config
30
- self.device = torch.device(self.args.gpu)
31
- self.global_step = 0
32
-
33
- self.convert_to_6d = self.config.Data.pose.convert_to_6d
34
- self.expression = self.config.Data.pose.expression
35
- self.epoch = 0
36
- self.init_params()
37
- self.num_classes = 4
38
- self.composition = self.config.Model.composition
39
- if self.composition:
40
- self.g_body = s2g_body(self.each_dim[1], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
41
- num_residual_layers=2, num_residual_hiddens=512).to(self.device)
42
- self.g_hand = s2g_body(self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
43
- num_residual_layers=2, num_residual_hiddens=512).to(self.device)
44
- else:
45
- self.g = s2g_body(self.each_dim[1] + self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num,
46
- num_hiddens=1024, num_residual_layers=2, num_residual_hiddens=512).to(self.device)
47
-
48
- self.discriminator = None
49
-
50
- if self.convert_to_6d:
51
- self.c_index = c_index_6d
52
- else:
53
- self.c_index = c_index_3d
54
-
55
- super().__init__(args, config)
56
-
57
- def init_optimizer(self):
58
- print('using Adam')
59
- if self.composition:
60
- self.g_body_optimizer = optim.Adam(
61
- self.g_body.parameters(),
62
- lr=self.config.Train.learning_rate.generator_learning_rate,
63
- betas=[0.9, 0.999]
64
- )
65
- self.g_hand_optimizer = optim.Adam(
66
- self.g_hand.parameters(),
67
- lr=self.config.Train.learning_rate.generator_learning_rate,
68
- betas=[0.9, 0.999]
69
- )
70
- else:
71
- self.g_optimizer = optim.Adam(
72
- self.g.parameters(),
73
- lr=self.config.Train.learning_rate.generator_learning_rate,
74
- betas=[0.9, 0.999]
75
- )
76
-
77
- def state_dict(self):
78
- if self.composition:
79
- model_state = {
80
- 'g_body': self.g_body.state_dict(),
81
- 'g_body_optim': self.g_body_optimizer.state_dict(),
82
- 'g_hand': self.g_hand.state_dict(),
83
- 'g_hand_optim': self.g_hand_optimizer.state_dict(),
84
- 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
85
- 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
86
- }
87
- else:
88
- model_state = {
89
- 'g': self.g.state_dict(),
90
- 'g_optim': self.g_optimizer.state_dict(),
91
- 'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
92
- 'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
93
- }
94
- return model_state
95
-
96
- def init_params(self):
97
- if self.config.Data.pose.convert_to_6d:
98
- scale = 2
99
- else:
100
- scale = 1
101
-
102
- global_orient = round(0 * scale)
103
- leye_pose = reye_pose = round(0 * scale)
104
- jaw_pose = round(0 * scale)
105
- body_pose = round((63 - 24) * scale)
106
- left_hand_pose = right_hand_pose = round(45 * scale)
107
- if self.expression:
108
- expression = 100
109
- else:
110
- expression = 0
111
-
112
- b_j = 0
113
- jaw_dim = jaw_pose
114
- b_e = b_j + jaw_dim
115
- eye_dim = leye_pose + reye_pose
116
- b_b = b_e + eye_dim
117
- body_dim = global_orient + body_pose
118
- b_h = b_b + body_dim
119
- hand_dim = left_hand_pose + right_hand_pose
120
- b_f = b_h + hand_dim
121
- face_dim = expression
122
-
123
- self.dim_list = [b_j, b_e, b_b, b_h, b_f]
124
- self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
125
- self.pose = int(self.full_dim / round(3 * scale))
126
- self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
127
-
128
- def __call__(self, bat):
129
- # assert (not self.args.infer), "infer mode"
130
- self.global_step += 1
131
-
132
- total_loss = None
133
- loss_dict = {}
134
-
135
- aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
136
-
137
- # id = bat['speaker'].to(self.device) - 20
138
- # id = F.one_hot(id, self.num_classes)
139
-
140
- poses = poses[:, self.c_index, :]
141
- gt_poses = poses.permute(0, 2, 1)
142
- b_poses = gt_poses[..., :self.each_dim[1]]
143
- h_poses = gt_poses[..., self.each_dim[1]:]
144
-
145
- if self.composition:
146
- loss = 0
147
- loss_dict, loss = self.vq_train(b_poses[:, :], 'b', self.g_body, loss_dict, loss)
148
- loss_dict, loss = self.vq_train(h_poses[:, :], 'h', self.g_hand, loss_dict, loss)
149
- else:
150
- loss = 0
151
- loss_dict, loss = self.vq_train(gt_poses[:, :], 'g', self.g, loss_dict, loss)
152
-
153
- return total_loss, loss_dict
154
-
155
- def vq_train(self, gt, name, model, dict, total_loss, pre=None):
156
- e_q_loss, x_recon = model(gt_poses=gt, pre_state=pre)
157
- loss, loss_dict = self.get_loss(pred_poses=x_recon, gt_poses=gt, e_q_loss=e_q_loss, pre=pre)
158
- # total_loss = total_loss + loss
159
-
160
- if name == 'b':
161
- optimizer_name = 'g_body_optimizer'
162
- elif name == 'h':
163
- optimizer_name = 'g_hand_optimizer'
164
- elif name == 'g':
165
- optimizer_name = 'g_optimizer'
166
- else:
167
- raise ValueError("model's name must be b or h")
168
- optimizer = getattr(self, optimizer_name)
169
- optimizer.zero_grad()
170
- loss.backward()
171
- optimizer.step()
172
-
173
- for key in list(loss_dict.keys()):
174
- dict[name + key] = loss_dict.get(key, 0).item()
175
- return dict, total_loss
176
-
177
- def get_loss(self,
178
- pred_poses,
179
- gt_poses,
180
- e_q_loss,
181
- pre=None
182
- ):
183
- loss_dict = {}
184
-
185
-
186
- rec_loss = torch.mean(torch.abs(pred_poses - gt_poses))
187
- v_pr = pred_poses[:, 1:] - pred_poses[:, :-1]
188
- v_gt = gt_poses[:, 1:] - gt_poses[:, :-1]
189
- velocity_loss = torch.mean(torch.abs(v_pr - v_gt))
190
-
191
- if pre is None:
192
- f0_vel = 0
193
- else:
194
- v0_pr = pred_poses[:, 0] - pre[:, -1]
195
- v0_gt = gt_poses[:, 0] - pre[:, -1]
196
- f0_vel = torch.mean(torch.abs(v0_pr - v0_gt))
197
-
198
- gen_loss = rec_loss + e_q_loss + velocity_loss + f0_vel
199
-
200
- loss_dict['rec_loss'] = rec_loss
201
- loss_dict['velocity_loss'] = velocity_loss
202
- # loss_dict['e_q_loss'] = e_q_loss
203
- if pre is not None:
204
- loss_dict['f0_vel'] = f0_vel
205
-
206
- return gen_loss, loss_dict
207
-
208
- def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, exp=None, var=None, w_pre=False, continuity=False,
209
- id=None, fps=15, sr=22000, smooth=False, **kwargs):
210
- '''
211
- initial_pose: (B, C, T), normalized
212
- (aud_fn, txgfile) -> generated motion (B, T, C)
213
- '''
214
- output = []
215
-
216
- assert self.args.infer, "train mode"
217
- if self.composition:
218
- self.g_body.eval()
219
- self.g_hand.eval()
220
- else:
221
- self.g.eval()
222
-
223
- if self.config.Data.pose.normalization:
224
- assert norm_stats is not None
225
- data_mean = norm_stats[0]
226
- data_std = norm_stats[1]
227
-
228
- # assert initial_pose.shape[-1] == pre_length
229
- if initial_pose is not None:
230
- gt = initial_pose[:, :, :].to(self.device).to(torch.float32)
231
- pre_poses = initial_pose[:, :, :15].permute(0, 2, 1).to(self.device).to(torch.float32)
232
- poses = initial_pose.permute(0, 2, 1).to(self.device).to(torch.float32)
233
- B = pre_poses.shape[0]
234
- else:
235
- gt = None
236
- pre_poses = None
237
- B = 1
238
-
239
- if type(aud_fn) == torch.Tensor:
240
- aud_feat = torch.tensor(aud_fn, dtype=torch.float32).to(self.device)
241
- num_poses_to_generate = aud_feat.shape[-1]
242
- else:
243
- aud_feat = get_mfcc_ta(aud_fn, sr=sr, fps=fps, smlpx=True, type='mfcc').transpose(1, 0)
244
- aud_feat = aud_feat[:, :]
245
- num_poses_to_generate = aud_feat.shape[-1]
246
- aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
247
- aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device)
248
-
249
- # pre_poses = torch.randn(pre_poses.shape).to(self.device).to(torch.float32)
250
- if id is None:
251
- id = F.one_hot(torch.tensor([[0]]), self.num_classes).to(self.device)
252
-
253
- with torch.no_grad():
254
- aud_feat = aud_feat.permute(0, 2, 1)
255
- gt_poses = gt[:, self.c_index].permute(0, 2, 1)
256
- if self.composition:
257
- if continuity:
258
- pred_poses_body = []
259
- pred_poses_hand = []
260
- pre_b = None
261
- pre_h = None
262
- for i in range(5):
263
- _, pred_body = self.g_body(gt_poses=gt_poses[:, i*60:(i+1)*60, :self.each_dim[1]], pre_state=pre_b)
264
- pre_b = pred_body[..., -1:].transpose(1,2)
265
- pred_poses_body.append(pred_body)
266
- _, pred_hand = self.g_hand(gt_poses=gt_poses[:, i*60:(i+1)*60, self.each_dim[1]:], pre_state=pre_h)
267
- pre_h = pred_hand[..., -1:].transpose(1,2)
268
- pred_poses_hand.append(pred_hand)
269
-
270
- pred_poses_body = torch.cat(pred_poses_body, dim=2)
271
- pred_poses_hand = torch.cat(pred_poses_hand, dim=2)
272
- else:
273
- _, pred_poses_body = self.g_body(gt_poses=gt_poses[..., :self.each_dim[1]], id=id)
274
- _, pred_poses_hand = self.g_hand(gt_poses=gt_poses[..., self.each_dim[1]:], id=id)
275
- pred_poses = torch.cat([pred_poses_body, pred_poses_hand], dim=1)
276
- else:
277
- _, pred_poses = self.g(gt_poses=gt_poses, id=id)
278
- pred_poses = pred_poses.transpose(1, 2).cpu().numpy()
279
- output = pred_poses
280
-
281
- if self.config.Data.pose.normalization:
282
- output = denormalize(output, data_mean, data_std)
283
-
284
- if smooth:
285
- lamda = 0.8
286
- smooth_f = 10
287
- frame = 149
288
- for i in range(smooth_f):
289
- f = frame + i
290
- l = lamda * (i + 1) / smooth_f
291
- output[0, f] = (1 - l) * output[0, f - 1] + l * output[0, f]
292
-
293
- output = np.concatenate(output, axis=1)
294
-
295
- return output
296
-
297
- def load_state_dict(self, state_dict):
298
- if self.composition:
299
- self.g_body.load_state_dict(state_dict['g_body'])
300
- self.g_hand.load_state_dict(state_dict['g_hand'])
301
- else:
302
- self.g.load_state_dict(state_dict['g'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nets/smplx_face.py DELETED
@@ -1,238 +0,0 @@
1
- import os
2
- import sys
3
-
4
- sys.path.append(os.getcwd())
5
-
6
- from nets.layers import *
7
- from nets.base import TrainWrapperBaseClass
8
- # from nets.spg.faceformer import Faceformer
9
- from nets.spg.s2g_face import Generator as s2g_face
10
- from losses import KeypointLoss
11
- from nets.utils import denormalize
12
- from data_utils import get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
13
- import numpy as np
14
- import torch.optim as optim
15
- import torch.nn.functional as F
16
- from sklearn.preprocessing import normalize
17
- import smplx
18
-
19
-
20
- class TrainWrapper(TrainWrapperBaseClass):
21
- '''
22
- a wrapper receving a batch from data_utils and calculate loss
23
- '''
24
-
25
- def __init__(self, args, config):
26
- self.args = args
27
- self.config = config
28
- self.device = torch.device(self.args.gpu)
29
- self.global_step = 0
30
-
31
- self.convert_to_6d = self.config.Data.pose.convert_to_6d
32
- self.expression = self.config.Data.pose.expression
33
- self.epoch = 0
34
- self.init_params()
35
- self.num_classes = 4
36
-
37
- self.generator = s2g_face(
38
- n_poses=self.config.Data.pose.generate_length,
39
- each_dim=self.each_dim,
40
- dim_list=self.dim_list,
41
- training=not self.args.infer,
42
- device=self.device,
43
- identity=False if self.convert_to_6d else True,
44
- num_classes=self.num_classes,
45
- ).to(self.device)
46
-
47
- # self.generator = Faceformer().to(self.device)
48
-
49
- self.discriminator = None
50
- self.am = None
51
-
52
- self.MSELoss = KeypointLoss().to(self.device)
53
- super().__init__(args, config)
54
-
55
- def init_optimizer(self):
56
- self.generator_optimizer = optim.SGD(
57
- filter(lambda p: p.requires_grad,self.generator.parameters()),
58
- lr=0.001,
59
- momentum=0.9,
60
- nesterov=False,
61
- )
62
-
63
- def init_params(self):
64
- if self.convert_to_6d:
65
- scale = 2
66
- else:
67
- scale = 1
68
-
69
- global_orient = round(3 * scale)
70
- leye_pose = reye_pose = round(3 * scale)
71
- jaw_pose = round(3 * scale)
72
- body_pose = round(63 * scale)
73
- left_hand_pose = right_hand_pose = round(45 * scale)
74
- if self.expression:
75
- expression = 100
76
- else:
77
- expression = 0
78
-
79
- b_j = 0
80
- jaw_dim = jaw_pose
81
- b_e = b_j + jaw_dim
82
- eye_dim = leye_pose + reye_pose
83
- b_b = b_e + eye_dim
84
- body_dim = global_orient + body_pose
85
- b_h = b_b + body_dim
86
- hand_dim = left_hand_pose + right_hand_pose
87
- b_f = b_h + hand_dim
88
- face_dim = expression
89
-
90
- self.dim_list = [b_j, b_e, b_b, b_h, b_f]
91
- self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim + face_dim
92
- self.pose = int(self.full_dim / round(3 * scale))
93
- self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
94
-
95
- def __call__(self, bat):
96
- # assert (not self.args.infer), "infer mode"
97
- self.global_step += 1
98
-
99
- total_loss = None
100
- loss_dict = {}
101
-
102
- aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
103
- id = bat['speaker'].to(self.device) - 20
104
- id = F.one_hot(id, self.num_classes)
105
-
106
- aud = aud.permute(0, 2, 1)
107
- gt_poses = poses.permute(0, 2, 1)
108
-
109
- if self.expression:
110
- expression = bat['expression'].to(self.device).to(torch.float32)
111
- gt_poses = torch.cat([gt_poses, expression.permute(0, 2, 1)], dim=2)
112
-
113
- pred_poses, _ = self.generator(
114
- aud,
115
- gt_poses,
116
- id,
117
- )
118
-
119
- G_loss, G_loss_dict = self.get_loss(
120
- pred_poses=pred_poses,
121
- gt_poses=gt_poses,
122
- pre_poses=None,
123
- mode='training_G',
124
- gt_conf=None,
125
- aud=aud,
126
- )
127
-
128
- self.generator_optimizer.zero_grad()
129
- G_loss.backward()
130
- grad = torch.nn.utils.clip_grad_norm(self.generator.parameters(), self.config.Train.max_gradient_norm)
131
- loss_dict['grad'] = grad.item()
132
- self.generator_optimizer.step()
133
-
134
- for key in list(G_loss_dict.keys()):
135
- loss_dict[key] = G_loss_dict.get(key, 0).item()
136
-
137
- return total_loss, loss_dict
138
-
139
- def get_loss(self,
140
- pred_poses,
141
- gt_poses,
142
- pre_poses,
143
- aud,
144
- mode='training_G',
145
- gt_conf=None,
146
- exp=1,
147
- gt_nzero=None,
148
- pre_nzero=None,
149
- ):
150
- loss_dict = {}
151
-
152
-
153
- [b_j, b_e, b_b, b_h, b_f] = self.dim_list
154
-
155
- MSELoss = torch.mean(torch.abs(pred_poses[:, :, :6] - gt_poses[:, :, :6]))
156
- if self.expression:
157
- expl = torch.mean((pred_poses[:, :, -100:] - gt_poses[:, :, -100:])**2)
158
- else:
159
- expl = 0
160
-
161
- gen_loss = expl + MSELoss
162
-
163
- loss_dict['MSELoss'] = MSELoss
164
- if self.expression:
165
- loss_dict['exp_loss'] = expl
166
-
167
- return gen_loss, loss_dict
168
-
169
- def infer_on_audio(self, aud_fn, id=None, initial_pose=None, norm_stats=None, w_pre=False, frame=None, am=None, am_sr=16000, **kwargs):
170
- '''
171
- initial_pose: (B, C, T), normalized
172
- (aud_fn, txgfile) -> generated motion (B, T, C)
173
- '''
174
- output = []
175
-
176
- # assert self.args.infer, "train mode"
177
- self.generator.eval()
178
-
179
- if self.config.Data.pose.normalization:
180
- assert norm_stats is not None
181
- data_mean = norm_stats[0]
182
- data_std = norm_stats[1]
183
-
184
- # assert initial_pose.shape[-1] == pre_length
185
- if initial_pose is not None:
186
- gt = initial_pose[:,:,:].permute(0, 2, 1).to(self.generator.device).to(torch.float32)
187
- pre_poses = initial_pose[:,:,:15].permute(0, 2, 1).to(self.generator.device).to(torch.float32)
188
- poses = initial_pose.permute(0, 2, 1).to(self.generator.device).to(torch.float32)
189
- B = pre_poses.shape[0]
190
- else:
191
- gt = None
192
- pre_poses=None
193
- B = 1
194
-
195
- if type(aud_fn) == torch.Tensor:
196
- aud_feat = torch.tensor(aud_fn, dtype=torch.float32).to(self.generator.device)
197
- num_poses_to_generate = aud_feat.shape[-1]
198
- else:
199
- aud_feat = get_mfcc_ta(aud_fn, am=am, am_sr=am_sr, fps=30, encoder_choice='faceformer')
200
- aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
201
- aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.generator.device).transpose(1, 2)
202
- if frame is None:
203
- frame = aud_feat.shape[2]*30//16000
204
- #
205
- if id is None:
206
- id = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32, device=self.generator.device)
207
- else:
208
- id = F.one_hot(id, self.num_classes).to(self.generator.device)
209
-
210
- with torch.no_grad():
211
- pred_poses = self.generator(aud_feat, pre_poses, id, time_steps=frame)[0]
212
- pred_poses = pred_poses.cpu().numpy()
213
- output = pred_poses
214
-
215
- if self.config.Data.pose.normalization:
216
- output = denormalize(output, data_mean, data_std)
217
-
218
- return output
219
-
220
-
221
- def generate(self, wv2_feat, frame):
222
- '''
223
- initial_pose: (B, C, T), normalized
224
- (aud_fn, txgfile) -> generated motion (B, T, C)
225
- '''
226
- output = []
227
-
228
- # assert self.args.infer, "train mode"
229
- self.generator.eval()
230
-
231
- B = 1
232
-
233
- id = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32, device=self.generator.device)
234
- id = id.repeat(wv2_feat.shape[0], 1)
235
-
236
- with torch.no_grad():
237
- pred_poses = self.generator(wv2_feat, None, id, time_steps=frame)[0]
238
- return pred_poses
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nets/spg/gated_pixelcnn_v2.py DELETED
@@ -1,179 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
-
6
- def weights_init(m):
7
- classname = m.__class__.__name__
8
- if classname.find('Conv') != -1:
9
- try:
10
- nn.init.xavier_uniform_(m.weight.data)
11
- m.bias.data.fill_(0)
12
- except AttributeError:
13
- print("Skipping initialization of ", classname)
14
-
15
-
16
- class GatedActivation(nn.Module):
17
- def __init__(self):
18
- super().__init__()
19
-
20
- def forward(self, x):
21
- x, y = x.chunk(2, dim=1)
22
- return F.tanh(x) * F.sigmoid(y)
23
-
24
-
25
- class GatedMaskedConv2d(nn.Module):
26
- def __init__(self, mask_type, dim, kernel, residual=True, n_classes=10, bh_model=False):
27
- super().__init__()
28
- assert kernel % 2 == 1, print("Kernel size must be odd")
29
- self.mask_type = mask_type
30
- self.residual = residual
31
- self.bh_model = bh_model
32
-
33
- self.class_cond_embedding = nn.Embedding(n_classes, 2 * dim)
34
- self.class_cond_embedding = self.class_cond_embedding.to("cpu")
35
-
36
- kernel_shp = (kernel // 2 + 1, 3 if self.bh_model else 1) # (ceil(n/2), n)
37
- padding_shp = (kernel // 2, 1 if self.bh_model else 0)
38
- self.vert_stack = nn.Conv2d(
39
- dim, dim * 2,
40
- kernel_shp, 1, padding_shp
41
- )
42
-
43
- self.vert_to_horiz = nn.Conv2d(2 * dim, 2 * dim, 1)
44
-
45
- kernel_shp = (1, 2)
46
- padding_shp = (0, 1)
47
- self.horiz_stack = nn.Conv2d(
48
- dim, dim * 2,
49
- kernel_shp, 1, padding_shp
50
- )
51
-
52
- self.horiz_resid = nn.Conv2d(dim, dim, 1)
53
-
54
- self.gate = GatedActivation()
55
-
56
- def make_causal(self):
57
- self.vert_stack.weight.data[:, :, -1].zero_() # Mask final row
58
- self.horiz_stack.weight.data[:, :, :, -1].zero_() # Mask final column
59
-
60
- def forward(self, x_v, x_h, h):
61
- if self.mask_type == 'A':
62
- self.make_causal()
63
-
64
- h = h.to(self.class_cond_embedding.weight.device)
65
- h = self.class_cond_embedding(h)
66
-
67
- h_vert = self.vert_stack(x_v)
68
- h_vert = h_vert[:, :, :x_v.size(-2), :]
69
- out_v = self.gate(h_vert + h[:, :, None, None])
70
-
71
- if self.bh_model:
72
- h_horiz = self.horiz_stack(x_h)
73
- h_horiz = h_horiz[:, :, :, :x_h.size(-1)]
74
- v2h = self.vert_to_horiz(h_vert)
75
-
76
- out = self.gate(v2h + h_horiz + h[:, :, None, None])
77
- if self.residual:
78
- out_h = self.horiz_resid(out) + x_h
79
- else:
80
- out_h = self.horiz_resid(out)
81
- else:
82
- if self.residual:
83
- out_v = self.horiz_resid(out_v) + x_v
84
- else:
85
- out_v = self.horiz_resid(out_v)
86
- out_h = out_v
87
-
88
- return out_v, out_h
89
-
90
-
91
- class GatedPixelCNN(nn.Module):
92
- def __init__(self, input_dim=256, dim=64, n_layers=15, n_classes=10, audio=False, bh_model=False):
93
- super().__init__()
94
- self.dim = dim
95
- self.audio = audio
96
- self.bh_model = bh_model
97
-
98
- if self.audio:
99
- self.embedding_aud = nn.Conv2d(256, dim, 1, 1, padding=0)
100
- self.fusion_v = nn.Conv2d(dim * 2, dim, 1, 1, padding=0)
101
- self.fusion_h = nn.Conv2d(dim * 2, dim, 1, 1, padding=0)
102
-
103
- # Create embedding layer to embed input
104
- self.embedding = nn.Embedding(input_dim, dim)
105
-
106
- # Building the PixelCNN layer by layer
107
- self.layers = nn.ModuleList()
108
-
109
- # Initial block with Mask-A convolution
110
- # Rest with Mask-B convolutions
111
- for i in range(n_layers):
112
- mask_type = 'A' if i == 0 else 'B'
113
- kernel = 7 if i == 0 else 3
114
- residual = False if i == 0 else True
115
-
116
- self.layers.append(
117
- GatedMaskedConv2d(mask_type, dim, kernel, residual, n_classes, bh_model)
118
- )
119
-
120
- # Add the output layer
121
- self.output_conv = nn.Sequential(
122
- nn.Conv2d(dim, 512, 1),
123
- nn.ReLU(True),
124
- nn.Conv2d(512, input_dim, 1)
125
- )
126
-
127
- self.apply(weights_init)
128
-
129
- self.dp = nn.Dropout(0.1)
130
- self.to("cpu")
131
-
132
- def forward(self, x, label, aud=None):
133
- shp = x.size() + (-1,)
134
- x = self.embedding(x.view(-1)).view(shp) # (B, H, W, C)
135
- x = x.permute(0, 3, 1, 2) # (B, C, W, W)
136
-
137
- x_v, x_h = (x, x)
138
- for i, layer in enumerate(self.layers):
139
- if i == 1 and self.audio is True:
140
- aud = self.embedding_aud(aud)
141
- a = torch.ones(aud.shape[-2]).to(aud.device)
142
- a = self.dp(a)
143
- aud = (aud.transpose(-1, -2) * a).transpose(-1, -2)
144
- x_v = self.fusion_v(torch.cat([x_v, aud], dim=1))
145
- if self.bh_model:
146
- x_h = self.fusion_h(torch.cat([x_h, aud], dim=1))
147
- x_v, x_h = layer(x_v, x_h, label)
148
-
149
- if self.bh_model:
150
- return self.output_conv(x_h)
151
- else:
152
- return self.output_conv(x_v)
153
-
154
- def generate(self, label, shape=(8, 8), batch_size=64, aud_feat=None, pre_latents=None, pre_audio=None):
155
- param = next(self.parameters())
156
- x = torch.zeros(
157
- (batch_size, *shape),
158
- dtype=torch.int64, device=param.device
159
- )
160
- if pre_latents is not None:
161
- x = torch.cat([pre_latents, x], dim=1)
162
- aud_feat = torch.cat([pre_audio, aud_feat], dim=2)
163
- h0 = pre_latents.shape[1]
164
- h = h0 + shape[0]
165
- else:
166
- h0 = 0
167
- h = shape[0]
168
-
169
- for i in range(h0, h):
170
- for j in range(shape[1]):
171
- if self.audio:
172
- logits = self.forward(x, label, aud_feat)
173
- else:
174
- logits = self.forward(x, label)
175
- probs = F.softmax(logits[:, :, i, j], -1)
176
- x.data[:, i, j].copy_(
177
- probs.multinomial(1).squeeze().data
178
- )
179
- return x[:, h0:h]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nets/spg/s2g_face.py DELETED
@@ -1,226 +0,0 @@
1
- '''
2
- not exactly the same as the official repo but the results are good
3
- '''
4
- import sys
5
- import os
6
-
7
- from transformers import Wav2Vec2Processor
8
-
9
- from .wav2vec import Wav2Vec2Model
10
- from torchaudio.sox_effects import apply_effects_tensor
11
-
12
- sys.path.append(os.getcwd())
13
-
14
- import numpy as np
15
- import torch
16
- import torch.nn as nn
17
- import torch.nn.functional as F
18
- import torchaudio as ta
19
- import math
20
- from nets.layers import SeqEncoder1D, SeqTranslator1D, ConvNormRelu
21
-
22
-
23
- """ from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """
24
-
25
-
26
- def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000):
27
- """
28
- :param audio: 1 x T tensor containing a 16kHz audio signal
29
- :param frame_rate: frame rate for video (we need one audio chunk per video frame)
30
- :param chunk_size: number of audio samples per chunk
31
- :return: num_chunks x chunk_size tensor containing sliced audio
32
- """
33
- samples_per_frame = 16000 // frame_rate
34
- padding = (chunk_size - samples_per_frame) // 2
35
- audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0)
36
- anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame))
37
- audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0)
38
- return audio
39
-
40
-
41
- class MeshtalkEncoder(nn.Module):
42
- def __init__(self, latent_dim: int = 128, model_name: str = 'audio_encoder'):
43
- """
44
- :param latent_dim: size of the latent audio embedding
45
- :param model_name: name of the model, used to load and save the model
46
- """
47
- super().__init__()
48
-
49
- self.melspec = ta.transforms.MelSpectrogram(
50
- sample_rate=16000, n_fft=2048, win_length=800, hop_length=160, n_mels=80
51
- )
52
-
53
- conv_len = 5
54
- self.convert_dimensions = torch.nn.Conv1d(80, 128, kernel_size=conv_len)
55
- self.weights_init(self.convert_dimensions)
56
- self.receptive_field = conv_len
57
-
58
- convs = []
59
- for i in range(6):
60
- dilation = 2 * (i % 3 + 1)
61
- self.receptive_field += (conv_len - 1) * dilation
62
- convs += [torch.nn.Conv1d(128, 128, kernel_size=conv_len, dilation=dilation)]
63
- self.weights_init(convs[-1])
64
- self.convs = torch.nn.ModuleList(convs)
65
- self.code = torch.nn.Linear(128, latent_dim)
66
-
67
- self.apply(lambda x: self.weights_init(x))
68
-
69
- def weights_init(self, m):
70
- if isinstance(m, torch.nn.Conv1d):
71
- torch.nn.init.xavier_uniform_(m.weight)
72
- try:
73
- torch.nn.init.constant_(m.bias, .01)
74
- except:
75
- pass
76
-
77
- def forward(self, audio: torch.Tensor):
78
- """
79
- :param audio: B x T x 16000 Tensor containing 1 sec of audio centered around the current time frame
80
- :return: code: B x T x latent_dim Tensor containing a latent audio code/embedding
81
- """
82
- B, T = audio.shape[0], audio.shape[1]
83
- x = self.melspec(audio).squeeze(1)
84
- x = torch.log(x.clamp(min=1e-10, max=None))
85
- if T == 1:
86
- x = x.unsqueeze(1)
87
-
88
- # Convert to the right dimensionality
89
- x = x.view(-1, x.shape[2], x.shape[3])
90
- x = F.leaky_relu(self.convert_dimensions(x), .2)
91
-
92
- # Process stacks
93
- for conv in self.convs:
94
- x_ = F.leaky_relu(conv(x), .2)
95
- if self.training:
96
- x_ = F.dropout(x_, .2)
97
- l = (x.shape[2] - x_.shape[2]) // 2
98
- x = (x[:, :, l:-l] + x_) / 2
99
-
100
- x = torch.mean(x, dim=-1)
101
- x = x.view(B, T, x.shape[-1])
102
- x = self.code(x)
103
-
104
- return {"code": x}
105
-
106
-
107
- class AudioEncoder(nn.Module):
108
- def __init__(self, in_dim, out_dim, identity=False, num_classes=0):
109
- super().__init__()
110
- self.identity = identity
111
- if self.identity:
112
- in_dim = in_dim + 64
113
- self.id_mlp = nn.Conv1d(num_classes, 64, 1, 1)
114
- self.first_net = SeqTranslator1D(in_dim, out_dim,
115
- min_layers_num=3,
116
- residual=True,
117
- norm='ln'
118
- )
119
- self.grus = nn.GRU(out_dim, out_dim, 1, batch_first=True)
120
- self.dropout = nn.Dropout(0.1)
121
- # self.att = nn.MultiheadAttention(out_dim, 4, dropout=0.1, batch_first=True)
122
-
123
- def forward(self, spectrogram, pre_state=None, id=None, time_steps=None):
124
-
125
- spectrogram = spectrogram
126
- spectrogram = self.dropout(spectrogram)
127
- if self.identity:
128
- id = id.reshape(id.shape[0], -1, 1).repeat(1, 1, spectrogram.shape[2]).to(torch.float32)
129
- id = self.id_mlp(id)
130
- spectrogram = torch.cat([spectrogram, id], dim=1)
131
- x1 = self.first_net(spectrogram)# .permute(0, 2, 1)
132
- if time_steps is not None:
133
- x1 = F.interpolate(x1, size=time_steps, align_corners=False, mode='linear')
134
- # x1, _ = self.att(x1, x1, x1)
135
- # x1, hidden_state = self.grus(x1)
136
- # x1 = x1.permute(0, 2, 1)
137
- hidden_state=None
138
-
139
- return x1, hidden_state
140
-
141
-
142
- class Generator(nn.Module):
143
- def __init__(self,
144
- n_poses,
145
- each_dim: list,
146
- dim_list: list,
147
- training=False,
148
- device=None,
149
- identity=True,
150
- num_classes=0,
151
- ):
152
- super().__init__()
153
-
154
- self.training = training
155
- self.device = device
156
- self.gen_length = n_poses
157
- self.identity = identity
158
-
159
- norm = 'ln'
160
- in_dim = 256
161
- out_dim = 256
162
-
163
- self.encoder_choice = 'faceformer'
164
-
165
- if self.encoder_choice == 'meshtalk':
166
- self.audio_encoder = MeshtalkEncoder(latent_dim=in_dim)
167
- elif self.encoder_choice == 'faceformer':
168
- # wav2vec 2.0 weights initialization
169
- self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h"
170
- self.audio_encoder.feature_extractor._freeze_parameters()
171
- self.audio_feature_map = nn.Linear(768, in_dim)
172
- else:
173
- self.audio_encoder = AudioEncoder(in_dim=64, out_dim=out_dim)
174
-
175
- self.audio_middle = AudioEncoder(in_dim, out_dim, identity, num_classes)
176
-
177
- self.dim_list = dim_list
178
-
179
- self.decoder = nn.ModuleList()
180
- self.final_out = nn.ModuleList()
181
-
182
- self.decoder.append(nn.Sequential(
183
- ConvNormRelu(out_dim, 64, norm=norm),
184
- ConvNormRelu(64, 64, norm=norm),
185
- ConvNormRelu(64, 64, norm=norm),
186
- ))
187
- self.final_out.append(nn.Conv1d(64, each_dim[0], 1, 1))
188
-
189
- self.decoder.append(nn.Sequential(
190
- ConvNormRelu(out_dim, out_dim, norm=norm),
191
- ConvNormRelu(out_dim, out_dim, norm=norm),
192
- ConvNormRelu(out_dim, out_dim, norm=norm),
193
- ))
194
- self.final_out.append(nn.Conv1d(out_dim, each_dim[3], 1, 1))
195
-
196
- def forward(self, in_spec, gt_poses=None, id=None, pre_state=None, time_steps=None):
197
- if self.training:
198
- time_steps = gt_poses.shape[1]
199
-
200
- # vector, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
201
- if self.encoder_choice == 'meshtalk':
202
- in_spec = audio_chunking(in_spec.squeeze(-1), frame_rate=30, chunk_size=16000)
203
- feature = self.audio_encoder(in_spec.unsqueeze(0))["code"].transpose(1, 2)
204
- elif self.encoder_choice == 'faceformer':
205
- hidden_states = self.audio_encoder(in_spec.reshape(in_spec.shape[0], -1), frame_num=time_steps).last_hidden_state
206
- feature = self.audio_feature_map(hidden_states).transpose(1, 2)
207
- else:
208
- feature, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
209
-
210
- # hidden_states = in_spec
211
-
212
- feature, _ = self.audio_middle(feature, id=id)
213
-
214
- out = []
215
-
216
- for i in range(self.decoder.__len__()):
217
- mid = self.decoder[i](feature)
218
- mid = self.final_out[i](mid)
219
- out.append(mid)
220
-
221
- out = torch.cat(out, dim=1)
222
- out = out.transpose(1, 2)
223
-
224
- return out, None
225
-
226
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nets/spg/s2glayers.py DELETED
@@ -1,522 +0,0 @@
1
- '''
2
- not exactly the same as the official repo but the results are good
3
- '''
4
- import sys
5
- import os
6
-
7
- sys.path.append(os.getcwd())
8
-
9
- import numpy as np
10
- import torch
11
- import torch.nn as nn
12
- import torch.nn.functional as F
13
- import math
14
- from nets.layers import SeqEncoder1D, SeqTranslator1D
15
-
16
- """ from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """
17
-
18
-
19
- class Conv2d_tf(nn.Conv2d):
20
- """
21
- Conv2d with the padding behavior from TF
22
- from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py
23
- """
24
-
25
- def __init__(self, *args, **kwargs):
26
- super(Conv2d_tf, self).__init__(*args, **kwargs)
27
- self.padding = kwargs.get("padding", "SAME")
28
-
29
- def _compute_padding(self, input, dim):
30
- input_size = input.size(dim + 2)
31
- filter_size = self.weight.size(dim + 2)
32
- effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1
33
- out_size = (input_size + self.stride[dim] - 1) // self.stride[dim]
34
- total_padding = max(
35
- 0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size
36
- )
37
- additional_padding = int(total_padding % 2 != 0)
38
-
39
- return additional_padding, total_padding
40
-
41
- def forward(self, input):
42
- if self.padding == "VALID":
43
- return F.conv2d(
44
- input,
45
- self.weight,
46
- self.bias,
47
- self.stride,
48
- padding=0,
49
- dilation=self.dilation,
50
- groups=self.groups,
51
- )
52
- rows_odd, padding_rows = self._compute_padding(input, dim=0)
53
- cols_odd, padding_cols = self._compute_padding(input, dim=1)
54
- if rows_odd or cols_odd:
55
- input = F.pad(input, [0, cols_odd, 0, rows_odd])
56
-
57
- return F.conv2d(
58
- input,
59
- self.weight,
60
- self.bias,
61
- self.stride,
62
- padding=(padding_rows // 2, padding_cols // 2),
63
- dilation=self.dilation,
64
- groups=self.groups,
65
- )
66
-
67
-
68
- class Conv1d_tf(nn.Conv1d):
69
- """
70
- Conv1d with the padding behavior from TF
71
- modified from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py
72
- """
73
-
74
- def __init__(self, *args, **kwargs):
75
- super(Conv1d_tf, self).__init__(*args, **kwargs)
76
- self.padding = kwargs.get("padding")
77
-
78
- def _compute_padding(self, input, dim):
79
- input_size = input.size(dim + 2)
80
- filter_size = self.weight.size(dim + 2)
81
- effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1
82
- out_size = (input_size + self.stride[dim] - 1) // self.stride[dim]
83
- total_padding = max(
84
- 0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size
85
- )
86
- additional_padding = int(total_padding % 2 != 0)
87
-
88
- return additional_padding, total_padding
89
-
90
- def forward(self, input):
91
- # if self.padding == "valid":
92
- # return F.conv1d(
93
- # input,
94
- # self.weight,
95
- # self.bias,
96
- # self.stride,
97
- # padding=0,
98
- # dilation=self.dilation,
99
- # groups=self.groups,
100
- # )
101
- rows_odd, padding_rows = self._compute_padding(input, dim=0)
102
- if rows_odd:
103
- input = F.pad(input, [0, rows_odd])
104
-
105
- return F.conv1d(
106
- input,
107
- self.weight,
108
- self.bias,
109
- self.stride,
110
- padding=(padding_rows // 2),
111
- dilation=self.dilation,
112
- groups=self.groups,
113
- )
114
-
115
-
116
- def ConvNormRelu(in_channels, out_channels, type='1d', downsample=False, k=None, s=None, padding='valid', groups=1,
117
- nonlinear='lrelu', bn='bn'):
118
- if k is None and s is None:
119
- if not downsample:
120
- k = 3
121
- s = 1
122
- padding = 'same'
123
- else:
124
- k = 4
125
- s = 2
126
- padding = 'valid'
127
-
128
- if type == '1d':
129
- conv_block = Conv1d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding, groups=groups)
130
- norm_block = nn.BatchNorm1d(out_channels)
131
- elif type == '2d':
132
- conv_block = Conv2d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding, groups=groups)
133
- norm_block = nn.BatchNorm2d(out_channels)
134
- else:
135
- assert False
136
- if bn != 'bn':
137
- if bn == 'gn':
138
- norm_block = nn.GroupNorm(1, out_channels)
139
- elif bn == 'ln':
140
- norm_block = nn.LayerNorm(out_channels)
141
- else:
142
- norm_block = nn.Identity()
143
- if nonlinear == 'lrelu':
144
- nlinear = nn.LeakyReLU(0.2, True)
145
- elif nonlinear == 'tanh':
146
- nlinear = nn.Tanh()
147
- elif nonlinear == 'none':
148
- nlinear = nn.Identity()
149
-
150
- return nn.Sequential(
151
- conv_block,
152
- norm_block,
153
- nlinear
154
- )
155
-
156
-
157
- class UnetUp(nn.Module):
158
- def __init__(self, in_ch, out_ch):
159
- super(UnetUp, self).__init__()
160
- self.conv = ConvNormRelu(in_ch, out_ch)
161
-
162
- def forward(self, x1, x2):
163
- # x1 = torch.repeat_interleave(x1, 2, dim=2)
164
- # x1 = x1[:, :, :x2.shape[2]]
165
- x1 = torch.nn.functional.interpolate(x1, size=x2.shape[2], mode='linear')
166
- x = x1 + x2
167
- x = self.conv(x)
168
- return x
169
-
170
-
171
- class UNet(nn.Module):
172
- def __init__(self, input_dim, dim):
173
- super(UNet, self).__init__()
174
- # dim = 512
175
- self.down1 = nn.Sequential(
176
- ConvNormRelu(input_dim, input_dim, '1d', False),
177
- ConvNormRelu(input_dim, dim, '1d', False),
178
- ConvNormRelu(dim, dim, '1d', False)
179
- )
180
- self.gru = nn.GRU(dim, dim, 1, batch_first=True)
181
- self.down2 = ConvNormRelu(dim, dim, '1d', True)
182
- self.down3 = ConvNormRelu(dim, dim, '1d', True)
183
- self.down4 = ConvNormRelu(dim, dim, '1d', True)
184
- self.down5 = ConvNormRelu(dim, dim, '1d', True)
185
- self.down6 = ConvNormRelu(dim, dim, '1d', True)
186
- self.up1 = UnetUp(dim, dim)
187
- self.up2 = UnetUp(dim, dim)
188
- self.up3 = UnetUp(dim, dim)
189
- self.up4 = UnetUp(dim, dim)
190
- self.up5 = UnetUp(dim, dim)
191
-
192
- def forward(self, x1, pre_pose=None, w_pre=False):
193
- x2_0 = self.down1(x1)
194
- if w_pre:
195
- i = 1
196
- x2_pre = self.gru(x2_0[:,:,0:i].permute(0,2,1), pre_pose[:,:,-1:].permute(2,0,1).contiguous())[0].permute(0,2,1)
197
- x2 = torch.cat([x2_pre, x2_0[:,:,i:]], dim=-1)
198
- # x2 = torch.cat([pre_pose, x2_0], dim=2) # [B, 512, 15]
199
- else:
200
- # x2 = self.gru(x2_0.transpose(1, 2))[0].transpose(1,2)
201
- x2 = x2_0
202
- x3 = self.down2(x2)
203
- x4 = self.down3(x3)
204
- x5 = self.down4(x4)
205
- x6 = self.down5(x5)
206
- x7 = self.down6(x6)
207
- x = self.up1(x7, x6)
208
- x = self.up2(x, x5)
209
- x = self.up3(x, x4)
210
- x = self.up4(x, x3)
211
- x = self.up5(x, x2) # [B, 512, 15]
212
- return x, x2_0
213
-
214
-
215
- class AudioEncoder(nn.Module):
216
- def __init__(self, n_frames, template_length, pose=False, common_dim=512):
217
- super().__init__()
218
- self.n_frames = n_frames
219
- self.pose = pose
220
- self.step = 0
221
- self.weight = 0
222
- if self.pose:
223
- # self.first_net = nn.Sequential(
224
- # ConvNormRelu(1, 64, '2d', False),
225
- # ConvNormRelu(64, 64, '2d', True),
226
- # ConvNormRelu(64, 128, '2d', False),
227
- # ConvNormRelu(128, 128, '2d', True),
228
- # ConvNormRelu(128, 256, '2d', False),
229
- # ConvNormRelu(256, 256, '2d', True),
230
- # ConvNormRelu(256, 256, '2d', False),
231
- # ConvNormRelu(256, 256, '2d', False, padding='VALID')
232
- # )
233
- # decoder_layer = nn.TransformerDecoderLayer(d_model=args.feature_dim, nhead=4,
234
- # dim_feedforward=2 * args.feature_dim, batch_first=True)
235
- # a = nn.TransformerDecoder
236
- self.first_net = SeqTranslator1D(256, 256,
237
- min_layers_num=4,
238
- residual=True
239
- )
240
- self.dropout_0 = nn.Dropout(0.1)
241
- self.mu_fc = nn.Conv1d(256, 128, 1, 1)
242
- self.var_fc = nn.Conv1d(256, 128, 1, 1)
243
- self.trans_motion = SeqTranslator1D(common_dim, common_dim,
244
- kernel_size=1,
245
- stride=1,
246
- min_layers_num=3,
247
- residual=True
248
- )
249
- # self.att = nn.MultiheadAttention(64 + template_length, 4, dropout=0.1)
250
- self.unet = UNet(128 + template_length, common_dim)
251
-
252
- else:
253
- self.first_net = SeqTranslator1D(256, 256,
254
- min_layers_num=4,
255
- residual=True
256
- )
257
- self.dropout_0 = nn.Dropout(0.1)
258
- # self.att = nn.MultiheadAttention(256, 4, dropout=0.1)
259
- self.unet = UNet(256, 256)
260
- self.dropout_1 = nn.Dropout(0.0)
261
-
262
- def forward(self, spectrogram, time_steps=None, template=None, pre_pose=None, w_pre=False):
263
- self.step = self.step + 1
264
- if self.pose:
265
- spect = spectrogram.transpose(1, 2)
266
- if w_pre:
267
- spect = spect[:, :, :]
268
-
269
- out = self.first_net(spect)
270
- out = self.dropout_0(out)
271
-
272
- mu = self.mu_fc(out)
273
- var = self.var_fc(out)
274
- audio = self.__reparam(mu, var)
275
- # audio = out
276
-
277
- # template = self.trans_motion(template)
278
- x1 = torch.cat([audio, template], dim=1)#.permute(2,0,1)
279
- # x1 = out
280
- #x1, _ = self.att(x1, x1, x1)
281
- #x1 = x1.permute(1,2,0)
282
- x1, x2_0 = self.unet(x1, pre_pose=pre_pose, w_pre=w_pre)
283
- else:
284
- spectrogram = spectrogram.transpose(1, 2)
285
- x1 = self.first_net(spectrogram)#.permute(2,0,1)
286
- #out, _ = self.att(out, out, out)
287
- #out = out.permute(1, 2, 0)
288
- x1 = self.dropout_0(x1)
289
- x1, x2_0 = self.unet(x1)
290
- x1 = self.dropout_1(x1)
291
- mu = None
292
- var = None
293
-
294
- return x1, (mu, var), x2_0
295
-
296
- def __reparam(self, mu, log_var):
297
- std = torch.exp(0.5 * log_var)
298
- eps = torch.randn_like(std, device='cuda')
299
- z = eps * std + mu
300
- return z
301
-
302
-
303
- class Generator(nn.Module):
304
- def __init__(self,
305
- n_poses,
306
- pose_dim,
307
- pose,
308
- n_pre_poses,
309
- each_dim: list,
310
- dim_list: list,
311
- use_template=False,
312
- template_length=0,
313
- training=False,
314
- device=None,
315
- separate=False,
316
- expression=False
317
- ):
318
- super().__init__()
319
-
320
- self.use_template = use_template
321
- self.template_length = template_length
322
- self.training = training
323
- self.device = device
324
- self.separate = separate
325
- self.pose = pose
326
- self.decoderf = True
327
- self.expression = expression
328
-
329
- common_dim = 256
330
-
331
- if self.use_template:
332
- assert template_length > 0
333
- # self.KLLoss = KLLoss(kl_tolerance=self.config.Train.weights.kl_tolerance).to(self.device)
334
- # self.pose_encoder = SeqEncoder1D(
335
- # C_in=pose_dim,
336
- # C_out=512,
337
- # T_in=n_poses,
338
- # min_layer_nums=6
339
- #
340
- # )
341
- self.pose_encoder = SeqTranslator1D(pose_dim - 50, common_dim,
342
- # kernel_size=1,
343
- # stride=1,
344
- min_layers_num=3,
345
- residual=True
346
- )
347
- self.mu_fc = nn.Conv1d(common_dim, template_length, kernel_size=1, stride=1)
348
- self.var_fc = nn.Conv1d(common_dim, template_length, kernel_size=1, stride=1)
349
-
350
- else:
351
- self.template_length = 0
352
-
353
- self.gen_length = n_poses
354
-
355
- self.audio_encoder = AudioEncoder(n_poses, template_length, True, common_dim)
356
- self.speech_encoder = AudioEncoder(n_poses, template_length, False)
357
-
358
- # self.pre_pose_encoder = SeqEncoder1D(
359
- # C_in=pose_dim,
360
- # C_out=128,
361
- # T_in=15,
362
- # min_layer_nums=3
363
- #
364
- # )
365
- # self.pmu_fc = nn.Linear(128, 64)
366
- # self.pvar_fc = nn.Linear(128, 64)
367
-
368
- self.pre_pose_encoder = SeqTranslator1D(pose_dim-50, common_dim,
369
- min_layers_num=5,
370
- residual=True
371
- )
372
- self.decoder_in = 256 + 64
373
- self.dim_list = dim_list
374
-
375
- if self.separate:
376
- self.decoder = nn.ModuleList()
377
- self.final_out = nn.ModuleList()
378
-
379
- self.decoder.append(nn.Sequential(
380
- ConvNormRelu(256, 64),
381
- ConvNormRelu(64, 64),
382
- ConvNormRelu(64, 64),
383
- ))
384
- self.final_out.append(nn.Conv1d(64, each_dim[0], 1, 1))
385
-
386
- self.decoder.append(nn.Sequential(
387
- ConvNormRelu(common_dim, common_dim),
388
- ConvNormRelu(common_dim, common_dim),
389
- ConvNormRelu(common_dim, common_dim),
390
- ))
391
- self.final_out.append(nn.Conv1d(common_dim, each_dim[1], 1, 1))
392
-
393
- self.decoder.append(nn.Sequential(
394
- ConvNormRelu(common_dim, common_dim),
395
- ConvNormRelu(common_dim, common_dim),
396
- ConvNormRelu(common_dim, common_dim),
397
- ))
398
- self.final_out.append(nn.Conv1d(common_dim, each_dim[2], 1, 1))
399
-
400
- if self.expression:
401
- self.decoder.append(nn.Sequential(
402
- ConvNormRelu(256, 256),
403
- ConvNormRelu(256, 256),
404
- ConvNormRelu(256, 256),
405
- ))
406
- self.final_out.append(nn.Conv1d(256, each_dim[3], 1, 1))
407
- else:
408
- self.decoder = nn.Sequential(
409
- ConvNormRelu(self.decoder_in, 512),
410
- ConvNormRelu(512, 512),
411
- ConvNormRelu(512, 512),
412
- ConvNormRelu(512, 512),
413
- ConvNormRelu(512, 512),
414
- ConvNormRelu(512, 512),
415
- )
416
- self.final_out = nn.Conv1d(512, pose_dim, 1, 1)
417
-
418
- def __reparam(self, mu, log_var):
419
- std = torch.exp(0.5 * log_var)
420
- eps = torch.randn_like(std, device=self.device)
421
- z = eps * std + mu
422
- return z
423
-
424
- def forward(self, in_spec, pre_poses, gt_poses, template=None, time_steps=None, w_pre=False, norm=True):
425
- if time_steps is not None:
426
- self.gen_length = time_steps
427
-
428
- if self.use_template:
429
- if self.training:
430
- if w_pre:
431
- in_spec = in_spec[:, 15:, :]
432
- pre_pose = self.pre_pose_encoder(gt_poses[:, 14:15, :-50].permute(0, 2, 1))
433
- pose_enc = self.pose_encoder(gt_poses[:, 15:, :-50].permute(0, 2, 1))
434
- mu = self.mu_fc(pose_enc)
435
- var = self.var_fc(pose_enc)
436
- template = self.__reparam(mu, var)
437
- else:
438
- pre_pose = None
439
- pose_enc = self.pose_encoder(gt_poses[:, :, :-50].permute(0, 2, 1))
440
- mu = self.mu_fc(pose_enc)
441
- var = self.var_fc(pose_enc)
442
- template = self.__reparam(mu, var)
443
- elif pre_poses is not None:
444
- if w_pre:
445
- pre_pose = pre_poses[:, -1:, :-50]
446
- if norm:
447
- pre_pose = pre_pose.reshape(1, -1, 55, 5)
448
- pre_pose = torch.cat([F.normalize(pre_pose[..., :3], dim=-1),
449
- F.normalize(pre_pose[..., 3:5], dim=-1)],
450
- dim=-1).reshape(1, -1, 275)
451
- pre_pose = self.pre_pose_encoder(pre_pose.permute(0, 2, 1))
452
- template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length ]).to(
453
- in_spec.device)
454
- else:
455
- pre_pose = None
456
- template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length]).to(in_spec.device)
457
- elif gt_poses is not None:
458
- template = self.pre_pose_encoder(gt_poses[:, :, :-50].permute(0, 2, 1))
459
- elif template is None:
460
- pre_pose = None
461
- template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length]).to(in_spec.device)
462
- else:
463
- template = None
464
- mu = None
465
- var = None
466
-
467
- a_t_f, (mu2, var2), x2_0 = self.audio_encoder(in_spec, time_steps=time_steps, template=template, pre_pose=pre_pose, w_pre=w_pre)
468
- s_f, _, _ = self.speech_encoder(in_spec, time_steps=time_steps)
469
-
470
- out = []
471
-
472
- if self.separate:
473
- for i in range(self.decoder.__len__()):
474
- if i == 0 or i == 3:
475
- mid = self.decoder[i](s_f)
476
- else:
477
- mid = self.decoder[i](a_t_f)
478
- mid = self.final_out[i](mid)
479
- out.append(mid)
480
- out = torch.cat(out, dim=1)
481
-
482
- else:
483
- out = self.decoder(a_t_f)
484
- out = self.final_out(out)
485
-
486
- out = out.transpose(1, 2)
487
-
488
- if self.training:
489
- if w_pre:
490
- return out, template, mu, var, (mu2, var2, x2_0, pre_pose)
491
- else:
492
- return out, template, mu, var, (mu2, var2, None, None)
493
- else:
494
- return out
495
-
496
-
497
- class Discriminator(nn.Module):
498
- def __init__(self, pose_dim, pose):
499
- super().__init__()
500
- self.net = nn.Sequential(
501
- Conv1d_tf(pose_dim, 64, kernel_size=4, stride=2, padding='SAME'),
502
- nn.LeakyReLU(0.2, True),
503
- ConvNormRelu(64, 128, '1d', True),
504
- ConvNormRelu(128, 256, '1d', k=4, s=1),
505
- Conv1d_tf(256, 1, kernel_size=4, stride=1, padding='SAME'),
506
- )
507
-
508
- def forward(self, x):
509
- x = x.transpose(1, 2)
510
-
511
- out = self.net(x)
512
- return out
513
-
514
-
515
- def main():
516
- d = Discriminator(275, 55)
517
- x = torch.randn([8, 60, 275])
518
- result = d(x)
519
-
520
-
521
- if __name__ == "__main__":
522
- main()