multimodalart HF Staff commited on
Commit
7758cff
·
verified ·
1 Parent(s): cfcc2fd

Upload 247 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +8 -0
  2. configs/audio2motion/inference/inference.yaml +35 -0
  3. configs/audio2motion/model/audio_processer_config.yaml +36 -0
  4. configs/audio2motion/model/config.yaml +59 -0
  5. configs/audio2motion/model/crop_config.yaml +21 -0
  6. configs/audio2motion/model/liveportrait_config.yaml +59 -0
  7. configs/audio2motion/model/models.yaml +43 -0
  8. requirements.txt +45 -0
  9. src/datasets/mean.pt +3 -0
  10. src/datasets/preprocess/__pycache__/flow_filter.cpython-310.pyc +0 -0
  11. src/datasets/preprocess/__pycache__/video_crop.cpython-310.pyc +0 -0
  12. src/datasets/preprocess/__pycache__/visualize.cpython-310.pyc +0 -0
  13. src/datasets/preprocess/extract_features/__pycache__/audio_processer.cpython-310.pyc +0 -0
  14. src/datasets/preprocess/extract_features/__pycache__/audio_processer.cpython-312.pyc +0 -0
  15. src/datasets/preprocess/extract_features/__pycache__/feature_extractor_pipeline.cpython-310.pyc +0 -0
  16. src/datasets/preprocess/extract_features/__pycache__/motion_processer.cpython-310.pyc +0 -0
  17. src/datasets/preprocess/extract_features/__pycache__/test_processer.cpython-310.pyc +0 -0
  18. src/datasets/preprocess/extract_features/audio_processer.py +471 -0
  19. src/datasets/preprocess/extract_features/face_segmentation/__init__.py +88 -0
  20. src/datasets/preprocess/extract_features/face_segmentation/__pycache__/__init__.cpython-310.pyc +0 -0
  21. src/datasets/preprocess/extract_features/face_segmentation/__pycache__/bisenet.cpython-310.pyc +0 -0
  22. src/datasets/preprocess/extract_features/face_segmentation/__pycache__/resnet.cpython-310.pyc +0 -0
  23. src/datasets/preprocess/extract_features/face_segmentation/bisenet.py +285 -0
  24. src/datasets/preprocess/extract_features/face_segmentation/resnet.py +113 -0
  25. src/datasets/preprocess/extract_features/motion_processer.py +1420 -0
  26. src/examples/driving_audios/10.wav +3 -0
  27. src/examples/driving_audios/5.wav +3 -0
  28. src/examples/driving_audios/6.wav +3 -0
  29. src/examples/driving_audios/tmp_5.wav +3 -0
  30. src/examples/reference_images/1.jpg +3 -0
  31. src/examples/reference_images/2.jpg +0 -0
  32. src/examples/reference_images/3.jpg +0 -0
  33. src/examples/reference_images/4.jpg +0 -0
  34. src/examples/reference_images/5.jpg +0 -0
  35. src/examples/reference_images/6.jpg +0 -0
  36. src/examples/reference_images/7.jpg +3 -0
  37. src/examples/silent-audio.wav +3 -0
  38. src/models/audio/__pycache__/audio_processer.cpython-310.pyc +0 -0
  39. src/models/audio/__pycache__/audio_proj.cpython-310.pyc +0 -0
  40. src/models/audio/__pycache__/hubert.cpython-310.pyc +0 -0
  41. src/models/audio/__pycache__/wav2vec.cpython-310.pyc +0 -0
  42. src/models/audio/__pycache__/wav2vec2.cpython-310.pyc +0 -0
  43. src/models/audio/__pycache__/wav2vec_modified.cpython-310.pyc +0 -0
  44. src/models/audio/audio_processer.py +407 -0
  45. src/models/audio/audio_proj.py +124 -0
  46. src/models/audio/hubert.py +120 -0
  47. src/models/audio/hubert2.py +120 -0
  48. src/models/audio/wav2vec.py +210 -0
  49. src/models/audio/wav2vec2.py +123 -0
  50. src/models/audio/wav2vec_modified.py +223 -0
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ src/examples/driving_audios/10.wav filter=lfs diff=lfs merge=lfs -text
37
+ src/examples/driving_audios/5.wav filter=lfs diff=lfs merge=lfs -text
38
+ src/examples/driving_audios/6.wav filter=lfs diff=lfs merge=lfs -text
39
+ src/examples/driving_audios/tmp_5.wav filter=lfs diff=lfs merge=lfs -text
40
+ src/examples/reference_images/1.jpg filter=lfs diff=lfs merge=lfs -text
41
+ src/examples/reference_images/7.jpg filter=lfs diff=lfs merge=lfs -text
42
+ src/examples/silent-audio.wav filter=lfs diff=lfs merge=lfs -text
43
+ src/thirdparty/liveportrait/src/utils/dependencies/insightface/data/images/t1.jpg filter=lfs diff=lfs merge=lfs -text
configs/audio2motion/inference/inference.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ output_fps: 25
3
+ ## appearance and motion feature extractor
4
+ appearance_feature_extractor_path: pretrain_weights/decode/v1/first_stage/base_models/appearance_feature_extractor.pth
5
+ motion_extractor_path: pretrain_weights/decode/v1/first_stage/base_models/motion_extractor.pth
6
+ ## SPADEGenerator
7
+ spade_generator_path: pretrain_weights/decode/v1/first_stage/base_models/spade_generator.pth
8
+ warping_module_path: pretrain_weights/decode/v1/first_stage/base_models/warping_module.pth
9
+ ## stitching retargeting module
10
+ stitching_retargeting_module_path: pretrain_weights/decode/v1/first_stage/retargeting_models/stitching_retargeting_module.pth
11
+ #
12
+
13
+ # audio processer config
14
+ audio_model_config: configs/audio2motion/model/audio_processer_config.yaml
15
+
16
+ # motion processer config
17
+ motion_processer_config: configs/audio2motion/model/liveportrait_config.yaml
18
+
19
+ # motion generator model
20
+ motion_models_config: configs/audio2motion/model/config.yaml
21
+ use_ref_kp: False
22
+ motion_generator_path: pretrain_weights/moda/net-200.pth
23
+ need_normalized: True
24
+
25
+ # other configs
26
+ device_id: 0
27
+ batch_size: 100
28
+
29
+ source_max_dim: 1280 # the max dim of height and width of source image or video
30
+ source_division: 2 # make sure the height and width of source image or video can be divided by this number
31
+ input_height: 256
32
+ input_width: 256
33
+ source_fps: 25
34
+ min_video_length: 50
35
+ max_video_length: 500
configs/audio2motion/model/audio_processer_config.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models settings
2
+ model_params:
3
+ model_name: hubert # wav2vec or hubert
4
+ model_type: base # base large
5
+ is_chinese: True
6
+ is_original: True
7
+ only_last_features: False
8
+ use_audio_separator: False
9
+ audio_separator_name: Kim_Vocal_2.onnx
10
+
11
+ # model weights
12
+ model_weights:
13
+ audio_separator_path: pretrain_weights/audio/audio_separator
14
+ hubert_path:
15
+ chinese:
16
+ base: pretrain_weights/audio/chinese-hubert-base
17
+ # data settings
18
+ data_params:
19
+ sample_rate: 16000
20
+ max_length: 60 # seconds
21
+ sub_clip_length: 3000 # samples
22
+ fps: 25
23
+ sample_strategy: "presample"
24
+ audio_pad_mode: replicate # pad mode for audio, replicate or zero
25
+ save_to_cpu: True # saving gpu memory
26
+
27
+ # device settings
28
+ device_params:
29
+ device_id: 0
30
+ flag_force_cpu: False
31
+ flag_use_half_precision: False
32
+
33
+ cache_dir: preprocessed/HDTF/vocals
34
+ tmp_dir: src/tmp
35
+
36
+
configs/audio2motion/model/config.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: TalkingHeadDiT-B
2
+ audio_projector:
3
+ type: MLP
4
+ pretrained_model_path: None
5
+ device: cuda
6
+ params:
7
+ model_name: MLP-S-3
8
+ sequence_length: 1
9
+ blocks: 12
10
+ audio_feat_dim: 768
11
+ keypoint_dim: 63
12
+ feature_dim: 512
13
+ output_dim: 256
14
+ context_tokens: 1
15
+ audio_embedder_type: simple
16
+ audio_cond_dim: 63
17
+ motion_generator:
18
+ type: DiT
19
+ pretrained_model_path: None
20
+ device: cuda
21
+ params:
22
+ model_name: DiT-S-8-8
23
+ architecture: decoder
24
+ use_emo: True
25
+ input_dim: 70
26
+ output_dim: 70
27
+ exp_dim: 63
28
+ n_prev_frames: 1
29
+ n_pred_frames: 80
30
+ use_indicator: False
31
+ feature_dim: 256
32
+ n_heads: 8
33
+ n_layers: 8
34
+ mlp_ratio: 4
35
+ no_use_learnable_pe: True
36
+ norm_type: rms_norm # [rms_norm|layer_norm]
37
+ qk_norm: rms_norm # [rms_norm|layer_norm|null]
38
+ steps: 1000
39
+ noise_scheduler:
40
+ type: flow_matching
41
+ sample_mode: sample
42
+ device: cuda
43
+ params:
44
+ time_shifting: True
45
+ num_train_timesteps: 1000
46
+ num_inference_steps: 10
47
+ eta: 0.2
48
+ beta_start: 0.0001
49
+ beta_end: 0.02
50
+ s: 0.008
51
+ mode: cosine
52
+ train:
53
+ audio_drop_prob: 0.3
54
+ cond_drop_prob: 0.2
55
+ motion_drop_prob: 0.3
56
+ audio_drop_ratio : 0.2
57
+ motion_drop_ratio: 0.1
58
+ pre_drop_ratio : 0.0
59
+ device_specific: True
configs/audio2motion/model/crop_config.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ insightface_root: pretrain_weights/decode/v1/insightface
2
+ landmark_ckpt_path: pretrain_weights/decode/v1/first_stage/landmark.onnx
3
+ xpose_config_file_path: src/utils/UniPose_SwinT.py
4
+ device_id: 0 # gpu device id
5
+ flag_force_cpu: False # force cpu inference, WIP
6
+ det_thresh: 0.15 # detection threshold
7
+ ########## source image or video cropping option ##########
8
+ dsize: 512 # crop size
9
+ scale: 2.3 # scale factor
10
+ vx_ratio: 0 # vx ratio
11
+ vy_ratio: -0.125 # vy ratio +up, -down
12
+ max_face_num: 0 # max face number, 0 mean no limit
13
+ flag_do_rot: True # whether to conduct the rotation when flag_do_crop is True
14
+ animal_face_type: animal_face_9 # animal_face_68 -> 68 landmark points, animal_face_9 -> 9 landmarks
15
+ ########## driving video auto cropping option ##########
16
+ scale_crop_driving_video: 2.2 # 2.0 # scale factor for cropping driving video
17
+ vx_ratio_crop_driving_video: 0.0 # adjust x offset
18
+ vy_ratio_crop_driving_video: -0.1 # adjust y offset
19
+ direction: large-small # direction of cropping
20
+ source_max_dim: 1280
21
+ source_division: 2
configs/audio2motion/model/liveportrait_config.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model config
2
+ models_config: configs/audio2motion/model/models.yaml
3
+
4
+ # 1. face appearance feature
5
+ appearance_feature_extractor_path: pretrain_weights/decode/v1/first_stage/base_models/appearance_feature_extractor.pth
6
+
7
+ # 2. motion feature
8
+ motion_extractor_path: pretrain_weights/decode/v1/first_stage/base_models/motion_extractor.pth
9
+
10
+ # 3. stitching retargeting module
11
+ stitching_retargeting_module_path: pretrain_weights/decode/v1/first_stage/retargeting_models/stitching_retargeting_module.pth
12
+
13
+ # 4. feature warper
14
+ warping_module_path: pretrain_weights/decode/v1/first_stage/base_models/warping_module.pth
15
+
16
+ # 5. SPADEGenerator
17
+ spade_generator_path: pretrain_weights/decode/v1/first_stage/base_models/spade_generator.pth
18
+
19
+ # 6. cropper
20
+ crop_cfg: "configs/audio2motion/model/crop_config.yaml"
21
+
22
+ # 7. face parser
23
+ face_parser_weight_path: "pretrain_weights/face/face-parsing/79999_iter.pth"
24
+ resnet_weight_path: "pretrain_weights/face/face-parsing/resnet18-5c106cde.pth"
25
+
26
+ # motion template
27
+ need_normalized: True
28
+
29
+ # others
30
+ batch_size: 100
31
+ source_max_dim: 1920 # the max dim of height and width of source image or video
32
+ source_division: 2 # make sure the height and width of source image or video can be divided by this number
33
+ input_height: 256
34
+ input_width: 256
35
+ output_height: 512
36
+ output_width: 512
37
+ output_fps: 25
38
+
39
+ # driving params
40
+ flag_do_torch_compile: False
41
+ flag_use_half_precision: True
42
+ flag_relative_motion: False
43
+ flag_normalize_lip: False
44
+ flag_source_video_eye_retargeting: False
45
+ flag_eye_retargeting: False
46
+ flag_lip_retargeting: False
47
+ flag_stitching: True
48
+
49
+ lip_normalize_threshold: 0.03 # threshold for flag_normalize_lip
50
+ source_video_eye_retargeting_threshold: 0.18 # threshold for eyes retargeting if the input is a source video
51
+ anchor_frame: 0 # TO IMPLEMENT
52
+
53
+ driving_option: "expression-friendly" # "expression-friendly" or "pose-friendly"
54
+ driving_multiplier: 1.0 # be used only when driving_option is "expression-friendly"
55
+ lib_multiplier: 1.0
56
+ driving_smooth_observation_variance: 3e-7 # the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
57
+ animation_region: "all" #["exp", "pose", "lip", "eyes", "all"], the region where the animation was performed, "exp" means the expression, "pose" means the head pose
58
+ mask_crop: src/utils/resources/mask_template.png
59
+ lip_array: src/utils/resources/lip_array.pkl
configs/audio2motion/model/models.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_params:
2
+ appearance_feature_extractor_params: # the F in the paper
3
+ image_channel: 3
4
+ block_expansion: 64
5
+ num_down_blocks: 2
6
+ max_features: 512
7
+ reshape_channel: 32
8
+ reshape_depth: 16
9
+ num_resblocks: 6
10
+ motion_extractor_params: # the M in the paper
11
+ num_kp: 21
12
+ backbone: convnextv2_tiny
13
+ warping_module_params: # the W in the paper
14
+ num_kp: 21
15
+ block_expansion: 64
16
+ max_features: 512
17
+ num_down_blocks: 2
18
+ reshape_channel: 32
19
+ estimate_occlusion_map: True
20
+ dense_motion_params:
21
+ block_expansion: 32
22
+ max_features: 1024
23
+ num_blocks: 5
24
+ reshape_depth: 16
25
+ compress: 4
26
+ spade_generator_params: # the G in the paper
27
+ upscale: 2 # represents upsample factor 256x256 -> 512x512
28
+ block_expansion: 64
29
+ max_features: 512
30
+ num_down_blocks: 2
31
+ stitching_retargeting_module_params: # the S in the paper
32
+ stitching:
33
+ input_size: 126 # (21*3)*2
34
+ hidden_sizes: [128, 128, 64]
35
+ output_size: 65 # (21*3)+2(tx,ty)
36
+ lip:
37
+ input_size: 65 # (21*3)+2
38
+ hidden_sizes: [128, 128, 64]
39
+ output_size: 63 # (21*3)
40
+ eye:
41
+ input_size: 66 # (21*3)+3
42
+ hidden_sizes: [256, 256, 128, 128, 64]
43
+ output_size: 63 # (21*3)
requirements.txt ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --find-links https://download.pytorch.org/whl/torch_stable.html
2
+
3
+ accelerate==0.28.0
4
+ audio-separator==0.17.2
5
+ av==12.1.0
6
+ bitsandbytes==0.43.1
7
+ decord==0.6.0
8
+ diffusers==0.27.2
9
+ einops==0.8.0
10
+ huggingface==0.0.1
11
+ huggingface-hub==0.25.1
12
+ insightface==0.7.3
13
+ librosa==0.10.2.post1
14
+ mediapipe[vision]==0.10.14
15
+ mlflow==2.13.1
16
+ moviepy==1.0.3
17
+ numpy==1.26.4
18
+ omegaconf==2.3.0
19
+ onnx2torch==1.5.14
20
+ onnx==1.16.1
21
+ onnxruntime-gpu==1.18.0
22
+ opencv-python==4.10.0.84
23
+ pillow==10.3.0
24
+ pyyaml==6.0.1
25
+ setuptools==70.0.0
26
+ torch==2.2.2+cu121
27
+ torchaudio==2.2.2
28
+ torchvision==0.17.2+cu121
29
+ transformers==4.39.2
30
+ xformers==0.0.25.post1
31
+ isort==5.13.2
32
+ pre-commit==3.7.1
33
+ scipy==1.13.1
34
+ imageio==2.34.2
35
+ lmdb==1.4.1
36
+ rich==13.7.1
37
+ ffmpeg-python==0.2.0
38
+ scikit-image==0.24.0
39
+ albumentations==1.4.10
40
+ matplotlib==3.9.0
41
+ imageio-ffmpeg==0.5.1
42
+ tyro==0.8.5
43
+ gradio==5.1.0
44
+ pykalman==0.9.7
45
+ tensorboardX==2.6.2.2
src/datasets/mean.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db742e76a39bbf81fb5b09fcc488bad0cbab9355df509d8e91967b58d02c6dfc
3
+ size 2582
src/datasets/preprocess/__pycache__/flow_filter.cpython-310.pyc ADDED
Binary file (7.38 kB). View file
 
src/datasets/preprocess/__pycache__/video_crop.cpython-310.pyc ADDED
Binary file (7.24 kB). View file
 
src/datasets/preprocess/__pycache__/visualize.cpython-310.pyc ADDED
Binary file (867 Bytes). View file
 
src/datasets/preprocess/extract_features/__pycache__/audio_processer.cpython-310.pyc ADDED
Binary file (13.2 kB). View file
 
src/datasets/preprocess/extract_features/__pycache__/audio_processer.cpython-312.pyc ADDED
Binary file (25.6 kB). View file
 
src/datasets/preprocess/extract_features/__pycache__/feature_extractor_pipeline.cpython-310.pyc ADDED
Binary file (20.1 kB). View file
 
src/datasets/preprocess/extract_features/__pycache__/motion_processer.cpython-310.pyc ADDED
Binary file (36.8 kB). View file
 
src/datasets/preprocess/extract_features/__pycache__/test_processer.cpython-310.pyc ADDED
Binary file (6.32 kB). View file
 
src/datasets/preprocess/extract_features/audio_processer.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from posixpath import isfile
4
+ from re import A
5
+ import sys
6
+ import os.path as osp
7
+
8
+ from typing import List, Dict, Tuple, Optional, Union, Any
9
+
10
+ import yaml
11
+ from omegaconf import OmegaConf
12
+
13
+ import math
14
+ import librosa
15
+ import soundfile
16
+ import numpy as np
17
+
18
+ from einops import rearrange
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+
23
+ from pydub import AudioSegment
24
+ from audio_separator.separator import Separator
25
+
26
+ from transformers import Wav2Vec2FeatureExtractor, HubertModel
27
+
28
+ from src.utils.rprint import rlog as log
29
+ from src.utils.util import resample_audio
30
+
31
+ from src.models.audio.wav2vec_modified import Wav2VecModel
32
+ from src.models.audio.hubert import HubertModel_ as HubertModel
33
+
34
+
35
+ def pad_audio(audio, audio_unit=320, pad_threshold=80):
36
+ batch_size, audio_len = audio.shape
37
+ n_units = audio_len // audio_unit
38
+ side_len = math.ceil((audio_unit * n_units + pad_threshold - audio_len) / 2)
39
+ if side_len >= 0:
40
+ reflect_len = side_len // 2
41
+ replicate_len = side_len % 2
42
+ if reflect_len > 0:
43
+ audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
44
+ audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
45
+ if replicate_len > 0:
46
+ audio = F.pad(audio, (1, 1), mode='replicate')
47
+
48
+ return audio
49
+
50
+
51
+ def cut_audio(audio_path: str, save_dir: str, length=60) -> List[str]:
52
+ """Cut audio into sub-divisions and return subfile paths. Supports wav format.
53
+
54
+ Args:
55
+ audio_path (str): the source audio file path
56
+ save_dir (str): the save directory of sub-divisions
57
+ length (int, optional): The max length of each sub-division. Defaults to 60 secs.
58
+
59
+ Returns:
60
+ List[str]: the subfile paths
61
+ """
62
+ audio_name = osp.basename(audio_path).split('.')[0]
63
+ audio = AudioSegment.from_wav(audio_path)
64
+ segment_length = length * 1000. # pydub uses milliseconds
65
+ num_segments = math.ceil(len(audio) / segment_length)
66
+
67
+ os.makedirs(save_dir, exist_ok=True)
68
+ audio_list = []
69
+
70
+ if num_segments > 1:
71
+ for i in range(num_segments):
72
+ start_time = i * segment_length
73
+ end_time = min((i + 1) * segment_length, len(audio))
74
+ segment = audio[start_time:end_time]
75
+
76
+ path = osp.join(save_dir, f"{audio_name}_segment_{i+1}.wav")
77
+ audio_list.append(path)
78
+ segment.export(path, format="wav")
79
+ else:
80
+ audio_list = [audio_path]
81
+ return audio_list
82
+
83
+
84
+ class AudioProcessor(object):
85
+ def __init__(self, cfg_path: str, is_training: bool = False, device_id=0) -> None:
86
+ cfg = OmegaConf.load(cfg_path)
87
+ self.cfg = cfg
88
+ self.is_training = is_training
89
+ log("========================================= Audio Processer =========================================")
90
+ log(OmegaConf.to_yaml(cfg))
91
+
92
+ # setting device
93
+ self.device_id = device_id
94
+ self.use_half = cfg.device_params.flag_use_half_precision
95
+ if cfg.device_params.flag_force_cpu:
96
+ self.device = 'cpu'
97
+ else:
98
+ try:
99
+ if torch.backends.mps.is_available():
100
+ self.device = 'mps'
101
+ else:
102
+ self.device = 'cuda:' + str(self.device_id)
103
+ except:
104
+ self.device = 'cuda:' + str(self.device_id)
105
+
106
+ # init audio separator
107
+ self.audio_separator = None
108
+ self.cache_dir = cfg.cache_dir
109
+ self.tmp_dir = cfg.tmp_dir
110
+ self.use_audio_separator = cfg.model_params.use_audio_separator
111
+ self.audio_separator_name = cfg.model_params.audio_separator_name
112
+ self.audio_separator_path = cfg.model_weights.audio_separator_path
113
+ self.set_audio_separator(cfg.cache_dir)
114
+
115
+ # load audio encoder, wav2vec or hubert
116
+ self.model_name = cfg.model_params.model_name
117
+ self.is_chinese = cfg.model_params.is_chinese
118
+ self.audio_encoder, self.feature_extractor = self.load_model(
119
+ model_name = cfg.model_params.model_name,
120
+ model_type = cfg.model_params.model_type,
121
+ is_chinese = cfg.model_params.is_chinese,
122
+ )
123
+ self.only_last_features = cfg.model_params.only_last_features
124
+ if cfg.model_params.only_last_features:
125
+ self.feature_shape = (1, 768)
126
+ else:
127
+ self.feature_shape = (12, 768) # features of 12 blocks
128
+
129
+ # init data params
130
+ self.sample_strategy = cfg.data_params.sample_strategy
131
+ self.sample_rate = cfg.data_params.sample_rate
132
+ self.fps = cfg.data_params.fps
133
+ self.audio_unit = cfg.data_params.sample_rate / cfg.data_params.fps # num of audio samples per frame
134
+ self.max_length = cfg.data_params.max_length
135
+ self.subclip_len = cfg.data_params.sub_clip_length
136
+ self.save_to_cpu = cfg.data_params.save_to_cpu
137
+ self.pad_mode = cfg.data_params.audio_pad_mode
138
+
139
+ log("========================================= Audio Processer: Done =========================================")
140
+
141
+ def load_model(self, model_name: str="wav2vec", model_type: str="base", is_chinese: bool = False):
142
+ assert model_name in ["wav2vec", "hubert"], f"Unknown audio model {model_name}, only support wav2vec or hubert"
143
+ assert model_type in ["base", "large"], f"Unknown audio model type {model_type}, only support base or large"
144
+
145
+ if model_name == "wav2vec":
146
+ # load wav2vec model weights
147
+ if is_chinese:
148
+ if model_type == "base":
149
+ model_weight_path = self.cfg.model_weights.wav2vec_path.chinese.base
150
+ else:
151
+ model_weight_path = self.cfg.model_weights.wav2vec_path.chinese.large
152
+ else:
153
+ if model_type == "base":
154
+ model_weight_path = self.cfg.model_weights.wav2vec_path.default.base
155
+ else:
156
+ model_weight_path = self.cfg.model_weights.wav2vec_path.default.large
157
+ if model_weight_path is None:
158
+ raise ValueError(f"model_weight_path is None")
159
+ audio_encoder = Wav2VecModel.from_pretrained(model_weight_path, local_files_only=True).to(device=self.device)
160
+ else:
161
+ if is_chinese:
162
+ if model_type == "base":
163
+ model_weight_path = self.cfg.model_weights.hubert_path.chinese.base
164
+ else:
165
+ model_weight_path = self.cfg.model_weights.hubert_path.chinese.large
166
+ else:
167
+ if model_type == "base":
168
+ model_weight_path = self.cfg.model_weights.hubert_path.default.base
169
+ else:
170
+ model_weight_path = self.cfg.model_weights.hubert_path.default.large
171
+ if model_weight_path is None:
172
+ raise ValueError(f"model_weight_path is None")
173
+ audio_encoder = HubertModel.from_pretrained(model_weight_path, local_files_only=True).to(device=self.device)
174
+
175
+ log(f"{model_name}-{model_type}-chinese-{is_chinese} model has beed loaded from {model_weight_path}")
176
+ total_params = sum(p.numel() for p in audio_encoder.parameters())
177
+ print('Number of parameter: % .4fM' % (total_params / 1e6))
178
+
179
+ # weights initialization
180
+ audio_encoder.feature_extractor._freeze_parameters()
181
+ if not self.cfg.model_params.is_original:
182
+ frozen_layers = [0, 1]
183
+ for name, param in audio_encoder.named_parameters():
184
+ if name.startswith("feature_projection"):
185
+ param.requires_grad = False
186
+ if name.startswith("encoder.layers"):
187
+ layer = int(name.split(".")[2])
188
+ if layer in frozen_layers:
189
+ param.requires_grad = False
190
+
191
+ audio_encoder = audio_encoder.to(self.device)
192
+ if self.use_half:
193
+ audio_encoder = audio_encoder.half()
194
+ audio_encoder.eval()
195
+
196
+ # feature extractor
197
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_weight_path)
198
+
199
+ return audio_encoder, feature_extractor
200
+
201
+ def set_audio_separator(self, output_dir: str) -> None:
202
+ del self.audio_separator
203
+
204
+ if self.audio_separator_name is not None and self.use_audio_separator:
205
+ try:
206
+ os.makedirs(output_dir, exist_ok=True)
207
+ except OSError as _:
208
+ print("Fail to create the output cache dir.")
209
+ self.audio_separator = Separator(
210
+ output_dir=output_dir,
211
+ output_single_stem="vocals",
212
+ model_file_dir=self.audio_separator_path,
213
+ )
214
+ self.audio_separator.load_model(self.audio_separator_name)
215
+ assert self.audio_separator.model_instance is not None, "Fail to load audio separate model."
216
+ else:
217
+ self.audio_separator=None
218
+ log("Use audio directly without vocals seperator.")
219
+
220
+ def seperate_audio(self, audio_path: str, output_dir: Union[str, None] = None) -> str:
221
+ if output_dir is not None:
222
+ if output_dir != self.cache_dir:
223
+ # reload audio separator
224
+ self.set_audio_separator(output_dir)
225
+
226
+ if self.audio_separator is not None:
227
+ # 1. separate vocals
228
+ # TODO: process in memory
229
+ try:
230
+ outputs = self.audio_separator.separate(audio_path)
231
+ if len(outputs) <= 0:
232
+ raise RuntimeError("Audio separate failed.")
233
+
234
+ vocal_audio_file = outputs[0]
235
+ vocal_audio_name, _ = os.path.splitext(vocal_audio_file)
236
+ vocal_audio_file = os.path.join(self.audio_separator.output_dir, vocal_audio_file)
237
+ vocal_audio_file = resample_audio(vocal_audio_file, os.path.join(self.audio_separator.output_dir, f"{vocal_audio_name}-16k.wav"), self.sample_rate)
238
+ except Exception as e:
239
+ log(f"Fail to separate vocals from {audio_path}, error info [{e}]")
240
+ vocal_audio_file=audio_path
241
+ else:
242
+ vocal_audio_file=audio_path
243
+
244
+ return vocal_audio_file
245
+
246
+ def load_audio(self, audio_path: str, mono: bool = True, duration: Optional[float] = None) -> Any:
247
+ try:
248
+ audio_data, sampling_rate = librosa.load(audio_path, sr=self.sample_rate, mono=mono, duration=duration)
249
+ except Exception as e:
250
+ raise RuntimeError(f"Fail to load audio from {audio_path}, error info [{e}]")
251
+ return audio_data, sampling_rate
252
+
253
+ def prepare_audio_data(self, audio_data: Union[np.ndarray, torch.Tensor], n_frames: Optional[int]=None) -> Tuple[List[Any], int]:
254
+ """Prepare audio data for processing.
255
+ """
256
+ #print(f"==========> Using Wav2Vec2FeatureExtractor to extract audio features")
257
+ audio_data = np.squeeze(self.feature_extractor(audio_data, sampling_rate=self.sample_rate).input_values)
258
+
259
+ clip_len = int(len(audio_data) / self.audio_unit)
260
+ if n_frames is not None:
261
+ if abs(n_frames - clip_len) > 7:
262
+ log(f"The number of frames must be close to the clip length (in 280ms), got {n_frames} and {clip_len}")
263
+ return [], n_frames
264
+ clip_len = n_frames
265
+ else:
266
+ n_frames = clip_len
267
+
268
+ if isinstance(audio_data, np.ndarray):
269
+ audio_data = torch.from_numpy(audio_data).float().to(self.device)
270
+ assert audio_data.ndim == 1, 'Audio must be 1D tensor.'
271
+
272
+ # padding
273
+ # padding audio to fit the clip length
274
+ n_audio_samples = round(self.audio_unit * clip_len)
275
+ n_padding_audio_samples = n_audio_samples - len(audio_data)
276
+ n_padding_frames = math.ceil(n_padding_audio_samples / self.audio_unit)
277
+ if n_padding_audio_samples > 0:
278
+ if self.pad_mode == 'zero':
279
+ padding_value = 0
280
+ elif self.pad_mode == 'replicate':
281
+ padding_value = float(audio_data[-1])
282
+ else:
283
+ raise ValueError(f'Unknown pad mode: {self.pad_mode}')
284
+ audio_data = F.pad(audio_data, (0, n_padding_audio_samples), value=padding_value)
285
+
286
+ # devide audio into sub-divisions for saving GPU memory
287
+ audio_segments = []
288
+ if clip_len <= self.subclip_len:
289
+ n_subdivision = 1
290
+ subclip_len = clip_len
291
+ else:
292
+ n_subdivision = math.ceil(clip_len / self.subclip_len)
293
+ subclip_len = self.subclip_len
294
+
295
+ for i in range(0, n_subdivision):
296
+ start_idx = i * subclip_len
297
+ end_idx = min(start_idx + subclip_len, clip_len)
298
+ # debug
299
+ #log(f"[{i+1}/{n_subdivision}] data index [{round(start_idx * self.audio_unit)}, {round(end_idx * self.audio_unit)})")
300
+ audio_segments.append(
301
+ {
302
+ "data": audio_data[round(start_idx * self.audio_unit):round(end_idx * self.audio_unit)].unsqueeze(0),
303
+ "start_idx": start_idx,
304
+ "end_idx": end_idx,
305
+ "length": end_idx - start_idx
306
+ }
307
+ )
308
+ return audio_segments, n_frames
309
+
310
+ def get_audio_embedding(self, audio, clip_len: int) -> torch.Tensor:
311
+ if audio.ndim == 2:
312
+ # Extract audio features
313
+ assert audio.shape[1] == 16000 * clip_len / self.fps, \
314
+ f'Incorrect audio length {audio.shape[1]}'
315
+
316
+ # Extract audio features
317
+ if self.use_half:
318
+ audio = audio.half()
319
+ embeddings = self.audio_encoder(
320
+ pad_audio(audio), seq_len=clip_len, sample_strategy=self.sample_strategy, output_hidden_states=True
321
+ ) # (N, L, 768)
322
+ assert len(embeddings) > 0, "Fail to extract audio embedding"
323
+
324
+ if self.only_last_features:
325
+ audio_emb = embeddings.last_hidden_state.squeeze(0)
326
+ else:
327
+ audio_emb = torch.stack(
328
+ embeddings.hidden_states[1:], dim=1
329
+ ).squeeze(0)
330
+ audio_emb = rearrange(audio_emb, "b s d -> s b d")
331
+
332
+ elif audio.ndim == 3:
333
+ assert audio.shape[1] == clip_len, f'Incorrect audio feature length {audio.shape[1]}'
334
+ audio_emb = audio
335
+ else:
336
+ raise ValueError(f'Incorrect audio input shape {audio.shape}')
337
+
338
+ return audio_emb
339
+
340
+ def get_audio_embeddings(self, audio_segments: List[Any]) -> Optional[torch.Tensor]:
341
+ audio_embs = []
342
+ for audio_segment in audio_segments:
343
+ if self.is_training:
344
+ audio_emb = self.get_audio_embedding(audio_segment["data"], audio_segment["length"])
345
+ else:
346
+ with torch.no_grad():
347
+ audio_emb = self.get_audio_embedding(audio_segment["data"], audio_segment["length"])
348
+
349
+ audio_emb = audio_emb.cpu() if self.save_to_cpu else audio_emb
350
+ audio_embs.append(audio_emb)
351
+ #log(f"audio segment [{audio_segment['start_idx']}, {audio_segment['end_idx']}) has been processed.")
352
+
353
+ if len(audio_embs) == 0:
354
+ return None
355
+
356
+ audio_emb = torch.cat(audio_embs, dim=0)
357
+
358
+ return audio_emb
359
+
360
+ def preprocess(
361
+ self,
362
+ audio_path: str,
363
+ n_frames: Optional[int] = None,
364
+ duration: Optional[float] = None,
365
+ need_seperate: bool = False
366
+ ):
367
+ """ Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate.
368
+ The separated vocal track is then converted into wav2vec2 for further processing or analysis.
369
+ """
370
+ if need_seperate:
371
+ vocal_audio_file = self.seperate_audio(audio_path)
372
+ else:
373
+ vocal_audio_file = audio_path
374
+
375
+ audio_data, sampling_rate = self.load_audio(vocal_audio_file, duration=duration)
376
+
377
+ assert sampling_rate == 16000, "The sample rate of audio must be 16000"
378
+ audio_segments, n_frames = self.prepare_audio_data(audio_data, n_frames)
379
+ audio_emb = self.get_audio_embeddings(audio_segments)
380
+ if audio_emb is None:
381
+ log(f"{audio_path} has been processed, but no audio embedding, set as 'None'.")
382
+ #else:
383
+ #log(f"{audio_path} has been processed, audio embedding shape {audio_emb.shape}.")
384
+ return audio_emb, n_frames
385
+
386
+ def preprocess_long(
387
+ self,
388
+ audio_path: str,
389
+ need_seperate: bool = False
390
+ ):
391
+ audio_list = cut_audio(audio_path, self.tmp_dir, length=self.max_length)
392
+ audio_emb_list = []
393
+ l = 0
394
+
395
+ for idx, audio_path in enumerate(audio_list):
396
+ padding = (idx+1) == len(audio_list)
397
+ emb, length = self.preprocess(audio_path, need_seperate=need_seperate)
398
+ audio_emb_list.append(emb)
399
+ log(f"Processing audio {idx+1}/{len(audio_list)}, path: {audio_path} length: {length}")
400
+ l += length
401
+
402
+ audio_emb = torch.cat(audio_emb_list)
403
+ audio_length = l
404
+
405
+ # remove tmp file
406
+ if len(audio_list) > 1:
407
+ for audio_path in audio_list:
408
+ os.remove(audio_path)
409
+
410
+ return audio_emb, audio_length
411
+
412
+ def add_silent_audio(self, audio_path: str, silent_audio_path: Optional[str] = None, add_duration: float = 1., linear_fusion=False, mode="post"):
413
+ # mode, pre, post, both
414
+ assert mode in ["pre", "post", "both"], f"Unkown mode: {mode}, only support pre, post, both"
415
+ if silent_audio_path is None:
416
+ return audio_path, 0
417
+ else:
418
+ audio_dir = osp.dirname(audio_path)
419
+ audio_name = osp.basename(audio_path)
420
+ temp_audio_path = osp.join(audio_dir, f"tmp_{audio_name}")
421
+ if osp.isfile(temp_audio_path):
422
+ os.remove(temp_audio_path)
423
+
424
+ audio, sr1 = librosa.load(audio_path, mono=True, sr=16000)
425
+ # denoise
426
+ audio = librosa.effects.preemphasis(audio) # enhance voice
427
+ # load silent audio
428
+ silent_audio, sr2 = librosa.load(silent_audio_path, mono=True, sr=16000)
429
+ silent_audio = silent_audio[:int(add_duration*sr2)]
430
+
431
+ if linear_fusion:
432
+ short_len = min(len(audio), len(silent_audio))
433
+ fusion_ratio = np.linspace(0, 1.0, num=short_len)
434
+ # get pre padding audio
435
+ pre_pad_audio = fusion_ratio * silent_audio[:short_len] + (1 - fusion_ratio) * audio[:short_len]
436
+ if short_len < len(silent_audio):
437
+ pre_pad_audio = np.hstack((pre_pad_audio, silent_audio[short_len:]))
438
+ pre_pad_audio = np.flip(pre_pad_audio, axis=0)
439
+
440
+ # get post padding audio
441
+ post_pad_audio = (1 - fusion_ratio) * silent_audio[-short_len:] + fusion_ratio * audio[-short_len:]
442
+ if short_len < len(silent_audio):
443
+ post_pad_audio = np.hstack((silent_audio[:-short_len], post_pad_audio))
444
+ post_pad_audio = np.flip(post_pad_audio, axis=0)
445
+ else:
446
+ pre_pad_audio = silent_audio
447
+ post_pad_audio = silent_audio
448
+
449
+ # padding audio
450
+ if mode == "both":
451
+ combined_audio = np.hstack((pre_pad_audio, audio, post_pad_audio))
452
+ elif mode == "pre":
453
+ combined_audio = np.hstack((pre_pad_audio, audio))
454
+ else:
455
+ combined_audio = np.hstack((audio, post_pad_audio))
456
+
457
+ add_nframes = math.floor(add_duration * sr2 / self.audio_unit)
458
+ #print(f"audio length: {len(audio)}, pre_pad_audio length: {len(pre_pad_audio)}, post_pad_audio length: {len(post_pad_audio)}, combined_length: {len(combined_audio)}, total add {add_nframes*2} frames")
459
+ #print(f"audio duration: {librosa.get_duration(audio, sr=sr1)}, silent duration: {librosa.get_duration(silent_audio, sr=sr2)}, combined duration: {librosa.get_duration(combined_audio, sr=sr2)}")
460
+ soundfile.write(temp_audio_path, combined_audio, sr2)
461
+
462
+ return temp_audio_path, add_nframes
463
+
464
+ def get_long_audio_emb(self, audio_path: str) -> torch.Tensor:
465
+ audio_emb, length = self.preprocess_long(audio_path)
466
+ log(f"Load audio from {osp.realpath(audio_path)} done, audio_emb shape: {audio_emb.shape}.")
467
+ return audio_emb
468
+
469
+ def __enter__(self):
470
+ return self
471
+
src/datasets/preprocess/extract_features/face_segmentation/__init__.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ import torch
5
+ from torchvision import transforms
6
+
7
+ from .bisenet import BiSeNet
8
+
9
+
10
+ def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='parsing_map_on_im2.jpg'):
11
+ # Colors for all 20 parts
12
+ part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
13
+ [255, 0, 85], [255, 0, 170],
14
+ [0, 255, 0], [85, 255, 0], [170, 255, 0],
15
+ [0, 255, 85], [0, 255, 170],
16
+ [0, 0, 255], [85, 0, 255], [170, 0, 255],
17
+ [0, 85, 255], [0, 170, 255],
18
+ [255, 255, 0], [255, 255, 85], [255, 255, 170],
19
+ [255, 0, 255], [255, 85, 255], [255, 170, 255],
20
+ [0, 255, 255], [85, 255, 255], [170, 255, 255]]
21
+
22
+ im = np.array(im)
23
+ vis_im = im.copy().astype(np.uint8)
24
+ vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
25
+ vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
26
+ vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
27
+
28
+ num_of_class = np.max(vis_parsing_anno)
29
+
30
+ for pi in range(1, num_of_class + 1):
31
+ index = np.where(vis_parsing_anno == pi)
32
+ vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
33
+
34
+ vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
35
+ # print(vis_parsing_anno_color.shape, vis_im.shape)
36
+ vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
37
+
38
+ # Save result or not
39
+ if save_im:
40
+ cv2.imwrite(save_path[:-4] +'.png', vis_parsing_anno)
41
+ cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
42
+
43
+ # return vis_im
44
+
45
+ def get_face_mask(face_parser, images, batch_size=128):
46
+ # images: Bx3xHxW
47
+ kernel = np.ones((13, 13), np.float32)
48
+ face_masks = []
49
+ for i in range(0, images.shape[0], batch_size):
50
+ images_batch = images[i:i+batch_size]
51
+ with torch.no_grad():
52
+ out = face_parser(images_batch)[0]
53
+ parsing = out.cpu().numpy().argmax(1)
54
+ masks = np.zeros_like(parsing, np.float32)
55
+ for idx in range(1, 14):
56
+ masks[parsing == idx] = 1
57
+
58
+ for mask in masks:
59
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=2)
60
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2)
61
+ mask = cv2.dilate(mask, kernel, iterations=3)
62
+ face_masks.append(mask)
63
+
64
+ return face_masks
65
+
66
+
67
+ def build_face_parser(weight_path, resnet_weight_path, n_classes=19, device_id=0):
68
+ model_state_dict = torch.load(weight_path, weights_only=False)
69
+ bisenet = BiSeNet(n_classes, resnet_weight_path=resnet_weight_path)
70
+ # load model
71
+ #bisenet.load_state_dict(model_state_dict, strict=True)
72
+ bisenet_state_dict = bisenet.state_dict()
73
+ for k, v in model_state_dict.items():
74
+ if 'fc' in k: continue
75
+ bisenet_state_dict.update({k: v})
76
+ bisenet.load_state_dict(bisenet_state_dict)
77
+ bisenet.to(f"cuda:{device_id}")
78
+
79
+ to_tensor = transforms.Compose([
80
+ transforms.ToTensor(),
81
+ transforms.Resize((512, 512)),
82
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
83
+ ])
84
+
85
+ return bisenet.eval(), to_tensor
86
+
87
+
88
+
src/datasets/preprocess/extract_features/face_segmentation/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (3.07 kB). View file
 
src/datasets/preprocess/extract_features/face_segmentation/__pycache__/bisenet.cpython-310.pyc ADDED
Binary file (8.38 kB). View file
 
src/datasets/preprocess/extract_features/face_segmentation/__pycache__/resnet.cpython-310.pyc ADDED
Binary file (3.77 kB). View file
 
src/datasets/preprocess/extract_features/face_segmentation/bisenet.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision
9
+
10
+ from .resnet import Resnet18
11
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
12
+
13
+
14
+ class ConvBNReLU(nn.Module):
15
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
16
+ super(ConvBNReLU, self).__init__()
17
+ self.conv = nn.Conv2d(in_chan,
18
+ out_chan,
19
+ kernel_size = ks,
20
+ stride = stride,
21
+ padding = padding,
22
+ bias = False)
23
+ self.bn = nn.BatchNorm2d(out_chan)
24
+ self.init_weight()
25
+
26
+ def forward(self, x):
27
+ x = self.conv(x)
28
+ x = F.relu(self.bn(x))
29
+ return x
30
+
31
+ def init_weight(self):
32
+ for ly in self.children():
33
+ if isinstance(ly, nn.Conv2d):
34
+ nn.init.kaiming_normal_(ly.weight, a=1)
35
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
36
+
37
+ class BiSeNetOutput(nn.Module):
38
+ def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
39
+ super(BiSeNetOutput, self).__init__()
40
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
41
+ self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
42
+ self.init_weight()
43
+
44
+ def forward(self, x):
45
+ x = self.conv(x)
46
+ x = self.conv_out(x)
47
+ return x
48
+
49
+ def init_weight(self):
50
+ for ly in self.children():
51
+ if isinstance(ly, nn.Conv2d):
52
+ nn.init.kaiming_normal_(ly.weight, a=1)
53
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
54
+
55
+ def get_params(self):
56
+ wd_params, nowd_params = [], []
57
+ for name, module in self.named_modules():
58
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
59
+ wd_params.append(module.weight)
60
+ if not module.bias is None:
61
+ nowd_params.append(module.bias)
62
+ elif isinstance(module, nn.BatchNorm2d):
63
+ nowd_params += list(module.parameters())
64
+ return wd_params, nowd_params
65
+
66
+
67
+ class AttentionRefinementModule(nn.Module):
68
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
69
+ super(AttentionRefinementModule, self).__init__()
70
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
71
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
72
+ self.bn_atten = nn.BatchNorm2d(out_chan)
73
+ self.sigmoid_atten = nn.Sigmoid()
74
+ self.init_weight()
75
+
76
+ def forward(self, x):
77
+ feat = self.conv(x)
78
+ atten = F.avg_pool2d(feat, feat.size()[2:])
79
+ atten = self.conv_atten(atten)
80
+ atten = self.bn_atten(atten)
81
+ atten = self.sigmoid_atten(atten)
82
+ out = torch.mul(feat, atten)
83
+ return out
84
+
85
+ def init_weight(self):
86
+ for ly in self.children():
87
+ if isinstance(ly, nn.Conv2d):
88
+ nn.init.kaiming_normal_(ly.weight, a=1)
89
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
90
+
91
+
92
+ class ContextPath(nn.Module):
93
+ def __init__(self, *args, **kwargs):
94
+ super(ContextPath, self).__init__()
95
+ backbone_weight_path = kwargs.get("resnet_weight_path", None)
96
+ self.resnet = Resnet18(backbone_weight_path)
97
+ self.arm16 = AttentionRefinementModule(256, 128)
98
+ self.arm32 = AttentionRefinementModule(512, 128)
99
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
100
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
101
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
102
+
103
+ self.init_weight()
104
+
105
+ def forward(self, x):
106
+ H0, W0 = x.size()[2:]
107
+ feat8, feat16, feat32 = self.resnet(x)
108
+ H8, W8 = feat8.size()[2:]
109
+ H16, W16 = feat16.size()[2:]
110
+ H32, W32 = feat32.size()[2:]
111
+
112
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
113
+ avg = self.conv_avg(avg)
114
+ avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
115
+
116
+ feat32_arm = self.arm32(feat32)
117
+ feat32_sum = feat32_arm + avg_up
118
+ feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
119
+ feat32_up = self.conv_head32(feat32_up)
120
+
121
+ feat16_arm = self.arm16(feat16)
122
+ feat16_sum = feat16_arm + feat32_up
123
+ feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
124
+ feat16_up = self.conv_head16(feat16_up)
125
+
126
+ return feat8, feat16_up, feat32_up # x8, x8, x16
127
+
128
+ def init_weight(self):
129
+ for ly in self.children():
130
+ if isinstance(ly, nn.Conv2d):
131
+ nn.init.kaiming_normal_(ly.weight, a=1)
132
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
133
+
134
+ def get_params(self):
135
+ wd_params, nowd_params = [], []
136
+ for name, module in self.named_modules():
137
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
138
+ wd_params.append(module.weight)
139
+ if not module.bias is None:
140
+ nowd_params.append(module.bias)
141
+ elif isinstance(module, nn.BatchNorm2d):
142
+ nowd_params += list(module.parameters())
143
+ return wd_params, nowd_params
144
+
145
+
146
+ ### This is not used, since I replace this with the resnet feature with the same size
147
+ class SpatialPath(nn.Module):
148
+ def __init__(self, *args, **kwargs):
149
+ super(SpatialPath, self).__init__()
150
+ self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
151
+ self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
152
+ self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
153
+ self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
154
+ self.init_weight()
155
+
156
+ def forward(self, x):
157
+ feat = self.conv1(x)
158
+ feat = self.conv2(feat)
159
+ feat = self.conv3(feat)
160
+ feat = self.conv_out(feat)
161
+ return feat
162
+
163
+ def init_weight(self):
164
+ for ly in self.children():
165
+ if isinstance(ly, nn.Conv2d):
166
+ nn.init.kaiming_normal_(ly.weight, a=1)
167
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
168
+
169
+ def get_params(self):
170
+ wd_params, nowd_params = [], []
171
+ for name, module in self.named_modules():
172
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
173
+ wd_params.append(module.weight)
174
+ if not module.bias is None:
175
+ nowd_params.append(module.bias)
176
+ elif isinstance(module, nn.BatchNorm2d):
177
+ nowd_params += list(module.parameters())
178
+ return wd_params, nowd_params
179
+
180
+
181
+ class FeatureFusionModule(nn.Module):
182
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
183
+ super(FeatureFusionModule, self).__init__()
184
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
185
+ self.conv1 = nn.Conv2d(out_chan,
186
+ out_chan//4,
187
+ kernel_size = 1,
188
+ stride = 1,
189
+ padding = 0,
190
+ bias = False)
191
+ self.conv2 = nn.Conv2d(out_chan//4,
192
+ out_chan,
193
+ kernel_size = 1,
194
+ stride = 1,
195
+ padding = 0,
196
+ bias = False)
197
+ self.relu = nn.ReLU(inplace=True)
198
+ self.sigmoid = nn.Sigmoid()
199
+ self.init_weight()
200
+
201
+ def forward(self, fsp, fcp):
202
+ fcat = torch.cat([fsp, fcp], dim=1)
203
+ feat = self.convblk(fcat)
204
+ atten = F.avg_pool2d(feat, feat.size()[2:])
205
+ atten = self.conv1(atten)
206
+ atten = self.relu(atten)
207
+ atten = self.conv2(atten)
208
+ atten = self.sigmoid(atten)
209
+ feat_atten = torch.mul(feat, atten)
210
+ feat_out = feat_atten + feat
211
+ return feat_out
212
+
213
+ def init_weight(self):
214
+ for ly in self.children():
215
+ if isinstance(ly, nn.Conv2d):
216
+ nn.init.kaiming_normal_(ly.weight, a=1)
217
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
218
+
219
+ def get_params(self):
220
+ wd_params, nowd_params = [], []
221
+ for name, module in self.named_modules():
222
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
223
+ wd_params.append(module.weight)
224
+ if not module.bias is None:
225
+ nowd_params.append(module.bias)
226
+ elif isinstance(module, nn.BatchNorm2d):
227
+ nowd_params += list(module.parameters())
228
+ return wd_params, nowd_params
229
+
230
+
231
+ class BiSeNet(nn.Module):
232
+ def __init__(self, n_classes, *args, **kwargs):
233
+ super(BiSeNet, self).__init__()
234
+ backbone_weight_path = kwargs.get("resnet_weight_path", None)
235
+ self.cp = ContextPath(resnet_weight_path=backbone_weight_path)
236
+ ## here self.sp is deleted
237
+ self.ffm = FeatureFusionModule(256, 256)
238
+ self.conv_out = BiSeNetOutput(256, 256, n_classes)
239
+ self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
240
+ self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
241
+ self.init_weight()
242
+
243
+ def forward(self, x):
244
+ H, W = x.size()[2:]
245
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
246
+ feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
247
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
248
+
249
+ feat_out = self.conv_out(feat_fuse)
250
+ feat_out16 = self.conv_out16(feat_cp8)
251
+ feat_out32 = self.conv_out32(feat_cp16)
252
+
253
+ feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
254
+ feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
255
+ feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
256
+ return feat_out, feat_out16, feat_out32
257
+
258
+ def init_weight(self):
259
+ for ly in self.children():
260
+ if isinstance(ly, nn.Conv2d):
261
+ nn.init.kaiming_normal_(ly.weight, a=1)
262
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
263
+
264
+ def get_params(self):
265
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
266
+ for name, child in self.named_children():
267
+ child_wd_params, child_nowd_params = child.get_params()
268
+ if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
269
+ lr_mul_wd_params += child_wd_params
270
+ lr_mul_nowd_params += child_nowd_params
271
+ else:
272
+ wd_params += child_wd_params
273
+ nowd_params += child_nowd_params
274
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
275
+
276
+
277
+ if __name__ == "__main__":
278
+ net = BiSeNet(19)
279
+ net.cuda()
280
+ net.eval()
281
+ in_ten = torch.randn(16, 3, 640, 480).cuda()
282
+ out, out16, out32 = net(in_ten)
283
+ print(out.shape)
284
+
285
+ net.get_params()
src/datasets/preprocess/extract_features/face_segmentation/resnet.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #!/usr/bin/python
3
+ # -*- encoding: utf-8 -*-
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.utils.model_zoo as modelzoo
9
+
10
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
11
+
12
+ resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
13
+
14
+
15
+ def conv3x3(in_planes, out_planes, stride=1):
16
+ """3x3 convolution with padding"""
17
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
18
+ padding=1, bias=False)
19
+
20
+
21
+ class BasicBlock(nn.Module):
22
+ def __init__(self, in_chan, out_chan, stride=1):
23
+ super(BasicBlock, self).__init__()
24
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
25
+ self.bn1 = nn.BatchNorm2d(out_chan)
26
+ self.conv2 = conv3x3(out_chan, out_chan)
27
+ self.bn2 = nn.BatchNorm2d(out_chan)
28
+ self.relu = nn.ReLU(inplace=True)
29
+ self.downsample = None
30
+ if in_chan != out_chan or stride != 1:
31
+ self.downsample = nn.Sequential(
32
+ nn.Conv2d(in_chan, out_chan,
33
+ kernel_size=1, stride=stride, bias=False),
34
+ nn.BatchNorm2d(out_chan),
35
+ )
36
+
37
+ def forward(self, x):
38
+ residual = self.conv1(x)
39
+ residual = F.relu(self.bn1(residual))
40
+ residual = self.conv2(residual)
41
+ residual = self.bn2(residual)
42
+
43
+ shortcut = x
44
+ if self.downsample is not None:
45
+ shortcut = self.downsample(x)
46
+
47
+ out = shortcut + residual
48
+ out = self.relu(out)
49
+ return out
50
+
51
+
52
+ def create_layer_basic(in_chan, out_chan, bnum, stride=1):
53
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
54
+ for i in range(bnum-1):
55
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
56
+ return nn.Sequential(*layers)
57
+
58
+
59
+ class Resnet18(nn.Module):
60
+ def __init__(self, backbone_weight_path=None):
61
+ super(Resnet18, self).__init__()
62
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
63
+ bias=False)
64
+ self.bn1 = nn.BatchNorm2d(64)
65
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
66
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
67
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
68
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
69
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
70
+ self.init_weight(backbone_weight_path)
71
+
72
+ def forward(self, x):
73
+ x = self.conv1(x)
74
+ x = F.relu(self.bn1(x))
75
+ x = self.maxpool(x)
76
+
77
+ x = self.layer1(x)
78
+ feat8 = self.layer2(x) # 1/8
79
+ feat16 = self.layer3(feat8) # 1/16
80
+ feat32 = self.layer4(feat16) # 1/32
81
+ return feat8, feat16, feat32
82
+
83
+ def init_weight(self, backbone_weight_path=None):
84
+ if backbone_weight_path is None:
85
+ state_dict = modelzoo.load_url(resnet18_url)
86
+ else:
87
+ state_dict = torch.load(backbone_weight_path, weights_only=False)
88
+ self_state_dict = self.state_dict()
89
+ for k, v in state_dict.items():
90
+ if 'fc' in k: continue
91
+ self_state_dict.update({k: v})
92
+ self.load_state_dict(self_state_dict)
93
+
94
+ def get_params(self):
95
+ wd_params, nowd_params = [], []
96
+ for name, module in self.named_modules():
97
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
98
+ wd_params.append(module.weight)
99
+ if not module.bias is None:
100
+ nowd_params.append(module.bias)
101
+ elif isinstance(module, nn.BatchNorm2d):
102
+ nowd_params += list(module.parameters())
103
+ return wd_params, nowd_params
104
+
105
+
106
+ if __name__ == "__main__":
107
+ net = Resnet18()
108
+ x = torch.randn(16, 3, 224, 224)
109
+ out = net(x)
110
+ print(out[0].size())
111
+ print(out[1].size())
112
+ print(out[2].size())
113
+ net.get_params()
src/datasets/preprocess/extract_features/motion_processer.py ADDED
@@ -0,0 +1,1420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Motion feature extractor
3
+ """
4
+ import os
5
+ import os.path as osp
6
+ import sys
7
+ import pickle
8
+ from omegaconf import OmegaConf
9
+
10
+ import torch
11
+
12
+ from PIL import Image
13
+ import numpy as np
14
+ import cv2
15
+ import imageio
16
+ import pickle
17
+ import time
18
+ from decord import VideoReader # must after import torch
19
+
20
+ from rich.progress import track
21
+
22
+
23
+
24
+
25
+ sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))))))
26
+ from src.datasets.preprocess.extract_features.face_segmentation import build_face_parser, get_face_mask, vis_parsing_maps
27
+ from src.thirdparty.liveportrait.src.utils.helper import load_model, concat_feat
28
+ from src.thirdparty.liveportrait.src.utils.io import load_image_rgb, resize_to_limit, load_video
29
+ from src.thirdparty.liveportrait.src.utils.video import get_fps, images2video, add_audio_to_video
30
+ from src.thirdparty.liveportrait.src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
31
+
32
+ from src.thirdparty.liveportrait.src.utils.cropper import Cropper
33
+ from src.thirdparty.liveportrait.src.utils.crop import prepare_paste_back, paste_back, paste_back_with_face_mask
34
+ from src.thirdparty.liveportrait.src.utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
35
+ from src.thirdparty.liveportrait.src.utils.helper import mkdir, basename, dct2device, is_image, calc_motion_multiplier
36
+ from src.utils.filter import smooth as ksmooth
37
+ from src.utils.filter import smooth_
38
+
39
+ from skimage.metrics import peak_signal_noise_ratio
40
+ import warnings
41
+
42
+
43
+ def psnr(imgs1, imgs2):
44
+ psnrs = []
45
+ for img1, img2 in zip(imgs1, imgs2):
46
+ psnr = peak_signal_noise_ratio(img1, img2, data_range=255)
47
+ psnrs.append(psnr)
48
+ return psnrs
49
+
50
+
51
+ def suffix(filename):
52
+ """a.jpg -> jpg"""
53
+ pos = filename.rfind(".")
54
+ if pos == -1:
55
+ return ""
56
+ return filename[pos + 1:]
57
+
58
+ def dump(wfp, obj):
59
+ wd = osp.split(wfp)[0]
60
+ if wd != "" and not osp.exists(wd):
61
+ mkdir(wd)
62
+
63
+ _suffix = suffix(wfp)
64
+ if _suffix == "npy":
65
+ np.save(wfp, obj)
66
+ elif _suffix == "pkl":
67
+ pickle.dump(obj, open(wfp, "wb"))
68
+ else:
69
+ raise Exception("Unknown type: {}".format(_suffix))
70
+
71
+ def load(fp):
72
+ suffix_ = suffix(fp)
73
+
74
+ if suffix_ == "npy":
75
+ return np.load(fp)
76
+ elif suffix_ == "pkl":
77
+ return pickle.load(open(fp, "rb"))
78
+ else:
79
+ raise Exception(f"Unknown type: {suffix}")
80
+
81
+
82
+ def remove_suffix(filepath):
83
+ """a/b/c.jpg -> a/b/c"""
84
+ return osp.join(osp.dirname(filepath), basename(filepath))
85
+
86
+
87
+ class MotionProcesser(object):
88
+ def __init__(self, cfg_path, device_id=0) -> None:
89
+ device = f"cuda:{device_id}"
90
+ cfg = OmegaConf.load(cfg_path)
91
+ print(f"Load cfg from {osp.realpath(cfg_path)} done.")
92
+ print(f"=============================== Driven CFG ===============================")
93
+ print(OmegaConf.to_yaml(cfg))
94
+ print(f"=============================== ========== ===============================")
95
+ models_config = OmegaConf.load(cfg.models_config)
96
+
97
+ # 1. init appearance feature extractor
98
+ self.appearance_feature_extractor = load_model(
99
+ cfg.appearance_feature_extractor_path,
100
+ models_config,
101
+ device,
102
+ 'appearance_feature_extractor'
103
+ )
104
+ print(f'1. Load appearance_feature_extractor from {osp.realpath(cfg.appearance_feature_extractor_path)} done.')
105
+
106
+ # 2. # init motion extractor
107
+ self.motion_extractor = load_model(
108
+ cfg.motion_extractor_path,
109
+ models_config,
110
+ device,
111
+ 'motion_extractor'
112
+ )
113
+ print(f'2. Load motion_extractor from {osp.realpath(cfg.motion_extractor_path)} done.')
114
+
115
+ # 3. init S and R
116
+ if cfg.stitching_retargeting_module_path is not None and osp.exists(cfg.stitching_retargeting_module_path):
117
+ self.stitching_retargeting_module = load_model(
118
+ cfg.stitching_retargeting_module_path,
119
+ models_config,
120
+ device,
121
+ 'stitching_retargeting_module'
122
+ )
123
+ print(f'3. Load stitching_retargeting_module from {osp.realpath(cfg.stitching_retargeting_module_path)} done.')
124
+ else:
125
+ self.stitching_retargeting_module = None
126
+
127
+ # 4. init motion warper
128
+ self.warping_module = load_model(
129
+ cfg.warping_module_path,
130
+ models_config,
131
+ device,
132
+ 'warping_module'
133
+ )
134
+ print(f"4. Load warping_module from {osp.realpath(cfg.warping_module_path)} done.")
135
+
136
+ # 5. init decoder
137
+ self.spade_generator = load_model(
138
+ cfg.spade_generator_path,
139
+ models_config,
140
+ device,
141
+ 'spade_generator'
142
+ )
143
+ print(f"Load generator from {osp.realpath(cfg.spade_generator_path)} done.")
144
+
145
+ # # Optimize for inference
146
+ self.compile = cfg.flag_do_torch_compile
147
+ if self.compile:
148
+ torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
149
+ self.warping_module = torch.compile(self.warping_module, mode='max-autotune')
150
+ self.spade_generator = torch.compile(self.spade_generator, mode='max-autotune')
151
+
152
+ # 6. init cropper
153
+ crop_cfg = OmegaConf.load(cfg.crop_cfg)
154
+ self.cropper = Cropper(crop_cfg=crop_cfg, image_type="human_face", device_id=device_id)
155
+
156
+ self.cfg = cfg
157
+ self.models_config = models_config
158
+ self.device = device
159
+
160
+
161
+ # 7. load crop mask
162
+ self.mask_crop = cv2.imread(cfg.mask_crop, cv2.IMREAD_COLOR)
163
+ # 8. load lib array
164
+ with open(cfg.lip_array, 'rb') as f:
165
+ self.lip_array = pickle.load(f)
166
+
167
+ # 9. load face parser
168
+ self.face_parser, self.to_tensor = build_face_parser(weight_path=cfg.face_parser_weight_path, resnet_weight_path=cfg.resnet_weight_path, device_id=device_id)
169
+
170
+ def inference_ctx(self):
171
+ ctx = torch.autocast(device_type=self.device[:4], dtype=torch.float16,
172
+ enabled=self.cfg.flag_use_half_precision)
173
+ return ctx
174
+
175
+ @torch.no_grad()
176
+ def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor:
177
+ """ get the appearance feature of the image by F
178
+ x: Bx3xHxW, normalized to 0~1
179
+ """
180
+ with self.inference_ctx():
181
+ feature_3d = self.appearance_feature_extractor(x)
182
+
183
+ return feature_3d.float()
184
+
185
+ @torch.no_grad()
186
+ def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict:
187
+ """ get the implicit keypoint information
188
+ x: Bx3xHxW, normalized to 0~1
189
+ flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
190
+ return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
191
+ """
192
+ with self.inference_ctx():
193
+ kp_info = self.motion_extractor(x)
194
+
195
+ if self.cfg.flag_use_half_precision:
196
+ # float the dict
197
+ for k, v in kp_info.items():
198
+ if isinstance(v, torch.Tensor):
199
+ kp_info[k] = v.float()
200
+
201
+ return kp_info
202
+
203
+ @torch.no_grad()
204
+ def refine_kp(self, kp_info):
205
+ bs = kp_info['exp'].shape[0]
206
+ kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None] # Bx1
207
+ kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None] # Bx1
208
+ kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None] # Bx1
209
+ kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3) # BxNx3
210
+ if 'kp' in kp_info.keys():
211
+ kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3) # BxNx3
212
+
213
+ return kp_info
214
+
215
+ @torch.no_grad()
216
+ def transform_keypoint(self, kp_info: dict):
217
+ """
218
+ transform the implicit keypoints with the pose, shift, and expression deformation
219
+ kp: BxNx3
220
+ """
221
+ kp = kp_info['kp'] # (bs, k, 3)
222
+ pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']
223
+
224
+ t, exp = kp_info['t'], kp_info['exp']
225
+ scale = kp_info['scale']
226
+
227
+ pitch = headpose_pred_to_degree(pitch)
228
+ yaw = headpose_pred_to_degree(yaw)
229
+ roll = headpose_pred_to_degree(roll)
230
+
231
+ bs = kp.shape[0]
232
+ if kp.ndim == 2:
233
+ num_kp = kp.shape[1] // 3 # Bx(num_kpx3)
234
+ else:
235
+ num_kp = kp.shape[1] # Bxnum_kpx3
236
+
237
+ rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3)
238
+
239
+ # Eqn.2: s * (R * x_c,s + exp) + t
240
+ kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
241
+ kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
242
+ kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty
243
+
244
+ return kp_transformed
245
+
246
+ @torch.no_grad()
247
+ def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
248
+ """ conduct the stitching
249
+ kp_source: Bxnum_kpx3
250
+ kp_driving: Bxnum_kpx3
251
+ """
252
+
253
+ if self.stitching_retargeting_module is not None:
254
+ bs, num_kp = kp_source.shape[:2]
255
+ kp_driving_new = kp_driving.clone()
256
+ # stich
257
+ feat_stiching = concat_feat(kp_source, kp_driving_new)
258
+ delta = self.stitching_retargeting_module['stitching'](feat_stiching) # Bxnum_kpx3
259
+
260
+ delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) # 1x20x3
261
+ delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) # 1x1x2
262
+
263
+ kp_driving_new += delta_exp
264
+ kp_driving_new[..., :2] += delta_tx_ty
265
+
266
+ return kp_driving_new
267
+
268
+ return kp_driving
269
+
270
+ @torch.no_grad()
271
+ def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> dict[str, torch.Tensor]:
272
+ """ get the image after the warping of the implicit keypoints
273
+ feature_3d: Bx32x16x64x64, feature volume
274
+ kp_source: BxNx3
275
+ kp_driving: BxNx3
276
+ """
277
+ # The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i))
278
+ with self.inference_ctx():
279
+ if self.compile:
280
+ # Mark the beginning of a new CUDA Graph step
281
+ torch.compiler.cudagraph_mark_step_begin()
282
+ # get decoder input
283
+ ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
284
+
285
+ # print(f"=============================================================================")
286
+ # for out_key, out_value in ret_dct.items():
287
+ # if isinstance(out_value, str) or isinstance(out_value, int) or isinstance(out_value, float):
288
+ # print(f"{out_key}: {out_value}")
289
+ # elif isinstance(out_value, torch.Tensor):
290
+ # print(f"{out_key}: tensor shape {out_value.shape}, min: {torch.min(out_value)}, max: {torch.max(out_value)}, mean: {torch.mean(out_value)}, std: {torch.std(out_value)}")
291
+ # else:
292
+ # print(f"{out_key}: data type {type(out_value)}")
293
+ # decode
294
+ ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
295
+
296
+ # float the dict
297
+ if self.cfg.flag_use_half_precision:
298
+ for k, v in ret_dct.items():
299
+ if isinstance(v, torch.Tensor):
300
+ ret_dct[k] = v.float()
301
+
302
+ return ret_dct
303
+
304
+ def parse_output(self, out: torch.Tensor) -> np.ndarray:
305
+ """ construct the output as standard
306
+ return: 1xHxWx3, uint8
307
+ """
308
+ out = np.transpose(out.cpu().numpy(), [0, 2, 3, 1]) # 1x3xHxW -> 1xHxWx3
309
+ out = np.clip(out, 0, 1) # clip to 0~1
310
+ out = np.clip(out * 255, 0, 255).astype(np.uint8) # 0~1 -> 0~255
311
+
312
+ return out
313
+
314
+ @torch.no_grad()
315
+ def calc_combined_eye_ratio(self, c_d_eyes_i, source_lmk):
316
+ c_s_eyes = calc_eye_close_ratio(source_lmk[None])
317
+ c_s_eyes_tensor = torch.from_numpy(c_s_eyes).float().to(self.device)
318
+ c_d_eyes_i_tensor = torch.Tensor([c_d_eyes_i[0][0]]).reshape(1, 1).to(self.device)
319
+ # [c_s,eyes, c_d,eyes,i]
320
+ combined_eye_ratio_tensor = torch.cat([c_s_eyes_tensor, c_d_eyes_i_tensor], dim=1)
321
+ return combined_eye_ratio_tensor
322
+
323
+ @torch.no_grad()
324
+ def calc_combined_lip_ratio(self, c_d_lip_i, source_lmk):
325
+ c_s_lip = calc_lip_close_ratio(source_lmk[None])
326
+ c_s_lip_tensor = torch.from_numpy(c_s_lip).float().to(self.device)
327
+ c_d_lip_i_tensor = torch.Tensor([c_d_lip_i[0]]).to(self.device).reshape(1, 1) # 1x1
328
+ # [c_s,lip, c_d,lip,i]
329
+ combined_lip_ratio_tensor = torch.cat([c_s_lip_tensor, c_d_lip_i_tensor], dim=1) # 1x2
330
+ return combined_lip_ratio_tensor
331
+
332
+ def calc_ratio(self, lmk_lst):
333
+ input_eye_ratio_lst = []
334
+ input_lip_ratio_lst = []
335
+ for lmk in lmk_lst:
336
+ # for eyes retargeting
337
+ input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
338
+ # for lip retargeting
339
+ input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
340
+ return input_eye_ratio_lst, input_lip_ratio_lst
341
+
342
+ @torch.no_grad()
343
+ def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor:
344
+ """
345
+ kp_source: BxNx3
346
+ lip_close_ratio: Bx2
347
+ Return: Bx(3*num_kp)
348
+ """
349
+ feat_lip = concat_feat(kp_source, lip_close_ratio)
350
+
351
+ delta = self.stitching_retargeting_module['lip'](feat_lip)
352
+
353
+ return delta.reshape(-1, kp_source.shape[1], 3)
354
+
355
+ @torch.no_grad()
356
+ def retarget_eye(self, kp_source: torch.Tensor, eye_close_ratio: torch.Tensor) -> torch.Tensor:
357
+ """
358
+ kp_source: BxNx3
359
+ eye_close_ratio: Bx3
360
+ Return: Bx(3*num_kp)
361
+ """
362
+ feat_eye = concat_feat(kp_source, eye_close_ratio)
363
+
364
+ delta = self.stitching_retargeting_module['eye'](feat_eye)
365
+
366
+ return delta.reshape(-1, kp_source.shape[1], 3)
367
+
368
+ def crop_image(self, img, do_crop=False):
369
+ ######## process source info ########
370
+ if do_crop:
371
+ crop_info = self.cropper.crop_source_image(img, self.cropper.crop_cfg)
372
+ if crop_info is None:
373
+ raise Exception("No face detected in the source image!")
374
+ lmk = crop_info['lmk_crop']
375
+ img_crop_256x256 = crop_info['img_crop_256x256']
376
+ else:
377
+ crop_info = None
378
+ lmk = self.cropper.calc_lmk_from_cropped_image(img)
379
+ img_crop_256x256 = cv2.resize(img, (256, 256)) # force to resize to 256x256
380
+ return img_crop_256x256, lmk, crop_info
381
+
382
+ def crop_source_video(self, img_lst, do_crop=False):
383
+ if do_crop:
384
+ ret_s = self.cropper.crop_source_video(img_lst, self.cropper.crop_cfg)
385
+ print(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.')
386
+ img_crop_256x256_lst, lmk_crop_lst, M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst']
387
+ else:
388
+ M_c2o_lst = None
389
+ lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(img_lst)
390
+ img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in img_lst] # force to resize to 256x256
391
+ return img_crop_256x256_lst, lmk_crop_lst, M_c2o_lst
392
+
393
+ def crop_driving_videos(self, img_lst, do_crop=False):
394
+ if do_crop:
395
+ ret_d = self.cropper.crop_driving_video(img_lst)
396
+ print(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.')
397
+ img_crop_lst, lmk_crop_lst = ret_d['frame_crop_lst'], ret_d['lmk_crop_lst']
398
+ img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in img_lst]
399
+ else:
400
+ lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(img_lst)
401
+ img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in img_lst] # force to resize to 256x256
402
+ return img_crop_256x256_lst, lmk_crop_lst
403
+
404
+ def prepare_source(self, src_img):
405
+ """ construct the input as standard
406
+ img: HxWx3, uint8, 256x256
407
+ """
408
+ # processing source image to tensor
409
+ h, w = src_img.shape[:2]
410
+ if h != self.cfg.input_height or w != self.cfg.input_width:
411
+ x = cv2.resize(src_img, (self.cfg.input_width, self.cfg.input_height))
412
+ else:
413
+ x = src_img.copy()
414
+
415
+ if x.ndim == 3:
416
+ x = x[np.newaxis].astype(np.float32) / 255. # HxWx3 -> 1xHxWx3, normalized to 0~1
417
+ elif x.ndim == 4:
418
+ x = x.astype(np.float32) / 255. # BxHxWx3, normalized to 0~1
419
+ else:
420
+ raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
421
+
422
+ x = np.clip(x, 0, 1) # clip to 0~1
423
+ x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
424
+ x = x.to(self.device)
425
+
426
+ # extract features
427
+ I_s = x
428
+ f_s = self.extract_feature_3d(I_s)
429
+ x_s_info = self.get_kp_info(I_s)
430
+
431
+ return f_s, x_s_info
432
+
433
+ def process_clips(self, clips):
434
+ """ construct the input as standard
435
+ clips: NxBxHxWx3, uint8
436
+ """
437
+ # resize to 256 x 256
438
+ imgs = []
439
+ for img in clips:
440
+ h, w = img.shape[:2]
441
+ if h != self.cfg.input_height or w != self.cfg.input_width:
442
+ img = cv2.resize(img, (self.cfg.input_width, self.cfg.input_height))
443
+ else:
444
+ img = img.copy()
445
+ imgs.append(img)
446
+
447
+ # processing video frames to tensor
448
+ if isinstance(imgs, list):
449
+ _imgs = np.array(imgs)[..., np.newaxis] # TxHxWx3x1
450
+ elif isinstance(imgs, np.ndarray):
451
+ _imgs = imgs
452
+ else:
453
+ raise ValueError(f'imgs type error: {type(imgs)}')
454
+
455
+ y = _imgs.astype(np.float32) / 255.
456
+ y = np.clip(y, 0, 1) # clip to 0~1
457
+ y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
458
+ y = y.to(self.device)
459
+
460
+ return y
461
+
462
+ def prepare_driving_videos(self, vid_frames, feat_type="tensor"):
463
+ """ get driving kp infos
464
+ vid_frames: image list of HxWx3, uint8
465
+ """
466
+ # extract features
467
+ total_len = len(vid_frames)
468
+ kp_infos = {"pitch": [], "yaw": [], "roll": [], "t": [], "exp": [], "scale": [], "kp": []}
469
+ for start_idx in range(0, total_len, self.cfg.batch_size):
470
+ frames = vid_frames[start_idx: min(start_idx + self.cfg.batch_size, total_len)]
471
+ frames = self.process_clips(frames).squeeze(1)
472
+ kp_info = self.get_kp_info(frames)
473
+
474
+ for k, v in kp_info.items():
475
+ kp_infos[k].append(v)
476
+
477
+ # combine the kp_infos
478
+ for k, v in kp_infos.items():
479
+ kp_infos[k] = torch.cat(v, dim=0)
480
+
481
+ if feat_type == "np":
482
+ for k, v in kp_infos.items():
483
+ kp_infos[k] = v.cpu().numpy()
484
+
485
+ return kp_infos
486
+
487
+ def get_driving_template(self, kp_infos, smooth=False, dtype="pt_tensor"):
488
+ kp_infos = self.refine_kp(kp_infos)
489
+ motion_list = []
490
+ n_frames = len(kp_infos["exp"])
491
+ for idx in range(n_frames):
492
+ exp = kp_infos["exp"][idx]
493
+ scale = kp_infos["scale"][idx]
494
+ t = kp_infos["t"][idx]
495
+ pitch = kp_infos["pitch"][idx]
496
+ yaw = kp_infos["yaw"][idx]
497
+ roll = kp_infos["roll"][idx]
498
+
499
+ R = get_rotation_matrix(pitch, yaw, roll)
500
+ R = R.reshape(1, 3, 3)
501
+
502
+ exp = exp.reshape(1, 21, 3)
503
+ scale = scale.reshape(1, 1)
504
+ t = t.reshape(1, 3)
505
+ pitch = pitch.reshape(1, 1)
506
+ yaw = yaw.reshape(1, 1)
507
+ roll = roll.reshape(1, 1)
508
+
509
+ if dtype == "np":
510
+ R = R.cpu().numpy().astype(np.float32)
511
+ exp = exp.cpu().numpy().astype(np.float32)
512
+ scale = scale.cpu().numpy().astype(np.float32)
513
+ t = t.cpu().numpy().astype(np.float32)
514
+ pitch = pitch.cpu().numpy().astype(np.float32)
515
+ yaw = yaw.cpu().numpy().astype(np.float32)
516
+ roll = roll.cpu().numpy().astype(np.float32)
517
+
518
+ motion_list.append(
519
+ {"exp": exp, "scale": scale, "R": R, "t": t, "pitch": pitch, "yaw": yaw, "roll": roll}
520
+ )
521
+ tgt_motion = {'n_frames': n_frames, 'output_fps': 25, 'motion': motion_list}
522
+
523
+ if smooth:
524
+ print("Smoothing motion sequence...")
525
+ tgt_motion = smooth_(tgt_motion, method="ema")
526
+ return tgt_motion
527
+
528
+ @torch.no_grad()
529
+ def update_delta_new_eyeball_direction(self, eyeball_direction_x, eyeball_direction_y, delta_new, **kwargs):
530
+ if eyeball_direction_x > 0:
531
+ delta_new[0, 11, 0] += eyeball_direction_x * 0.0007
532
+ delta_new[0, 15, 0] += eyeball_direction_x * 0.001
533
+ else:
534
+ delta_new[0, 11, 0] += eyeball_direction_x * 0.001
535
+ delta_new[0, 15, 0] += eyeball_direction_x * 0.0007
536
+
537
+ delta_new[0, 11, 1] += eyeball_direction_y * -0.001
538
+ delta_new[0, 15, 1] += eyeball_direction_y * -0.001
539
+ blink = -eyeball_direction_y / 2.
540
+
541
+ delta_new[0, 11, 1] += blink * -0.001
542
+ delta_new[0, 13, 1] += blink * 0.0003
543
+ delta_new[0, 15, 1] += blink * -0.001
544
+ delta_new[0, 16, 1] += blink * 0.0003
545
+
546
+ return delta_new
547
+
548
+ def driven(self, f_s, x_s_info, s_lmk, c_s_eyes_lst, kp_infos, c_d_eyes_lst=None, c_d_lip_lst=None, smooth=False):
549
+ # source kp info
550
+ x_d_i_news=[]
551
+ x_ss=[]
552
+ f_ss=[]
553
+ x_s_info = self.refine_kp(x_s_info)
554
+ R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
555
+ x_s = self.transform_keypoint(x_s_info)
556
+ x_c_s = x_s_info["kp"]
557
+
558
+ # driving kp infos
559
+ driving_template_dct = self.get_driving_template(kp_infos, smooth)
560
+ n_frames = driving_template_dct['n_frames']
561
+
562
+ # driving params
563
+ flag_normalize_lip = self.cfg.flag_normalize_lip
564
+ flag_relative_motion = self.cfg.flag_relative_motion
565
+ flag_source_video_eye_retargeting = self.cfg.flag_source_video_eye_retargeting
566
+ lip_normalize_threshold = self.cfg.lip_normalize_threshold
567
+ source_video_eye_retargeting_threshold = self.cfg.source_video_eye_retargeting_threshold
568
+ animation_region = self.cfg.animation_region
569
+ driving_option = self.cfg.driving_option
570
+ flag_stitching = self.cfg.flag_stitching
571
+ flag_eye_retargeting = self.cfg.flag_eye_retargeting
572
+ flag_lip_retargeting = self.cfg.flag_lip_retargeting
573
+ driving_multiplier = self.cfg.driving_multiplier
574
+ lib_multiplier = self.cfg.lib_multiplier
575
+
576
+ # let lip-open scalar to be 0 at first
577
+ lip_delta_before_animation, eye_delta_before_animation = None, None
578
+ if flag_normalize_lip and flag_relative_motion and s_lmk is not None:
579
+ c_d_lip_before_animation = [0.]
580
+ combined_lip_ratio_tensor_before_animation = self.calc_combined_lip_ratio(c_d_lip_before_animation, s_lmk)
581
+ if combined_lip_ratio_tensor_before_animation[0][0] >= lip_normalize_threshold:
582
+ lip_delta_before_animation = self.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
583
+
584
+ # let eye-open scalar to be the same as the first frame if the latter is eye-open state
585
+ if flag_source_video_eye_retargeting and s_lmk is not None:
586
+ combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0]
587
+ c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]]
588
+ if c_d_eye_before_animation_frame_zero[0][0] < source_video_eye_retargeting_threshold:
589
+ c_d_eye_before_animation_frame_zero = [[0.39]]
590
+ combined_eye_ratio_tensor_before_animation = self.calc_combined_eye_ratio(c_d_eye_before_animation_frame_zero, s_lmk)
591
+ eye_delta_before_animation = self.retarget_eye(x_s, combined_eye_ratio_tensor_before_animation)
592
+
593
+ # animate
594
+ I_p_lst = []
595
+ for i in range(n_frames):
596
+ x_d_i_info = driving_template_dct['motion'][i]
597
+ x_d_i_info = dct2device(x_d_i_info, self.device)
598
+ # R
599
+ R_d_i = x_d_i_info['R']
600
+ if i == 0: # cache the first frame
601
+ R_d_0 = R_d_i
602
+ x_d_0_info = x_d_i_info.copy()
603
+
604
+ # enhance lip
605
+ # if i > 0:
606
+ # for lip_idx in [6, 12, 14, 17, 19, 20]:
607
+ # x_d_i_info['exp'][:, lip_idx, :] = x_d_0_info['exp'][:, lip_idx, :] + (x_d_i_info['exp'][:, lip_idx, :] - x_d_0_info['exp'][:, lip_idx, :]) * lib_multiplier
608
+
609
+ # normalize eye_ball, TODO
610
+ x_d_i_info['exp'] = self.update_delta_new_eyeball_direction(0, -5, x_d_i_info['exp'])
611
+
612
+ # debug
613
+ #print(f"frame {i:03d}, src scale {x_s_info['scale']}, 0 scale {x_d_0_info['scale']}, i scale {x_d_i_info['scale']}")
614
+ # delta
615
+ delta_new = x_s_info['exp'].clone()
616
+ if flag_relative_motion:
617
+ # R
618
+ if animation_region == "all" or animation_region == "pose":
619
+ R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
620
+ else:
621
+ R_new = R_s
622
+
623
+ # exp
624
+ if animation_region == "all" or animation_region == "exp":
625
+ delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
626
+ elif animation_region == "lip":
627
+ for lip_idx in [6, 12, 14, 17, 19, 20]:
628
+ delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :]
629
+ elif animation_region == "eyes":
630
+ for eyes_idx in [11, 13, 15, 16, 18]:
631
+ delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, eyes_idx, :]
632
+
633
+ # scale
634
+ if animation_region == "all":
635
+ scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
636
+ else:
637
+ scale_new = x_s_info['scale']
638
+
639
+ # translation
640
+ if animation_region == "all" or animation_region == "pose":
641
+ t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
642
+ else:
643
+ t_new = x_s_info['t']
644
+ else:
645
+ # R
646
+ if animation_region == "all" or animation_region == "pose":
647
+ R_new = R_d_i
648
+ else:
649
+ R_new = R_s
650
+
651
+ # exp
652
+ if animation_region == "all" or animation_region == "exp":
653
+ EYE_IDX=[1,2,6,11,12,13,14,15,16,17,18,19,20]
654
+ delta_new[:, EYE_IDX, :] = x_d_i_info['exp'][:, EYE_IDX, :]
655
+ # for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
656
+ # delta_new[:, idx, :] = x_d_i_info['exp'][:, idx, :]
657
+ delta_new[:, 3:5, 1] = x_d_i_info['exp'][:, 3:5, 1]
658
+ delta_new[:, 5, 2] = x_d_i_info['exp'][:, 5, 2]
659
+ delta_new[:, 8, 2] = x_d_i_info['exp'][:, 8, 2]
660
+ delta_new[:, 9, 1:] = x_d_i_info['exp'][:, 9, 1:]
661
+ elif animation_region == "lip":
662
+ for lip_idx in [6, 12, 14, 17, 19, 20]:
663
+ delta_new[:, lip_idx, :] = x_d_i_info['exp'][:, lip_idx, :]
664
+ elif animation_region == "eyes":
665
+ for eyes_idx in [11, 13, 15, 16, 18]:
666
+ delta_new[:, eyes_idx, :] = x_d_i_info['exp'][:, eyes_idx, :]
667
+
668
+ # scale
669
+ scale_new = x_s_info['scale']
670
+
671
+ # translation
672
+ if animation_region == "all" or animation_region == "pose":
673
+ t_new = x_d_i_info['t']
674
+ else:
675
+ t_new = x_s_info['t']
676
+
677
+ t_new[..., 2].fill_(0) # zero tz
678
+
679
+ x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
680
+
681
+ if flag_relative_motion and driving_option == "expression-friendly":
682
+ if i == 0:
683
+ x_d_0_new = x_d_i_new
684
+ motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new)
685
+ x_d_diff = (x_d_i_new - x_d_0_new) * motion_multiplier
686
+ x_d_i_new = x_d_diff + x_s
687
+
688
+ # Algorithm 1 in Liveportrait:
689
+ if not flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting:
690
+ # without stitching or retargeting
691
+ if flag_normalize_lip and lip_delta_before_animation is not None:
692
+ x_d_i_new += lip_delta_before_animation
693
+ if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
694
+ x_d_i_new += eye_delta_before_animation
695
+ else:
696
+ pass
697
+ elif flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting:
698
+ # with stitching and without retargeting
699
+ if flag_normalize_lip and lip_delta_before_animation is not None:
700
+ x_d_i_new = self.stitching(x_s, x_d_i_new) + lip_delta_before_animation
701
+ else:
702
+ x_d_i_new = self.stitching(x_s, x_d_i_new)
703
+ if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
704
+ x_d_i_new += eye_delta_before_animation
705
+ else:
706
+ eyes_delta, lip_delta = None, None
707
+ if flag_eye_retargeting and s_lmk is not None and c_d_eyes_lst is not None:
708
+ c_d_eyes_i = c_d_eyes_lst[i]
709
+ combined_eye_ratio_tensor = self.calc_combined_eye_ratio(c_d_eyes_i, s_lmk)
710
+ eyes_delta = self.retarget_eye(x_s, combined_eye_ratio_tensor)
711
+
712
+ if flag_lip_retargeting and s_lmk is not None and c_d_lip_lst is not None:
713
+ c_d_lip_i = c_d_lip_lst[i]
714
+ combined_lip_ratio_tensor = self.calc_combined_lip_ratio(c_d_lip_i, s_lmk)
715
+ # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
716
+ lip_delta = self.retarget_lip(x_s, combined_lip_ratio_tensor)
717
+
718
+ if flag_relative_motion: # use x_s
719
+ x_d_i_new = x_s + \
720
+ (eyes_delta if eyes_delta is not None else 0) + \
721
+ (lip_delta if lip_delta is not None else 0)
722
+ else: # use x_d,i
723
+ x_d_i_new = x_d_i_new + \
724
+ (eyes_delta if eyes_delta is not None else 0) + \
725
+ (lip_delta if lip_delta is not None else 0)
726
+
727
+ if flag_stitching:
728
+ x_d_i_new = self.stitching(x_s, x_d_i_new)
729
+
730
+ x_d_i_new = x_s + (x_d_i_new - x_s) * driving_multiplier
731
+ x_d_i_news.append(x_d_i_new)
732
+ f_s_s= f_s.expand(n_frames, *f_s.shape[1:])
733
+ x_s_s = x_s.expand(n_frames, *x_s.shape[1:])
734
+ x_d_i_new = torch.cat(x_d_i_news, dim=0)
735
+ for start in range(0, n_frames, 100):
736
+ end = min(start + 100,n_frames)
737
+ with torch.no_grad(), torch.autocast('cuda'):
738
+ out = self.warp_decode(f_s_s[start:end], x_s_s[start:end], x_d_i_new[start:end])
739
+ I_p_lst.append(out['out'])
740
+ I_p=torch.cat(I_p_lst, dim=0)
741
+ I_p_i = self.parse_output(I_p)
742
+ return I_p_i
743
+
744
+ def driven_debug(self, f_s, x_s_info, s_lmk, c_s_eyes_lst, driving_template_dct, c_d_eyes_lst=None, c_d_lip_lst=None):
745
+ # source kp info
746
+ x_s_info = self.refine_kp(x_s_info)
747
+ x_c_s = x_s_info["kp"]
748
+ R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
749
+ x_s = self.transform_keypoint(x_s_info)
750
+
751
+ n_frames = driving_template_dct['n_frames']
752
+
753
+ # driving params
754
+ flag_normalize_lip = self.cfg.flag_normalize_lip
755
+ flag_relative_motion = self.cfg.flag_relative_motion
756
+ flag_source_video_eye_retargeting = self.cfg.flag_source_video_eye_retargeting
757
+ lip_normalize_threshold = self.cfg.lip_normalize_threshold
758
+ source_video_eye_retargeting_threshold = self.cfg.source_video_eye_retargeting_threshold
759
+ animation_region = self.cfg.animation_region
760
+ driving_option = self.cfg.driving_option
761
+ flag_stitching = self.cfg.flag_stitching
762
+ flag_eye_retargeting = self.cfg.flag_eye_retargeting
763
+ flag_lip_retargeting = self.cfg.flag_lip_retargeting
764
+ driving_multiplier = self.cfg.driving_multiplier
765
+
766
+ # let lip-open scalar to be 0 at first
767
+ lip_delta_before_animation, eye_delta_before_animation = None, None
768
+ if flag_normalize_lip and flag_relative_motion and s_lmk is not None:
769
+ c_d_lip_before_animation = [0.]
770
+ combined_lip_ratio_tensor_before_animation = self.calc_combined_lip_ratio(c_d_lip_before_animation, s_lmk)
771
+ if combined_lip_ratio_tensor_before_animation[0][0] >= lip_normalize_threshold:
772
+ lip_delta_before_animation = self.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
773
+
774
+ # let eye-open scalar to be the same as the first frame if the latter is eye-open state
775
+ if flag_source_video_eye_retargeting and s_lmk is not None:
776
+ combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0]
777
+ c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]]
778
+ if c_d_eye_before_animation_frame_zero[0][0] < source_video_eye_retargeting_threshold:
779
+ c_d_eye_before_animation_frame_zero = [[0.39]]
780
+ combined_eye_ratio_tensor_before_animation = self.calc_combined_eye_ratio(c_d_eye_before_animation_frame_zero, s_lmk)
781
+ eye_delta_before_animation = self.retarget_eye(x_s, combined_eye_ratio_tensor_before_animation)
782
+
783
+ # animate
784
+ I_p_lst = []
785
+ for i in range(n_frames):
786
+ x_d_i_info = driving_template_dct['motion'][i]
787
+ x_d_i_info = dct2device(x_d_i_info, self.device)
788
+ # R
789
+ R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys
790
+ if i == 0: # cache the first frame
791
+ R_d_0 = R_d_i
792
+ x_d_0_info = x_d_i_info.copy()
793
+
794
+ # debug
795
+ #print(f"frame {i:03d}, src scale {x_s_info['scale']}, 0 scale {x_d_0_info['scale']}, i scale {x_d_i_info['scale']}")
796
+ # delta
797
+ delta_new = x_s_info['exp'].clone()
798
+ if flag_relative_motion:
799
+ # R
800
+ if animation_region == "all" or animation_region == "pose":
801
+ R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
802
+ else:
803
+ R_new = R_s
804
+
805
+ # exp
806
+ if animation_region == "all" or animation_region == "exp":
807
+ delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
808
+ elif animation_region == "lip":
809
+ for lip_idx in [6, 12, 14, 17, 19, 20]:
810
+ delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :]
811
+ elif animation_region == "eyes":
812
+ for eyes_idx in [11, 13, 15, 16, 18]:
813
+ delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, eyes_idx, :]
814
+
815
+ # scale
816
+ if animation_region == "all":
817
+ scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
818
+ else:
819
+ scale_new = x_s_info['scale']
820
+
821
+ # translation
822
+ if animation_region == "all" or animation_region == "pose":
823
+ t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
824
+ else:
825
+ t_new = x_s_info['t']
826
+ else:
827
+ # R
828
+ if animation_region == "all" or animation_region == "pose":
829
+ R_new = R_d_i
830
+ else:
831
+ R_new = R_s
832
+
833
+ # exp
834
+ if animation_region == "all" or animation_region == "exp":
835
+ for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
836
+ delta_new[:, idx, :] = x_d_i_info['exp'][:, idx, :]
837
+ delta_new[:, 3:5, 1] = x_d_i_info['exp'][:, 3:5, 1]
838
+ delta_new[:, 5, 2] = x_d_i_info['exp'][:, 5, 2]
839
+ delta_new[:, 8, 2] = x_d_i_info['exp'][:, 8, 2]
840
+ delta_new[:, 9, 1:] = x_d_i_info['exp'][:, 9, 1:]
841
+ elif animation_region == "lip":
842
+ for lip_idx in [6, 12, 14, 17, 19, 20]:
843
+ delta_new[:, lip_idx, :] = x_d_i_info['exp'][:, lip_idx, :]
844
+ elif animation_region == "eyes":
845
+ for eyes_idx in [11, 13, 15, 16, 18]:
846
+ delta_new[:, eyes_idx, :] = x_d_i_info['exp'][:, eyes_idx, :]
847
+
848
+ # scale
849
+ scale_new = x_s_info['scale']
850
+
851
+ # translation
852
+ if animation_region == "all" or animation_region == "pose":
853
+ t_new = x_d_i_info['t']
854
+ else:
855
+ t_new = x_s_info['t']
856
+
857
+ t_new[..., 2].fill_(0) # zero tz
858
+
859
+ x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
860
+
861
+ if flag_relative_motion and driving_option == "expression-friendly":
862
+ if i == 0:
863
+ x_d_0_new = x_d_i_new
864
+ motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new)
865
+ x_d_diff = (x_d_i_new - x_d_0_new) * motion_multiplier
866
+ x_d_i_new = x_d_diff + x_s
867
+
868
+ # Algorithm 1 in Liveportrait:
869
+ if not flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting:
870
+ # without stitching or retargeting
871
+ if flag_normalize_lip and lip_delta_before_animation is not None:
872
+ x_d_i_new += lip_delta_before_animation
873
+ if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
874
+ x_d_i_new += eye_delta_before_animation
875
+ else:
876
+ pass
877
+ elif flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting:
878
+ # with stitching and without retargeting
879
+ if flag_normalize_lip and lip_delta_before_animation is not None:
880
+ x_d_i_new = self.stitching(x_s, x_d_i_new) + lip_delta_before_animation
881
+ else:
882
+ x_d_i_new = self.stitching(x_s, x_d_i_new)
883
+ if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
884
+ x_d_i_new += eye_delta_before_animation
885
+ else:
886
+ eyes_delta, lip_delta = None, None
887
+ if flag_eye_retargeting and s_lmk is not None and c_d_eyes_lst is not None:
888
+ c_d_eyes_i = c_d_eyes_lst[i]
889
+ combined_eye_ratio_tensor = self.calc_combined_eye_ratio(c_d_eyes_i, s_lmk)
890
+ eyes_delta = self.retarget_eye(x_s, combined_eye_ratio_tensor)
891
+
892
+ if flag_lip_retargeting and s_lmk is not None and c_d_lip_lst is not None:
893
+ c_d_lip_i = c_d_lip_lst[i]
894
+ combined_lip_ratio_tensor = self.calc_combined_lip_ratio(c_d_lip_i, s_lmk)
895
+ # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
896
+ lip_delta = self.retarget_lip(x_s, combined_lip_ratio_tensor)
897
+
898
+ if flag_relative_motion: # use x_s
899
+ x_d_i_new = x_s + \
900
+ (eyes_delta if eyes_delta is not None else 0) + \
901
+ (lip_delta if lip_delta is not None else 0)
902
+ else: # use x_d,i
903
+ x_d_i_new = x_d_i_new + \
904
+ (eyes_delta if eyes_delta is not None else 0) + \
905
+ (lip_delta if lip_delta is not None else 0)
906
+
907
+ if flag_stitching:
908
+ x_d_i_new = self.stitching(x_s, x_d_i_new)
909
+
910
+ x_d_i_new = x_s + (x_d_i_new - x_s) * driving_multiplier
911
+ out = self.warp_decode(f_s, x_s, x_d_i_new)
912
+ I_p_i = self.parse_output(out['out'])[0]
913
+ I_p_lst.append(I_p_i)
914
+
915
+ return I_p_lst
916
+
917
+ def read_image(self, image_path: str) -> list:
918
+ img_rgb = load_image_rgb(image_path)
919
+ img_rgb = resize_to_limit(img_rgb, self.cfg.source_max_dim, self.cfg.source_division)
920
+ source_rgb_list = [img_rgb]
921
+ print(f"Load image from {osp.realpath(image_path)} done.")
922
+ return source_rgb_list
923
+
924
+ def read_video(self, video_path: str, interval=None) -> list:
925
+ vr = VideoReader(video_path)
926
+ if interval is not None:
927
+ video_frames = vr.get_batch(np.arange(0, len(vr), interval)).numpy()
928
+ else:
929
+ video_frames = [vr[0].numpy(), vr[len(vr) // 2].numpy(), vr[-1].numpy()]
930
+ vr.seek(0)
931
+ driving_rgb_list = []
932
+ for video_frame in video_frames:
933
+ # h, w = video_frame.shape[:2]
934
+ # if h != self.cfg.output_height or w != self.cfg.output_width:
935
+ # video_frame = cv2.resize(video_frame, (self.cfg.output_height, self.cfg.output_width))
936
+ driving_rgb_list.append(video_frame)
937
+
938
+ return driving_rgb_list
939
+
940
+ def prepare_videos(self, imgs) -> torch.Tensor:
941
+ """ construct the input as standard
942
+ imgs: NxBxHxWx3, uint8
943
+ """
944
+ if isinstance(imgs, list):
945
+ _imgs = np.array(imgs)[..., np.newaxis] # TxHxWx3x1
946
+ elif isinstance(imgs, np.ndarray):
947
+ _imgs = imgs
948
+ else:
949
+ raise ValueError(f'imgs type error: {type(imgs)}')
950
+
951
+ y = _imgs.astype(np.float32) / 255.
952
+ y = np.clip(y, 0, 1) # clip to 0~1
953
+ y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
954
+ y = y.to(self.device)
955
+
956
+ return y
957
+
958
+ def make_motion_template(self, I_lst, c_eyes_lst, c_lip_lst, **kwargs):
959
+ n_frames = I_lst.shape[0]
960
+ template_dct = {
961
+ 'n_frames': n_frames,
962
+ 'output_fps': kwargs.get('output_fps', 25),
963
+ 'motion': [],
964
+ 'c_eyes_lst': [],
965
+ 'c_lip_lst': [],
966
+ }
967
+
968
+ for i in track(range(n_frames), description='Making motion templates...', total=n_frames):
969
+ # collect s, R, δ and t for inference
970
+ I_i = I_lst[i]
971
+ x_i_info = self.refine_kp(self.get_kp_info(I_i))
972
+ x_s = self.transform_keypoint(x_i_info)
973
+ R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll'])
974
+
975
+ item_dct = {
976
+ 'scale': x_i_info['scale'].cpu().numpy().astype(np.float32),
977
+ 'R': R_i.cpu().numpy().astype(np.float32),
978
+ 'exp': x_i_info['exp'].cpu().numpy().astype(np.float32),
979
+ 't': x_i_info['t'].cpu().numpy().astype(np.float32),
980
+ 'kp': x_i_info['kp'].cpu().numpy().astype(np.float32),
981
+ 'x_s': x_s.cpu().numpy().astype(np.float32),
982
+ }
983
+
984
+ template_dct['motion'].append(item_dct)
985
+
986
+ c_eyes = c_eyes_lst[i].astype(np.float32)
987
+ template_dct['c_eyes_lst'].append(c_eyes)
988
+
989
+ c_lip = c_lip_lst[i].astype(np.float32)
990
+ template_dct['c_lip_lst'].append(c_lip)
991
+
992
+ return template_dct
993
+
994
+ def load_template(self, wfp_template):
995
+ print(f"Load from template: {wfp_template}, NOT the video, so the cropping video and audio are both NULL.")
996
+ driving_template_dct = load(wfp_template)
997
+ c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] # compatible with previous keys
998
+ c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst']
999
+ driving_n_frames = driving_template_dct['n_frames']
1000
+ flag_is_driving_video = True if driving_n_frames > 1 else False
1001
+ n_frames = driving_n_frames
1002
+
1003
+ # set output_fps
1004
+ output_fps = driving_template_dct.get('output_fps', 25)
1005
+ print(f'The FPS of template: {output_fps}')
1006
+ return driving_template_dct
1007
+
1008
+ def reconstruction(self, src_img, dst_imgs, video_path="template"):
1009
+ # prepare source
1010
+ src_img_256x256, s_lmk, _ = self.crop_image(src_img, do_crop=False)
1011
+ #c_s_eyes_lst, c_s_lip_lst = self.calc_ratio([s_lmk])
1012
+ c_s_eyes_lst = None
1013
+ f_s, x_s_info = self.prepare_source(src_img_256x256)
1014
+
1015
+ # prepare driving videos
1016
+ dst_imgs_256x256, d_lmk_lst = self.crop_driving_videos(dst_imgs, do_crop=False)
1017
+ c_d_eyes_lst, c_d_lip_lst = self.calc_ratio(d_lmk_lst)
1018
+ kp_infos = self.prepare_driving_videos(dst_imgs_256x256)
1019
+
1020
+
1021
+ recs = self.driven(f_s, x_s_info, s_lmk, c_s_eyes_lst, kp_infos, c_d_eyes_lst, c_d_lip_lst)
1022
+ return recs
1023
+
1024
+ def save_results(self, results, save_path, audio_path=None):
1025
+ save_dir = osp.dirname(save_path)
1026
+ save_name = osp.basename(save_path)
1027
+ final_video = osp.join(save_dir, f'final_{save_name}')
1028
+
1029
+ images2video(results, wfp=save_path, fps=self.cfg.output_fps)
1030
+
1031
+ if audio_path is not None:
1032
+ add_audio_to_video(save_path, audio_path, final_video)
1033
+ os.remove(save_path)
1034
+
1035
+ def rec_score(self, video_path: str, interval=None, save_path=None):
1036
+ video_frames = self.read_video(video_path, interval=interval)
1037
+ #print(f"len frames: {len(video_frames)}, shape: {video_frames[0].shape}")
1038
+ recs = self.reconstruction(video_frames[0], video_frames[1:], video_path)
1039
+ if save_path is not None:
1040
+ self.save_results(recs, save_path)
1041
+ #print(f"len rec: {len(recs)}, shape: {recs[0].shape}")
1042
+ psnrs = psnr(video_frames[1:], recs)
1043
+ psnrs_np = np.array(psnrs)
1044
+ psnr_mean, psnr_std = np.mean(psnrs_np), np.std(psnrs_np)
1045
+ rec_score = {"mean": psnr_mean, "std": psnr_std}
1046
+ return rec_score
1047
+
1048
+ @torch.no_grad()
1049
+ def paste_back_by_face_mask(self, result, crop_info, src_img, crop_src_image, use_laplacian=False):
1050
+ """
1051
+ paste back the result to the original image with face mask
1052
+ """
1053
+ # detect src mask
1054
+ crop_src_tensor = self.to_tensor(crop_src_image).unsqueeze(0).to(self.device)
1055
+ src_msks = get_face_mask(self.face_parser, crop_src_tensor)
1056
+ result_tensor = self.to_tensor(result).unsqueeze(0).to(self.device)
1057
+ result_msks = get_face_mask(self.face_parser, result_tensor)
1058
+ # combine masks
1059
+ masks = []
1060
+ for src_msk, result_msk in zip(src_msks, result_msks):
1061
+ mask = np.clip(src_msk + result_msk, 0, 1)
1062
+ masks.append(mask)
1063
+ result = paste_back_with_face_mask(result, crop_info, src_img, masks[0], use_laplacian=use_laplacian)
1064
+ return result
1065
+
1066
+ def driven_by_audio(self, src_img, kp_infos, save_path, audio_path=None, smooth=False):
1067
+ # prepare source
1068
+ # prepare source
1069
+ src_img_256x256, s_lmk, crop_info = self.crop_image(src_img, do_crop=True)
1070
+ #c_s_eyes_lst, c_s_lip_lst = self.calc_ratio([s_lmk])
1071
+ c_s_eyes_lst = None
1072
+ f_s, x_s_info = self.prepare_source(src_img_256x256)
1073
+
1074
+ mask_ori_float = prepare_paste_back(self.mask_crop, crop_info['M_c2o'], dsize=(src_img.shape[1], src_img.shape[0]))
1075
+
1076
+ # prepare driving videos
1077
+ results = self.driven(f_s, x_s_info, s_lmk, c_s_eyes_lst, kp_infos, smooth=smooth)
1078
+ frames=results.shape[0]
1079
+ results = [paste_back(results[i], crop_info['M_c2o'], src_img, mask_ori_float) for i in range(frames)]
1080
+ self.save_results(results, save_path, audio_path)
1081
+ def mix_kp_infos(self, emo_kp_infos, lip_kp_infos, smooth=False, dtype="pt_tensor"):
1082
+ driving_emo_template_dct = self.get_driving_template(emo_kp_infos, smooth=False, dtype=dtype)
1083
+ if lip_kp_infos is not None:
1084
+ driving_lip_template_dct = self.get_driving_template(lip_kp_infos, smooth=smooth, dtype=dtype)
1085
+ driving_template_dct = {**driving_emo_template_dct}
1086
+ n_frames = min(driving_emo_template_dct['n_frames'], driving_lip_template_dct['n_frames'])
1087
+ driving_template_dct['n_frames'] = n_frames
1088
+ for i in range(n_frames):
1089
+ emo_motion = driving_emo_template_dct['motion'][i]['exp']
1090
+ lib_motion = driving_lip_template_dct['motion'][i]['exp']
1091
+ for lip_idx in [6, 12, 14, 17, 19, 20]:
1092
+ emo_motion[:, lip_idx, :] = lib_motion[:, lip_idx, :]
1093
+ driving_template_dct['motion'][i]['exp'] = emo_motion
1094
+ else:
1095
+ driving_template_dct = driving_emo_template_dct
1096
+
1097
+ return driving_template_dct
1098
+
1099
+ def driven_by_mix(self, src_img, driving_video_path, kp_infos, save_path, audio_path=None, smooth=False):
1100
+ # prepare source
1101
+ src_img_256x256, s_lmk, crop_info = self.crop_image(src_img, do_crop=True)
1102
+ c_s_eyes_lst, c_s_lip_lst = self.calc_ratio([s_lmk])
1103
+ f_s, x_s_info = self.prepare_source(src_img_256x256)
1104
+ mask_ori_float = prepare_paste_back(self.mask_crop, crop_info['M_c2o'], dsize=(src_img.shape[1], src_img.shape[0]))
1105
+ # prepare driving videos
1106
+ driving_imgs = self.read_video(driving_video_path, interval=1)
1107
+ dst_imgs_256x256, d_lmk_lst = self.crop_driving_videos(driving_imgs, do_crop=True)
1108
+ c_d_eyes_lst, c_d_lip_lst = self.calc_ratio(d_lmk_lst)
1109
+ emo_kp_infos = self.prepare_driving_videos(dst_imgs_256x256)
1110
+ # mix kp_infos
1111
+ driving_template_dct = self.mix_kp_infos(emo_kp_infos, kp_infos, smooth=smooth)
1112
+ # driven
1113
+ results = self.driven_debug(f_s, x_s_info, s_lmk, c_s_eyes_lst, driving_template_dct, c_d_eyes_lst=c_d_eyes_lst, c_d_lip_lst=c_d_lip_lst)
1114
+ results = [paste_back(result, crop_info['M_c2o'], src_img, mask_ori_float) for result in results]
1115
+ print(results.shape)
1116
+ self.save_results(results, save_path, audio_path)
1117
+
1118
+ def drive_video_by_mix(self, video_path, driving_video_path, kp_infos, save_path, audio_path):
1119
+ # prepare driving videos
1120
+ driving_imgs = self.read_video(driving_video_path, interval=1)
1121
+ dst_imgs_256x256, d_lmk_lst = self.crop_driving_videos(driving_imgs, do_crop=True)
1122
+ emo_kp_infos = self.prepare_driving_videos(dst_imgs_256x256)
1123
+ # mix kp_infos
1124
+ #driving_template_dct = self.get_driving_template(emo_kp_infos, smooth=True, dtype="np")
1125
+ driving_template_dct = self.mix_kp_infos(emo_kp_infos, kp_infos, smooth=True, dtype="np")
1126
+ # driven
1127
+ self.video_lip_retargeting(
1128
+ video_path, None,
1129
+ save_path, audio_path,
1130
+ driving_template_dct=driving_template_dct, retargeting_ragion="exp"
1131
+ )
1132
+
1133
+ def load_source_video(self, video_info, n_frames=-1):
1134
+ reader = imageio.get_reader(video_info, "ffmpeg")
1135
+
1136
+ ret = []
1137
+ for idx, frame_rgb in enumerate(reader):
1138
+ if n_frames > 0 and idx >= n_frames:
1139
+ break
1140
+ ret.append(frame_rgb)
1141
+
1142
+ reader.close()
1143
+
1144
+ return ret
1145
+
1146
+ def video_lip_retargeting(self, video_path, kp_infos, save_path, audio_path, c_d_eyes_lst=None, c_d_lip_lst=None, smooth=False, driving_template_dct=None, retargeting_ragion="exp"):
1147
+ # 0. process source motion template
1148
+ source_rgb_lst = load_video(video_path)
1149
+ source_rgb_lst = [resize_to_limit(img, self.cfg.source_max_dim, self.cfg.source_division) for img in source_rgb_lst]
1150
+ img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = self.crop_source_video(source_rgb_lst, do_crop=True)
1151
+ c_s_eyes_lst, c_s_lip_lst = self.calc_ratio(source_lmk_crop_lst)
1152
+ I_s_lst = self.prepare_videos(img_crop_256x256_lst)
1153
+ source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=25)
1154
+ # 1. prepare driving template
1155
+ if driving_template_dct is None:
1156
+ driving_template_dct = self.get_driving_template(kp_infos, smooth=smooth, dtype="np")
1157
+ # 2. driving
1158
+ n_frames = min(source_template_dct['n_frames'], driving_template_dct['n_frames'])
1159
+ # driving params
1160
+ I_p_lst = []
1161
+ I_p_pstbk_lst = []
1162
+ R_d_0, x_d_0_info = None, None
1163
+ flag_normalize_lip = self.cfg.flag_normalize_lip
1164
+ flag_relative_motion = True #self.cfg.flag_relative_motion
1165
+ flag_source_video_eye_retargeting = self.cfg.flag_source_video_eye_retargeting
1166
+ lip_normalize_threshold = self.cfg.lip_normalize_threshold
1167
+ source_video_eye_retargeting_threshold = self.cfg.source_video_eye_retargeting_threshold
1168
+ animation_region = 'lip' #self.cfg.animation_region
1169
+ driving_option = self.cfg.driving_option
1170
+ flag_stitching = self.cfg.flag_stitching
1171
+ flag_eye_retargeting = self.cfg.flag_eye_retargeting
1172
+ flag_lip_retargeting = self.cfg.flag_lip_retargeting
1173
+ driving_multiplier = self.cfg.driving_multiplier
1174
+ driving_smooth_observation_variance = self.cfg.driving_smooth_observation_variance
1175
+
1176
+ key_r = 'R' if 'R' in driving_template_dct['motion'][0].keys() else 'R_d'
1177
+ if flag_relative_motion:
1178
+ x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)]
1179
+ for i in range(n_frames):
1180
+ for idx in [6, 12, 14, 17, 19, 20]:
1181
+ # lip motion use abs motion
1182
+ x_d_exp_lst[i][:, idx, :] = driving_template_dct['motion'][i]['exp'][:, idx, :]
1183
+ x_d_exp_lst_smooth = ksmooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, self.device, driving_smooth_observation_variance)
1184
+
1185
+ if animation_region == "all" or animation_region == "pose" or "all" in animation_region:
1186
+ x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)]
1187
+ x_d_r_lst_smooth = ksmooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, self.device, driving_smooth_observation_variance)
1188
+ else:
1189
+ x_d_exp_lst = [driving_template_dct['motion'][i]['exp'] for i in range(n_frames)]
1190
+ x_d_exp_lst_smooth = ksmooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, self.device, driving_smooth_observation_variance)
1191
+
1192
+ if animation_region == "all" or animation_region == "pose" or "all" in animation_region:
1193
+ x_d_r_lst = [driving_template_dct['motion'][i][key_r] for i in range(n_frames)]
1194
+ x_d_r_lst_smooth = ksmooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, self.device, driving_smooth_observation_variance)
1195
+
1196
+ # driving all
1197
+ for i in track(range(n_frames), description='🚀Retargeting...', total=n_frames):
1198
+ x_s_info = source_template_dct['motion'][i]
1199
+ x_s_info = dct2device(x_s_info, self.device)
1200
+
1201
+ source_lmk = source_lmk_crop_lst[i]
1202
+ img_crop_256x256 = img_crop_256x256_lst[i]
1203
+ I_s = I_s_lst[i]
1204
+ f_s = self.extract_feature_3d(I_s)
1205
+
1206
+ x_c_s = x_s_info['kp']
1207
+ R_s = x_s_info['R']
1208
+ x_s =x_s_info['x_s']
1209
+
1210
+ # let lip-open scalar to be 0 at first if the input is a video
1211
+ lip_delta_before_animation = None
1212
+ if flag_normalize_lip and flag_relative_motion and source_lmk is not None:
1213
+ c_d_lip_before_animation = [0.]
1214
+ combined_lip_ratio_tensor_before_animation = self.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
1215
+ if combined_lip_ratio_tensor_before_animation[0][0] >= lip_normalize_threshold:
1216
+ lip_delta_before_animation = self.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
1217
+ else:
1218
+ lip_delta_before_animation = None
1219
+
1220
+ # let eye-open scalar to be the same as the first frame if the latter is eye-open state
1221
+ eye_delta_before_animation = None
1222
+ if flag_source_video_eye_retargeting and source_lmk is not None:
1223
+ if i == 0:
1224
+ combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0]
1225
+ c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]]
1226
+ if c_d_eye_before_animation_frame_zero[0][0] < source_video_eye_retargeting_threshold:
1227
+ c_d_eye_before_animation_frame_zero = [[0.39]]
1228
+ combined_eye_ratio_tensor_before_animation = self.calc_combined_eye_ratio(c_d_eye_before_animation_frame_zero, source_lmk)
1229
+ eye_delta_before_animation = self.retarget_eye(x_s, combined_eye_ratio_tensor_before_animation)
1230
+
1231
+ if flag_stitching: # prepare for paste back
1232
+ mask_ori_float = prepare_paste_back(self.mask_crop, source_M_c2o_lst[i], dsize=(source_rgb_lst[i].shape[1], source_rgb_lst[i].shape[0]))
1233
+
1234
+ x_d_i_info = driving_template_dct['motion'][i]
1235
+ x_d_i_info = dct2device(x_d_i_info, self.device)
1236
+ R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys
1237
+
1238
+ if i == 0: # cache the first frame
1239
+ R_d_0 = R_d_i
1240
+ x_d_0_info = x_d_i_info.copy()
1241
+
1242
+ delta_new = x_s_info['exp'].clone()
1243
+ if flag_relative_motion:
1244
+ if animation_region == "all" or animation_region == "pose" or "all" in animation_region:
1245
+ R_new = x_d_r_lst_smooth[i]
1246
+ else:
1247
+ R_new = R_s
1248
+ if animation_region == "all" or animation_region == "exp":
1249
+ for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
1250
+ delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :]
1251
+ delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1]
1252
+ delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2]
1253
+ delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2]
1254
+ delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:]
1255
+ elif animation_region == "all_wo_lip" or animation_region == "exp_wo_lip":
1256
+ for idx in [1, 2, 11, 13, 15, 16, 18]:
1257
+ delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :]
1258
+ delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1]
1259
+ delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2]
1260
+ delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2]
1261
+ delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:]
1262
+ elif animation_region == "lip":
1263
+ for lip_idx in [6, 12, 14, 17, 19, 20]:
1264
+ delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :]
1265
+ elif animation_region == "eyes":
1266
+ for eyes_idx in [11, 13, 15, 16, 18]:
1267
+ delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :]
1268
+
1269
+ scale_new = x_s_info['scale']
1270
+ t_new = x_s_info['t']
1271
+ else:
1272
+ if animation_region == "all" or animation_region == "pose" or "all" in animation_region:
1273
+ R_new = x_d_r_lst_smooth[i]
1274
+ else:
1275
+ R_new = R_s
1276
+ if animation_region == "all" or animation_region == "exp":
1277
+ for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
1278
+ delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :]
1279
+ delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1]
1280
+ delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2]
1281
+ delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2]
1282
+ delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:]
1283
+ elif animation_region == "all_wo_lip" or animation_region == "exp_wo_lip":
1284
+ for idx in [1, 2, 11, 13, 15, 16, 18]:
1285
+ delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :]
1286
+ delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1]
1287
+ delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2]
1288
+ delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2]
1289
+ delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:]
1290
+ elif animation_region == "lip":
1291
+ for lip_idx in [6, 12, 14, 17, 19, 20]:
1292
+ delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :]
1293
+ elif animation_region == "eyes":
1294
+ for eyes_idx in [11, 13, 15, 16, 18]:
1295
+ delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :]
1296
+ scale_new = x_s_info['scale']
1297
+ if animation_region == "all" or animation_region == "pose" or "all" in animation_region:
1298
+ t_new = x_d_i_info['t']
1299
+ else:
1300
+ t_new = x_s_info['t']
1301
+
1302
+ t_new[..., 2].fill_(0) # zero tz
1303
+ x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
1304
+
1305
+ # Algorithm 1:
1306
+ if not flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting:
1307
+ # without stitching or retargeting
1308
+ if flag_normalize_lip and lip_delta_before_animation is not None:
1309
+ x_d_i_new += lip_delta_before_animation
1310
+ if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
1311
+ x_d_i_new += eye_delta_before_animation
1312
+ else:
1313
+ pass
1314
+ elif flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting:
1315
+ # with stitching and without retargeting
1316
+ if flag_normalize_lip and lip_delta_before_animation is not None:
1317
+ x_d_i_new = self.stitching(x_s, x_d_i_new) + lip_delta_before_animation
1318
+ else:
1319
+ x_d_i_new = self.stitching(x_s, x_d_i_new)
1320
+ if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
1321
+ x_d_i_new += eye_delta_before_animation
1322
+ else:
1323
+ eyes_delta, lip_delta = None, None
1324
+ if flag_eye_retargeting and source_lmk is not None and c_d_eyes_lst is not None:
1325
+ c_d_eyes_i = c_d_eyes_lst[i]
1326
+ combined_eye_ratio_tensor = self.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
1327
+ # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
1328
+ eyes_delta = self.retarget_eye(x_s, combined_eye_ratio_tensor)
1329
+ if flag_lip_retargeting and source_lmk is not None and c_d_lip_lst is not None:
1330
+ c_d_lip_i = c_d_lip_lst[i]
1331
+ combined_lip_ratio_tensor = self.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
1332
+ # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
1333
+ lip_delta = self.retarget_lip(x_s, combined_lip_ratio_tensor)
1334
+
1335
+ if flag_relative_motion: # use x_s
1336
+ x_d_i_new = x_s + \
1337
+ (eyes_delta if eyes_delta is not None else 0) + \
1338
+ (lip_delta if lip_delta is not None else 0)
1339
+ else: # use x_d,i
1340
+ x_d_i_new = x_d_i_new + \
1341
+ (eyes_delta if eyes_delta is not None else 0) + \
1342
+ (lip_delta if lip_delta is not None else 0)
1343
+
1344
+ if flag_stitching:
1345
+ x_d_i_new = self.stitching(x_s, x_d_i_new)
1346
+
1347
+ x_d_i_new = x_s + (x_d_i_new - x_s) * driving_multiplier
1348
+ out = self.warp_decode(f_s, x_s, x_d_i_new)
1349
+ I_p_i = self.parse_output(out['out'])[0]
1350
+ I_p_lst.append(I_p_i)
1351
+
1352
+ if flag_stitching:
1353
+ # TODO: the paste back procedure is slow, considering optimize it using multi-threading or GPU
1354
+ #I_p_pstbk = self.paste_back_by_face_mask(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], img_crop_256x256, use_laplacian=True)
1355
+ I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_float, use_laplacian=True)
1356
+ I_p_pstbk_lst.append(I_p_pstbk)
1357
+
1358
+ if len(I_p_pstbk_lst) > 0:
1359
+ self.save_results(I_p_pstbk_lst, save_path, audio_path)
1360
+ else:
1361
+ self.save_results(I_p_lst, save_path, audio_path)
1362
+
1363
+ @torch.no_grad()
1364
+ def video_reconstruction_test(self, video_tensor, xs, save_path):
1365
+ # video_tensor, (1, F, C, H, W), [-1, 1]
1366
+ # xs, (1, F, 63)
1367
+ result_lst = []
1368
+ #ori_videos = []
1369
+ video_tensor = video_tensor[0:1] * 0.5 + 0.5 # [-1, 1] -> [0, 1], 1xTx3xHxW
1370
+ video_tensor = torch.clip(video_tensor, 0, 1)
1371
+ video_tensor = video_tensor.permute(1, 0, 2, 3, 4) # 1xTx3xHxW -> Tx1x3xHxW
1372
+ video = video_tensor.to(self.device)
1373
+ xs = xs[0:1].permute(1, 0, 2) # 1xTx63 -> Tx1x63
1374
+ xs = xs.reshape(-1, 1, 21, 3)
1375
+ xs = xs.to(self.device)
1376
+
1377
+ x_s_0 = xs[0]
1378
+ I_s_0 = torch.nn.functional.interpolate(video[0], size=(256, 256), mode='bilinear')
1379
+ f_s_0 = self.extract_feature_3d(I_s_0)
1380
+
1381
+ for i in range(video_tensor.shape[0]):
1382
+ #I_s = video[i] # 1x3xHxW
1383
+ #ori_videos.append((I_s.squeeze(0).squeeze(0).permute(1, 2, 0).cpu().numpy()*255).astype(np.uint8))
1384
+ x_s = self.stitching(x_s_0, xs[i])
1385
+ out = self.warp_decode(f_s_0, x_s_0, x_s)
1386
+ I_p_i = self.parse_output(out['out'])[0]
1387
+ result_lst.append(I_p_i)
1388
+
1389
+ #save_dir = osp.dirname(save_path)
1390
+ #ori_path = osp.join(save_dir, "ori.mp4")
1391
+ #save_path = osp.join(save_dir, "rec.mp4")
1392
+ self.save_results(result_lst, save_path, audio_path=None)
1393
+ #self.save_results(ori_videos, ori_path, audio_path=None)
1394
+
1395
+ @torch.no_grad()
1396
+ def self_driven(self, image_tensor, xs, save_path, length):
1397
+ result_lst = []
1398
+ image_tensor = image_tensor[0:1] * 0.5 + 0.5 # [-1, 1] -> [0, 1], 1x3xHxW
1399
+ image_tensor = torch.clip(image_tensor, 0, 1)
1400
+ image = image_tensor.to(self.device)
1401
+ I_s_0 = torch.nn.functional.interpolate(image, size=(256, 256), mode='bilinear')
1402
+
1403
+ xs = xs[0:1].permute(1, 0, 2) # 1xTx63 -> Tx1x63
1404
+ xs = xs.reshape(-1, 1, 21, 3)
1405
+ xs = xs.to(self.device)
1406
+
1407
+ x_s_0 = xs[0]
1408
+ f_s_0 = self.extract_feature_3d(I_s_0)
1409
+
1410
+ for i in range(xs.shape[0]):
1411
+ x_d = self.stitching(x_s_0, xs[i])
1412
+ out = self.warp_decode(f_s_0, x_s_0, x_d)
1413
+ I_p_i = self.parse_output(out['out'])[0]
1414
+ result_lst.append(I_p_i)
1415
+
1416
+ assert len(result_lst) == length, f"length of result_lst is {len(result_lst)}, but length is {length}"
1417
+
1418
+ self.save_results(result_lst, save_path, audio_path=None)
1419
+
1420
+
src/examples/driving_audios/10.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79b53cbd91ebd7756b51f4d388a769b461a247f26acae5c362ca326e27c23626
3
+ size 2880078
src/examples/driving_audios/5.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79b53cbd91ebd7756b51f4d388a769b461a247f26acae5c362ca326e27c23626
3
+ size 2880078
src/examples/driving_audios/6.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90be6ae092eaa9be4e74e0bed56ef343a825bc2c899d2868e0e3aee494c86a04
3
+ size 1323078
src/examples/driving_audios/tmp_5.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2f615328211bb938ab7f6b603631695106d2e23ceaa4dfcd4f491bc5dc2faca
3
+ size 544044
src/examples/reference_images/1.jpg ADDED

Git LFS Details

  • SHA256: 362a14590bbfa4517e00338941f87f51fa9d6da0beaa827f6ba28a0e490888d4
  • Pointer size: 131 Bytes
  • Size of remote file: 225 kB
src/examples/reference_images/2.jpg ADDED
src/examples/reference_images/3.jpg ADDED
src/examples/reference_images/4.jpg ADDED
src/examples/reference_images/5.jpg ADDED
src/examples/reference_images/6.jpg ADDED
src/examples/reference_images/7.jpg ADDED

Git LFS Details

  • SHA256: 9f03a04f1de9055c626aa09c471115da0365d9d6c25a62c227e8eb3dfba53993
  • Pointer size: 131 Bytes
  • Size of remote file: 774 kB
src/examples/silent-audio.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:231cedffe295d0f5c8ea8569af9edc2471c262410689190bb705fb0adb62f63f
3
+ size 352878
src/models/audio/__pycache__/audio_processer.cpython-310.pyc ADDED
Binary file (12.1 kB). View file
 
src/models/audio/__pycache__/audio_proj.cpython-310.pyc ADDED
Binary file (4.66 kB). View file
 
src/models/audio/__pycache__/hubert.cpython-310.pyc ADDED
Binary file (3.45 kB). View file
 
src/models/audio/__pycache__/wav2vec.cpython-310.pyc ADDED
Binary file (6 kB). View file
 
src/models/audio/__pycache__/wav2vec2.cpython-310.pyc ADDED
Binary file (4.32 kB). View file
 
src/models/audio/__pycache__/wav2vec_modified.cpython-310.pyc ADDED
Binary file (6.78 kB). View file
 
src/models/audio/audio_processer.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio processer for talking data.
2
+ Author: linzhihui.lzh
3
+ Date: 2024-12-12
4
+ """
5
+ import os
6
+ from re import A
7
+ import sys
8
+ import os.path as osp
9
+
10
+ from typing import List, Dict, Tuple, Optional, Union, Any
11
+
12
+ import yaml
13
+ from omegaconf import OmegaConf
14
+
15
+ import math
16
+ import librosa
17
+ import numpy as np
18
+
19
+ from einops import rearrange
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+
24
+ from pydub import AudioSegment
25
+ # from audio_separator.separator import Separator
26
+
27
+ sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__))))))
28
+ from src.utils.rprint import rlog as log
29
+ from src.utils.util import resample_audio
30
+
31
+ from src.models.audio.wav2vec_modified import Wav2VecModel
32
+ from src.models.audio.hubert import HubertModel
33
+
34
+
35
+ def pad_audio(audio, audio_unit=320, pad_threshold=80):
36
+ batch_size, audio_len = audio.shape
37
+ n_units = audio_len // audio_unit
38
+ side_len = math.ceil((audio_unit * n_units + pad_threshold - audio_len) / 2)
39
+ if side_len >= 0:
40
+ reflect_len = side_len // 2
41
+ replicate_len = side_len % 2
42
+ if reflect_len > 0:
43
+ audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
44
+ audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
45
+ if replicate_len > 0:
46
+ audio = F.pad(audio, (1, 1), mode='replicate')
47
+
48
+ return audio
49
+
50
+
51
+ def cut_audio(audio_path: str, save_dir: str, length=60) -> List[str]:
52
+ """Cut audio into sub-divisions and return subfile paths. Supports wav format.
53
+
54
+ Args:
55
+ audio_path (str): the source audio file path
56
+ save_dir (str): the save directory of sub-divisions
57
+ length (int, optional): The max length of each sub-division. Defaults to 60 secs.
58
+
59
+ Returns:
60
+ List[str]: the subfile paths
61
+ """
62
+ audio_name = osp.basename(audio_path).split('.')[0]
63
+ audio = AudioSegment.from_wav(audio_path)
64
+ segment_length = length * 1000. # pydub uses milliseconds
65
+ num_segments = math.ceil(len(audio) / segment_length)
66
+
67
+ os.makedirs(save_dir, exist_ok=True)
68
+ audio_list = []
69
+
70
+ for i in range(num_segments):
71
+ start_time = i * segment_length
72
+ end_time = min((i + 1) * segment_length, len(audio))
73
+ segment = audio[start_time:end_time]
74
+
75
+ path = osp.join(save_dir, f"{audio_name}_segment_{i+1}.wav")
76
+ audio_list.append(path)
77
+ segment.export(path, format="wav")
78
+ return audio_list
79
+
80
+
81
+ class AudioProcessor(object):
82
+ def __init__(self, cfg_path: str, is_training: bool = False) -> None:
83
+ cfg = OmegaConf.load(cfg_path)
84
+ self.cfg = cfg
85
+ self.is_training = is_training
86
+ log("========================================= Audio Processer =========================================")
87
+ log(OmegaConf.to_yaml(cfg))
88
+
89
+ # setting device
90
+ self.device_id = cfg.device_params.device_id
91
+ self.use_half = cfg.device_params.flag_use_half_precision
92
+ if cfg.device_params.flag_force_cpu:
93
+ self.device = 'cpu'
94
+ else:
95
+ try:
96
+ if torch.backends.mps.is_available():
97
+ self.device = 'mps'
98
+ else:
99
+ self.device = 'cuda:' + str(self.device_id)
100
+ except:
101
+ self.device = 'cuda:' + str(self.device_id)
102
+
103
+ # init audio separator
104
+ self.audio_separator = None
105
+ self.cache_dir = cfg.cache_dir
106
+ self.tmp_dir = cfg.tmp_dir
107
+ self.use_audio_separator = cfg.model_params.use_audio_separator
108
+ self.audio_separator_name = cfg.model_params.audio_separator_name
109
+ self.audio_separator_path = cfg.model_weights.audio_separator_path
110
+ self.set_audio_separator(cfg.cache_dir)
111
+
112
+ # load audio encoder, wav2vec or hubert
113
+ self.model_name = cfg.model_params.model_name
114
+ self.is_chinese = cfg.model_params.is_chinese
115
+ self.audio_encoder = self.load_model(
116
+ model_name = cfg.model_params.model_name,
117
+ model_type = cfg.model_params.model_type,
118
+ is_chinese = cfg.model_params.is_chinese,
119
+ )
120
+ self.only_last_features = cfg.model_params.only_last_features
121
+ if cfg.model_params.only_last_features:
122
+ self.feature_shape = (1, 768)
123
+ else:
124
+ self.feature_shape = (12, 768) # features of 12 blocks
125
+
126
+ # init data params
127
+ self.sample_strategy = cfg.data_params.sample_strategy
128
+ self.sample_rate = cfg.data_params.sample_rate
129
+ self.fps = cfg.data_params.fps
130
+ self.audio_unit = cfg.data_params.sample_rate / cfg.data_params.fps # num of audio samples per frame
131
+ self.max_length = cfg.data_params.max_length
132
+ self.subclip_len = cfg.data_params.sub_clip_length
133
+ self.save_to_cpu = cfg.data_params.save_to_cpu
134
+ self.pad_mode = cfg.data_params.audio_pad_mode
135
+
136
+ log("========================================= Audio Processer: Done =========================================")
137
+
138
+ def load_model(self, model_name: str="wav2vec", model_type: str="base", is_chinese: bool = False):
139
+ assert model_name in ["wav2vec", "hubert"], f"Unknown audio model {model_name}, only support wav2vec or hubert"
140
+ assert model_type in ["base", "large"], f"Unknown audio model type {model_type}, only support base or large"
141
+
142
+ if model_name == "wav2vec":
143
+ # load wav2vec model weights
144
+ if is_chinese:
145
+ if model_type == "base":
146
+ model_weight_path = self.cfg.model_weights.wav2vec_path.chinese.base
147
+ else:
148
+ model_weight_path = self.cfg.model_weights.wav2vec_path.chinese.large
149
+ else:
150
+ if model_type == "base":
151
+ model_weight_path = self.cfg.model_weights.wav2vec_path.default.base
152
+ else:
153
+ model_weight_path = self.cfg.model_weights.wav2vec_path.default.large
154
+ if model_weight_path is None:
155
+ raise ValueError(f"model_weight_path is None")
156
+ audio_encoder = Wav2VecModel.from_pretrained(model_weight_path, local_files_only=True).to(device=self.device)
157
+ else:
158
+ if is_chinese:
159
+ if model_type == "base":
160
+ model_weight_path = self.cfg.model_weights.hubert_path.chinese.base
161
+ else:
162
+ model_weight_path = self.cfg.model_weights.hubert_path.chinese.large
163
+ else:
164
+ if model_type == "base":
165
+ model_weight_path = self.cfg.model_weights.hubert_path.default.base
166
+ else:
167
+ model_weight_path = self.cfg.model_weights.hubert_path.default.large
168
+ if model_weight_path is None:
169
+ raise ValueError(f"model_weight_path is None")
170
+ audio_encoder = HubertModel.from_pretrained(model_weight_path, local_files_only=True).to(device=self.device)
171
+
172
+ log(f"{model_name}-{model_type}-chinese-{is_chinese} model has beed loaded from {model_weight_path}")
173
+ total_params = sum(p.numel() for p in audio_encoder.parameters())
174
+ print('Number of parameter: % .4fM' % (total_params / 1e6))
175
+
176
+ # weights initialization
177
+ audio_encoder.feature_extractor._freeze_parameters()
178
+ if not self.cfg.model_params.is_original:
179
+ frozen_layers = [0, 1]
180
+ for name, param in audio_encoder.named_parameters():
181
+ if name.startswith("feature_projection"):
182
+ param.requires_grad = False
183
+ if name.startswith("encoder.layers"):
184
+ layer = int(name.split(".")[2])
185
+ if layer in frozen_layers:
186
+ param.requires_grad = False
187
+
188
+ audio_encoder = audio_encoder.to(self.device)
189
+ if self.use_half:
190
+ audio_encoder = audio_encoder.half()
191
+ audio_encoder.eval()
192
+ return audio_encoder
193
+
194
+ def set_audio_separator(self, output_dir: str) -> None:
195
+ del self.audio_separator
196
+
197
+ if self.audio_separator_name is not None and self.use_audio_separator:
198
+ try:
199
+ os.makedirs(output_dir, exist_ok=True)
200
+ except OSError as _:
201
+ print("Fail to create the output cache dir.")
202
+ self.audio_separator = Separator(
203
+ output_dir=output_dir,
204
+ output_single_stem="vocals",
205
+ model_file_dir=self.audio_separator_path,
206
+ )
207
+ self.audio_separator.load_model(self.audio_separator_name)
208
+ assert self.audio_separator.model_instance is not None, "Fail to load audio separate model."
209
+ else:
210
+ self.audio_separator=None
211
+ log("Use audio directly without vocals seperator.")
212
+
213
+ def seperate_audio(self, audio_path: str, output_dir: Union[str, None] = None) -> str:
214
+ if output_dir is not None:
215
+ if output_dir != self.cache_dir:
216
+ # reload audio separator
217
+ self.set_audio_separator(output_dir)
218
+
219
+ if self.audio_separator is not None:
220
+ # 1. separate vocals
221
+ # TODO: process in memory
222
+ try:
223
+ outputs = self.audio_separator.separate(audio_path)
224
+ if len(outputs) <= 0:
225
+ raise RuntimeError("Audio separate failed.")
226
+
227
+ vocal_audio_file = outputs[0]
228
+ vocal_audio_name, _ = os.path.splitext(vocal_audio_file)
229
+ vocal_audio_file = os.path.join(self.audio_separator.output_dir, vocal_audio_file)
230
+ vocal_audio_file = resample_audio(vocal_audio_file, os.path.join(self.audio_separator.output_dir, f"{vocal_audio_name}-16k.wav"), self.sample_rate)
231
+ except Exception as e:
232
+ log(f"Fail to separate vocals from {audio_path}, error info [{e}]")
233
+ vocal_audio_file=audio_path
234
+ else:
235
+ vocal_audio_file=audio_path
236
+
237
+ return vocal_audio_file
238
+
239
+ def load_audio(self, audio_path: str, mono: bool = True, duration: Optional[float] = None) -> Any:
240
+ try:
241
+ audio_data, sampling_rate = librosa.load(audio_path, sr=self.sample_rate, mono=mono, duration=duration)
242
+ except Exception as e:
243
+ raise RuntimeError(f"Fail to load audio from {audio_path}, error info [{e}]")
244
+ return audio_data, sampling_rate
245
+
246
+ def prepare_audio_data(self, audio_data: Union[np.ndarray, torch.Tensor], n_frames: Optional[int]=None) -> Tuple[List[Any], int]:
247
+ """Prepare audio data for processing.
248
+ """
249
+ clip_len = int(len(audio_data) / self.audio_unit)
250
+ if n_frames is not None:
251
+ if abs(n_frames - clip_len) > 2:
252
+ log(f"The number of frames must be close to the clip length (in 80ms), got {n_frames} and {clip_len}")
253
+ return [], n_frames
254
+ clip_len = n_frames
255
+ else:
256
+ n_frames = clip_len
257
+
258
+ # normalize audio, replace Wav2Vec2FeatureExtractor
259
+ if isinstance(audio_data, np.ndarray):
260
+ audio_data = torch.from_numpy(audio_data).to(self.device)
261
+ assert audio_data.ndim == 1, 'Audio must be 1D tensor.'
262
+ audio_data = (audio_data - torch.mean(audio_data)) / (torch.std(audio_data) + 1e-7)
263
+ #log(f"audio loaded! {audio_data.shape}")
264
+
265
+ # padding
266
+ # padding audio to fit the clip length
267
+ n_audio_samples = round(self.audio_unit * clip_len)
268
+ n_padding_audio_samples = n_audio_samples - len(audio_data)
269
+ n_padding_frames = math.ceil(n_padding_audio_samples / self.audio_unit)
270
+ if n_padding_audio_samples > 0:
271
+ if self.pad_mode == 'zero':
272
+ padding_value = 0
273
+ elif self.pad_mode == 'replicate':
274
+ padding_value = float(audio_data[-1])
275
+ else:
276
+ raise ValueError(f'Unknown pad mode: {self.pad_mode}')
277
+ audio_data = F.pad(audio_data, (0, n_padding_audio_samples), value=padding_value)
278
+
279
+ # devide audio into sub-divisions for saving GPU memory
280
+ audio_segments = []
281
+ if clip_len <= self.subclip_len:
282
+ n_subdivision = 1
283
+ subclip_len = clip_len
284
+ else:
285
+ n_subdivision = math.ceil(clip_len / self.subclip_len)
286
+ subclip_len = self.subclip_len
287
+
288
+ for i in range(0, n_subdivision):
289
+ start_idx = i * subclip_len
290
+ end_idx = min(start_idx + subclip_len, clip_len)
291
+ # debug
292
+ #log(f"[{i+1}/{n_subdivision}] data index [{round(start_idx * self.audio_unit)}, {round(end_idx * self.audio_unit)})")
293
+ audio_segments.append(
294
+ {
295
+ "data": audio_data[round(start_idx * self.audio_unit):round(end_idx * self.audio_unit)].unsqueeze(0),
296
+ "start_idx": start_idx,
297
+ "end_idx": end_idx,
298
+ "length": end_idx - start_idx
299
+ }
300
+ )
301
+ return audio_segments, n_frames
302
+
303
+ def get_audio_embedding(self, audio, clip_len: int) -> torch.Tensor:
304
+ if audio.ndim == 2:
305
+ # Extract audio features
306
+ assert audio.shape[1] == 16000 * clip_len / self.fps, \
307
+ f'Incorrect audio length {audio.shape[1]}'
308
+
309
+ # Extract audio features
310
+ if self.use_half:
311
+ audio = audio.half()
312
+ embeddings = self.audio_encoder(
313
+ pad_audio(audio), seq_len=clip_len, sample_strategy=self.sample_strategy, output_hidden_states=True
314
+ ) # (N, L, 768)
315
+ assert len(embeddings) > 0, "Fail to extract audio embedding"
316
+
317
+ if self.only_last_features:
318
+ audio_emb = embeddings.last_hidden_state.squeeze(0)
319
+ else:
320
+ audio_emb = torch.stack(
321
+ embeddings.hidden_states[1:], dim=1
322
+ ).squeeze(0)
323
+ audio_emb = rearrange(audio_emb, "b s d -> s b d")
324
+
325
+ elif audio.ndim == 3:
326
+ assert audio.shape[1] == clip_len, f'Incorrect audio feature length {audio.shape[1]}'
327
+ audio_emb = audio
328
+ else:
329
+ raise ValueError(f'Incorrect audio input shape {audio.shape}')
330
+
331
+ return audio_emb
332
+
333
+ def get_audio_embeddings(self, audio_segments: List[Any]) -> Optional[torch.Tensor]:
334
+ audio_embs = []
335
+ for audio_segment in audio_segments:
336
+ if self.is_training:
337
+ audio_emb = self.get_audio_embedding(audio_segment["data"], audio_segment["length"])
338
+ else:
339
+ with torch.no_grad():
340
+ audio_emb = self.get_audio_embedding(audio_segment["data"], audio_segment["length"])
341
+
342
+ audio_emb = audio_emb.cpu() if self.save_to_cpu else audio_emb
343
+ audio_embs.append(audio_emb)
344
+ #log(f"audio segment [{audio_segment['start_idx']}, {audio_segment['end_idx']}) has been processed.")
345
+
346
+ if len(audio_embs) == 0:
347
+ return None
348
+
349
+ audio_emb = torch.cat(audio_embs, dim=0)
350
+
351
+ return audio_emb
352
+
353
+ def preprocess(
354
+ self,
355
+ audio_path: str,
356
+ n_frames: Optional[int] = None,
357
+ duration: Optional[float] = None,
358
+ need_seperate: bool = False
359
+ ):
360
+ """ Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate.
361
+ The separated vocal track is then converted into wav2vec2 for further processing or analysis.
362
+ """
363
+ if need_seperate:
364
+ vocal_audio_file = self.seperate_audio(audio_path)
365
+ else:
366
+ vocal_audio_file = audio_path
367
+
368
+ audio_data, sampling_rate = self.load_audio(vocal_audio_file, duration=duration)
369
+
370
+ assert sampling_rate == 16000, "The sample rate of audio must be 16000"
371
+ audio_segments, n_frames = self.prepare_audio_data(audio_data, n_frames)
372
+ audio_emb = self.get_audio_embeddings(audio_segments)
373
+ if audio_emb is None:
374
+ log(f"{audio_path} has been processed, but no audio embedding, set as 'None'.")
375
+ #else:
376
+ #log(f"{audio_path} has been processed, audio embedding shape {audio_emb.shape}.")
377
+ return audio_emb, n_frames
378
+
379
+ def preprocess_long(
380
+ self,
381
+ audio_path: str,
382
+ need_seperate: bool = False
383
+ ):
384
+ audio_list = cut_audio(audio_path, self.tmp_dir, length=self.max_length)
385
+ audio_emb_list = []
386
+ l = 0
387
+
388
+ for idx, audio_path in enumerate(audio_list):
389
+ padding = (idx+1) == len(audio_list)
390
+ emb, length = self.preprocess(audio_path, need_seperate=need_seperate)
391
+ audio_emb_list.append(emb)
392
+ log(f"Processing audio {idx+1}/{len(audio_list)}, path: {audio_path} length: {length}")
393
+ l += length
394
+
395
+ audio_emb = torch.cat(audio_emb_list)
396
+ audio_length = l
397
+
398
+ # remove tmp file
399
+ for audio_path in audio_list:
400
+ os.remove(audio_path)
401
+
402
+ return audio_emb, audio_length
403
+
404
+ def __enter__(self):
405
+ return self
406
+
407
+
src/models/audio/audio_proj.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module provides the implementation of an Audio Projection Model, which is designed for
3
+ audio processing tasks. The model takes audio embeddings as input and outputs context tokens
4
+ that can be used for various downstream applications, such as audio analysis or synthesis.
5
+
6
+ The AudioProjModel class is based on the ModelMixin class from the diffusers library, which
7
+ provides a foundation for building custom models. This implementation includes multiple linear
8
+ layers with ReLU activation functions and a LayerNorm for normalization.
9
+
10
+ Key Features:
11
+ - Audio embedding input with flexible sequence length and block structure.
12
+ - Multiple linear layers for feature transformation.
13
+ - ReLU activation for non-linear transformation.
14
+ - LayerNorm for stabilizing and speeding up training.
15
+ - Rearrangement of input embeddings to match the model's expected input shape.
16
+ - Customizable number of blocks, channels, and context tokens for adaptability.
17
+
18
+ The module is structured to be easily integrated into larger systems or used as a standalone
19
+ component for audio feature extraction and processing.
20
+
21
+ Classes:
22
+ - AudioProjModel: A class representing the audio projection model with configurable parameters.
23
+
24
+ Functions:
25
+ - (none)
26
+
27
+ Dependencies:
28
+ - torch: For tensor operations and neural network components.
29
+ - diffusers: For the ModelMixin base class.
30
+ - einops: For tensor rearrangement operations.
31
+
32
+ """
33
+
34
+ import torch
35
+ from diffusers import ModelMixin
36
+ from einops import rearrange
37
+ from torch import nn
38
+
39
+
40
+ class AudioProjModel(ModelMixin):
41
+ """Audio Projection Model
42
+
43
+ This class defines an audio projection model that takes audio embeddings as input
44
+ and produces context tokens as output. The model is based on the ModelMixin class
45
+ and consists of multiple linear layers and activation functions. It can be used
46
+ for various audio processing tasks.
47
+
48
+ Attributes:
49
+ seq_len (int): The length of the audio sequence.
50
+ blocks (int): The number of blocks in the audio projection model.
51
+ channels (int): The number of channels in the audio projection model.
52
+ intermediate_dim (int): The intermediate dimension of the model.
53
+ context_tokens (int): The number of context tokens in the output.
54
+ output_dim (int): The output dimension of the context tokens.
55
+
56
+ Methods:
57
+ __init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768):
58
+ Initializes the AudioProjModel with the given parameters.
59
+ forward(self, audio_embeds):
60
+ Defines the forward pass for the AudioProjModel.
61
+ Parameters:
62
+ audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
63
+ Returns:
64
+ context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
65
+
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ seq_len=5,
71
+ blocks=12, # add a new parameter blocks
72
+ channels=768, # add a new parameter channels
73
+ intermediate_dim=512,
74
+ output_dim=768,
75
+ context_tokens=32,
76
+ ):
77
+ super().__init__()
78
+
79
+ self.seq_len = seq_len
80
+ self.blocks = blocks
81
+ self.channels = channels
82
+ self.input_dim = (
83
+ seq_len * blocks * channels
84
+ ) # update input_dim to be the product of blocks and channels.
85
+ self.intermediate_dim = intermediate_dim
86
+ self.context_tokens = context_tokens
87
+ self.output_dim = output_dim
88
+
89
+ # define multiple linear layers
90
+ self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
91
+ self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
92
+ self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
93
+
94
+ self.norm = nn.LayerNorm(output_dim)
95
+
96
+ def forward(self, audio_embeds):
97
+ """
98
+ Defines the forward pass for the AudioProjModel.
99
+
100
+ Parameters:
101
+ audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
102
+
103
+ Returns:
104
+ context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
105
+ """
106
+ # merge
107
+ video_length = audio_embeds.shape[1]
108
+ audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
109
+ batch_size, window_size, blocks, channels = audio_embeds.shape
110
+ audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
111
+
112
+ audio_embeds = torch.relu(self.proj1(audio_embeds))
113
+ audio_embeds = torch.relu(self.proj2(audio_embeds))
114
+
115
+ context_tokens = self.proj3(audio_embeds).reshape(
116
+ batch_size, self.context_tokens, self.output_dim
117
+ )
118
+
119
+ context_tokens = self.norm(context_tokens)
120
+ context_tokens = rearrange(
121
+ context_tokens, "(bz f) m c -> bz f m c", f=video_length
122
+ )
123
+
124
+ return context_tokens
src/models/audio/hubert.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from transformers import HubertModel
7
+ from transformers.modeling_outputs import BaseModelOutput
8
+
9
+
10
+ _CONFIG_FOR_DOC = 'HubertConfig'
11
+
12
+
13
+ def linear_interpolation(features, seq_len):
14
+ """
15
+ Transpose the features to interpolate linearly.
16
+
17
+ Args:
18
+ features (torch.Tensor): The extracted features to be interpolated.
19
+ seq_len (torch.Tensor): The sequence lengths of the features.
20
+
21
+ Returns:
22
+ torch.Tensor: The interpolated features.
23
+ """
24
+ features = features.transpose(1, 2)
25
+ output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
26
+ return output_features.transpose(1, 2)
27
+
28
+
29
+ class HubertModel_(HubertModel):
30
+ def __init__(self, config):
31
+ super().__init__(config)
32
+
33
+ def forward(
34
+ self,
35
+ input_values: Optional[torch.Tensor],
36
+ seq_len: Optional[int],
37
+ sample_strategy: Optional[str] = "presample",
38
+ attention_mask: Optional[torch.LongTensor] = None,
39
+ mask_time_indices: Optional[torch.FloatTensor] = None,
40
+ output_attentions: Optional[bool] = None,
41
+ output_hidden_states: Optional[bool] = None,
42
+ return_dict: Optional[bool] = None,
43
+ ):
44
+ """
45
+ Forward pass of the HuBERT model.
46
+
47
+ Args:
48
+ self: The instance of the model.
49
+ input_values: The input values (waveform) to the model.
50
+ seq_len: The sequence length of the input values.
51
+ sample_strategy: The sample strategy to align features and seq_len, supports ['presample', 'postsample'].
52
+ attention_mask: Attention mask to be used for the model.
53
+ mask_time_indices: Mask indices to be used for the model.
54
+ output_attentions: If set to True, returns attentions.
55
+ output_hidden_states: If set to True, returns hidden states.
56
+ return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
57
+
58
+ Returns:
59
+ The output of the HuBERT model.
60
+ """
61
+ # output_fps=25,
62
+ # attention_mask=None,
63
+ # output_attentions=None,
64
+ # output_hidden_states=None,
65
+ # return_dict=None,
66
+ # frame_num=None
67
+ assert sample_strategy in ["presample", "postsample"], f"sample_strategy must be in ['presample', 'postsample]"
68
+ self.config.output_attentions = True
69
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
70
+
71
+ output_hidden_states = (
72
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
73
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
74
+
75
+ extract_features = self.feature_extractor(input_values) # (N, C, L)
76
+ extract_features = extract_features.transpose(1, 2)
77
+ if sample_strategy == "presample":
78
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
79
+
80
+ # # Resample the audio feature @ 50 fps to `output_fps`.
81
+ # if frame_num is not None:
82
+ # extract_features_len = round(frame_num * 50 / output_fps)
83
+ # extract_features = extract_features[:, :, :extract_features_len]
84
+ # extract_features = linear_interpolation(extract_features, 50, output_fps, output_len=frame_num)
85
+ # extract_features = extract_features.transpose(1, 2) # (N, L, C)
86
+
87
+ if attention_mask is not None:
88
+ # compute reduced attention_mask corresponding to feature vectors
89
+ attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
90
+
91
+ hidden_states = self.feature_projection(extract_features)
92
+ hidden_states = self._mask_hidden_states(
93
+ hidden_states,
94
+ mask_time_indices=mask_time_indices,
95
+ attention_mask=attention_mask
96
+ )
97
+
98
+ encoder_outputs = self.encoder(
99
+ hidden_states,
100
+ attention_mask=attention_mask,
101
+ output_attentions=output_attentions,
102
+ output_hidden_states=output_hidden_states,
103
+ return_dict=return_dict,
104
+ )
105
+
106
+ hidden_states = encoder_outputs[0]
107
+
108
+ if sample_strategy == "postsample":
109
+ hidden_states = linear_interpolation(hidden_states, seq_len=seq_len)
110
+ for i in range(len(encoder_outputs.hidden_states)):
111
+ encoder_outputs.hidden_states[i] = linear_interpolation(encoder_outputs.hidden_states[i], seq_len=seq_len)
112
+
113
+ if not return_dict:
114
+ return (hidden_states,) + encoder_outputs[1:]
115
+
116
+ return BaseModelOutput(
117
+ last_hidden_state=hidden_states,
118
+ hidden_states=encoder_outputs.hidden_states,
119
+ attentions=encoder_outputs.attentions,
120
+ )
src/models/audio/hubert2.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from transformers import HubertModel
7
+ from transformers.modeling_outputs import BaseModelOutput
8
+
9
+
10
+ _CONFIG_FOR_DOC = 'HubertConfig'
11
+
12
+
13
+ def linear_interpolation(features, seq_len):
14
+ """
15
+ Transpose the features to interpolate linearly.
16
+
17
+ Args:
18
+ features (torch.Tensor): The extracted features to be interpolated.
19
+ seq_len (torch.Tensor): The sequence lengths of the features.
20
+
21
+ Returns:
22
+ torch.Tensor: The interpolated features.
23
+ """
24
+ features = features.transpose(1, 2)
25
+ output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
26
+ return output_features.transpose(1, 2)
27
+
28
+
29
+ class HubertModel(HubertModel):
30
+ def __init__(self, config):
31
+ super().__init__(config)
32
+
33
+ def forward(
34
+ self,
35
+ input_values: Optional[torch.Tensor],
36
+ seq_len: Optional[int],
37
+ sample_strategy: Optional[str] = "presample",
38
+ attention_mask: Optional[torch.LongTensor] = None,
39
+ mask_time_indices: Optional[torch.FloatTensor] = None,
40
+ output_attentions: Optional[bool] = None,
41
+ output_hidden_states: Optional[bool] = None,
42
+ return_dict: Optional[bool] = None,
43
+ ):
44
+ """
45
+ Forward pass of the HuBERT model.
46
+
47
+ Args:
48
+ self: The instance of the model.
49
+ input_values: The input values (waveform) to the model.
50
+ seq_len: The sequence length of the input values.
51
+ sample_strategy: The sample strategy to align features and seq_len, supports ['presample', 'postsample'].
52
+ attention_mask: Attention mask to be used for the model.
53
+ mask_time_indices: Mask indices to be used for the model.
54
+ output_attentions: If set to True, returns attentions.
55
+ output_hidden_states: If set to True, returns hidden states.
56
+ return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
57
+
58
+ Returns:
59
+ The output of the HuBERT model.
60
+ """
61
+ # output_fps=25,
62
+ # attention_mask=None,
63
+ # output_attentions=None,
64
+ # output_hidden_states=None,
65
+ # return_dict=None,
66
+ # frame_num=None
67
+ assert sample_strategy in ["presample", "postsample"], f"sample_strategy must be in ['presample', 'postsample]"
68
+ self.config.output_attentions = True
69
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
70
+
71
+ output_hidden_states = (
72
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
73
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
74
+
75
+ extract_features = self.feature_extractor(input_values) # (N, C, L)
76
+ extract_features = extract_features.transpose(1, 2)
77
+ if sample_strategy == "presample":
78
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
79
+
80
+ # # Resample the audio feature @ 50 fps to `output_fps`.
81
+ # if frame_num is not None:
82
+ # extract_features_len = round(frame_num * 50 / output_fps)
83
+ # extract_features = extract_features[:, :, :extract_features_len]
84
+ # extract_features = linear_interpolation(extract_features, 50, output_fps, output_len=frame_num)
85
+ # extract_features = extract_features.transpose(1, 2) # (N, L, C)
86
+
87
+ if attention_mask is not None:
88
+ # compute reduced attention_mask corresponding to feature vectors
89
+ attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
90
+
91
+ hidden_states = self.feature_projection(extract_features)
92
+ hidden_states = self._mask_hidden_states(
93
+ hidden_states,
94
+ mask_time_indices=mask_time_indices,
95
+ attention_mask=attention_mask
96
+ )
97
+
98
+ encoder_outputs = self.encoder(
99
+ hidden_states,
100
+ attention_mask=attention_mask,
101
+ output_attentions=output_attentions,
102
+ output_hidden_states=output_hidden_states,
103
+ return_dict=return_dict,
104
+ )
105
+
106
+ hidden_states = encoder_outputs[0]
107
+
108
+ if sample_strategy == "postsample":
109
+ hidden_states = linear_interpolation(hidden_states, seq_len=seq_len)
110
+ for i in range(len(encoder_outputs.hidden_states)):
111
+ encoder_outputs.hidden_states[i] = linear_interpolation(encoder_outputs.hidden_states[i], seq_len=seq_len)
112
+
113
+ if not return_dict:
114
+ return (hidden_states,) + encoder_outputs[1:]
115
+
116
+ return BaseModelOutput(
117
+ last_hidden_state=hidden_states,
118
+ hidden_states=encoder_outputs.hidden_states,
119
+ attentions=encoder_outputs.attentions,
120
+ )
src/models/audio/wav2vec.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ """
4
+ This module defines the Wav2Vec model, which is a pre-trained model for speech recognition and understanding.
5
+ It inherits from the Wav2Vec2Model class in the transformers library and provides additional functionalities
6
+ such as feature extraction and encoding.
7
+
8
+ Classes:
9
+ Wav2VecModel: Inherits from Wav2Vec2Model and adds additional methods for feature extraction and encoding.
10
+
11
+ Functions:
12
+ linear_interpolation: Interpolates the features based on the sequence length.
13
+ """
14
+
15
+ from typing import Optional, Tuple, Union
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from transformers import Wav2Vec2Model
19
+ from transformers.modeling_outputs import BaseModelOutput
20
+
21
+
22
+ class Wav2VecModel(Wav2Vec2Model):
23
+ """
24
+ Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library.
25
+ It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding.
26
+ ...
27
+
28
+ Attributes:
29
+ base_model (Wav2Vec2Model): The base Wav2Vec2Model object.
30
+
31
+ Methods:
32
+ forward(input_values, seq_len, attention_mask=None, mask_time_indices=None
33
+ , output_attentions=None, output_hidden_states=None, return_dict=None):
34
+ Forward pass of the Wav2VecModel.
35
+ It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model.
36
+
37
+ feature_extract(input_values, seq_len):
38
+ Extracts features from the input_values using the base model.
39
+
40
+ encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None):
41
+ Encodes the extracted features using the base model and returns the encoded features.
42
+ """
43
+ def forward(
44
+ self,
45
+ input_values,
46
+ seq_len,
47
+ attention_mask=None,
48
+ mask_time_indices=None,
49
+ output_attentions=None,
50
+ output_hidden_states=None,
51
+ return_dict=None,
52
+ ):
53
+ """
54
+ Forward pass of the Wav2Vec model.
55
+
56
+ Args:
57
+ self: The instance of the model.
58
+ input_values: The input values (waveform) to the model.
59
+ seq_len: The sequence length of the input values.
60
+ attention_mask: Attention mask to be used for the model.
61
+ mask_time_indices: Mask indices to be used for the model.
62
+ output_attentions: If set to True, returns attentions.
63
+ output_hidden_states: If set to True, returns hidden states.
64
+ return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
65
+
66
+ Returns:
67
+ The output of the Wav2Vec model.
68
+ """
69
+ self.config.output_attentions = True
70
+
71
+ output_hidden_states = (
72
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
73
+ )
74
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
75
+
76
+ extract_features = self.feature_extractor(input_values)
77
+ extract_features = extract_features.transpose(1, 2)
78
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
79
+
80
+ if attention_mask is not None:
81
+ # compute reduced attention_mask corresponding to feature vectors
82
+ attention_mask = self._get_feature_vector_attention_mask(
83
+ extract_features.shape[1], attention_mask, add_adapter=False
84
+ )
85
+
86
+ hidden_states, extract_features = self.feature_projection(extract_features)
87
+ hidden_states = self._mask_hidden_states(
88
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
89
+ )
90
+
91
+ encoder_outputs = self.encoder(
92
+ hidden_states,
93
+ attention_mask=attention_mask,
94
+ output_attentions=output_attentions,
95
+ output_hidden_states=output_hidden_states,
96
+ return_dict=return_dict,
97
+ )
98
+
99
+ hidden_states = encoder_outputs[0]
100
+
101
+ if self.adapter is not None:
102
+ hidden_states = self.adapter(hidden_states)
103
+
104
+ if not return_dict:
105
+ return (hidden_states, ) + encoder_outputs[1:]
106
+ return BaseModelOutput(
107
+ last_hidden_state=hidden_states,
108
+ hidden_states=encoder_outputs.hidden_states,
109
+ attentions=encoder_outputs.attentions,
110
+ )
111
+
112
+
113
+ def feature_extract(
114
+ self,
115
+ input_values,
116
+ seq_len,
117
+ ):
118
+ """
119
+ Extracts features from the input values and returns the extracted features.
120
+
121
+ Parameters:
122
+ input_values (torch.Tensor): The input values to be processed.
123
+ seq_len (torch.Tensor): The sequence lengths of the input values.
124
+
125
+ Returns:
126
+ extracted_features (torch.Tensor): The extracted features from the input values.
127
+ """
128
+ extract_features = self.feature_extractor(input_values)
129
+ extract_features = extract_features.transpose(1, 2)
130
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
131
+
132
+ return extract_features
133
+
134
+ def encode(
135
+ self,
136
+ extract_features,
137
+ attention_mask=None,
138
+ mask_time_indices=None,
139
+ output_attentions=None,
140
+ output_hidden_states=None,
141
+ return_dict=None,
142
+ ):
143
+ """
144
+ Encodes the input features into the output space.
145
+
146
+ Args:
147
+ extract_features (torch.Tensor): The extracted features from the audio signal.
148
+ attention_mask (torch.Tensor, optional): Attention mask to be used for padding.
149
+ mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension.
150
+ output_attentions (bool, optional): If set to True, returns the attention weights.
151
+ output_hidden_states (bool, optional): If set to True, returns all hidden states.
152
+ return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple.
153
+
154
+ Returns:
155
+ The encoded output features.
156
+ """
157
+ self.config.output_attentions = True
158
+
159
+ output_hidden_states = (
160
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
161
+ )
162
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
163
+
164
+ if attention_mask is not None:
165
+ # compute reduced attention_mask corresponding to feature vectors
166
+ attention_mask = self._get_feature_vector_attention_mask(
167
+ extract_features.shape[1], attention_mask, add_adapter=False
168
+ )
169
+
170
+ hidden_states, extract_features = self.feature_projection(extract_features)
171
+ hidden_states = self._mask_hidden_states(
172
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
173
+ )
174
+
175
+ encoder_outputs = self.encoder(
176
+ hidden_states,
177
+ attention_mask=attention_mask,
178
+ output_attentions=output_attentions,
179
+ output_hidden_states=output_hidden_states,
180
+ return_dict=return_dict,
181
+ )
182
+
183
+ hidden_states = encoder_outputs[0]
184
+
185
+ if self.adapter is not None:
186
+ hidden_states = self.adapter(hidden_states)
187
+
188
+ if not return_dict:
189
+ return (hidden_states, ) + encoder_outputs[1:]
190
+ return BaseModelOutput(
191
+ last_hidden_state=hidden_states,
192
+ hidden_states=encoder_outputs.hidden_states,
193
+ attentions=encoder_outputs.attentions,
194
+ )
195
+
196
+
197
+ def linear_interpolation(features, seq_len):
198
+ """
199
+ Transpose the features to interpolate linearly.
200
+
201
+ Args:
202
+ features (torch.Tensor): The extracted features to be interpolated.
203
+ seq_len (torch.Tensor): The sequence lengths of the features.
204
+
205
+ Returns:
206
+ torch.Tensor: The interpolated features.
207
+ """
208
+ features = features.transpose(1, 2)
209
+ output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
210
+ return output_features.transpose(1, 2)
src/models/audio/wav2vec2.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from packaging import version
2
+ from typing import Optional, Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import transformers
8
+ from transformers import Wav2Vec2Model
9
+ from transformers.modeling_outputs import BaseModelOutput
10
+
11
+ _CONFIG_FOR_DOC = 'Wav2Vec2Config'
12
+
13
+
14
+ # the implementation of Wav2Vec2Model is borrowed from
15
+ # https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model
16
+ # initialize our encoder with the pre-trained wav2vec 2.0 weights.
17
+ def _compute_mask_indices(shape: Tuple[int, int], mask_prob: float, mask_length: int,
18
+ attention_mask: Optional[torch.Tensor] = None, min_masks: int = 0, ) -> np.ndarray:
19
+ bsz, all_sz = shape
20
+ mask = np.full((bsz, all_sz), False)
21
+
22
+ all_num_mask = int(mask_prob * all_sz / float(mask_length) + np.random.rand())
23
+ all_num_mask = max(min_masks, all_num_mask)
24
+ mask_idcs = []
25
+ padding_mask = attention_mask.ne(1) if attention_mask is not None else None
26
+ for i in range(bsz):
27
+ if padding_mask is not None:
28
+ sz = all_sz - padding_mask[i].long().sum().item()
29
+ num_mask = int(mask_prob * sz / float(mask_length) + np.random.rand())
30
+ num_mask = max(min_masks, num_mask)
31
+ else:
32
+ sz = all_sz
33
+ num_mask = all_num_mask
34
+
35
+ lengths = np.full(num_mask, mask_length)
36
+
37
+ if sum(lengths) == 0:
38
+ lengths[0] = min(mask_length, sz - 1)
39
+
40
+ min_len = min(lengths)
41
+ if sz - min_len <= num_mask:
42
+ min_len = sz - num_mask - 1
43
+
44
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
45
+ mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
46
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
47
+
48
+ min_len = min([len(m) for m in mask_idcs])
49
+ for i, mask_idc in enumerate(mask_idcs):
50
+ if len(mask_idc) > min_len:
51
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
52
+ mask[i, mask_idc] = True
53
+ return mask
54
+
55
+
56
+ # linear interpolation layer
57
+ def linear_interpolation(features, input_fps, output_fps, output_len=None):
58
+ # features: (N, C, L)
59
+ seq_len = features.shape[2] / float(input_fps)
60
+ if output_len is None:
61
+ output_len = int(seq_len * output_fps)
62
+ output_features = F.interpolate(features, size=output_len, align_corners=False, mode='linear')
63
+ return output_features
64
+
65
+
66
+ class Wav2Vec2Model(Wav2Vec2Model):
67
+ def __init__(self, config):
68
+ super().__init__(config)
69
+ self.is_old_version = version.parse(transformers.__version__) < version.parse('4.7.0')
70
+
71
+ def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None,
72
+ output_hidden_states=None, return_dict=None, frame_num=None):
73
+ self.config.output_attentions = True
74
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
75
+ output_hidden_states = (
76
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
77
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
78
+ print(f"data shape before feature extractor: {input_values.shape}")
79
+ hidden_states = self.feature_extractor(input_values) # (N, C, L)
80
+ print(f"data shape after feature extractor: {hidden_states.shape}")
81
+ # Resample the audio feature @ 50 fps to `output_fps`.
82
+ if frame_num is not None:
83
+ hidden_states_len = round(frame_num * 50 / output_fps)
84
+ hidden_states = hidden_states[:, :, :hidden_states_len]
85
+ hidden_states = linear_interpolation(hidden_states, 50, output_fps, output_len=frame_num)
86
+ hidden_states = hidden_states.transpose(1, 2) # (N, L, C)
87
+
88
+ if attention_mask is not None:
89
+ output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
90
+ attention_mask = torch.zeros(hidden_states.shape[:2], dtype=hidden_states.dtype,
91
+ device=hidden_states.device)
92
+ attention_mask[(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)] = 1
93
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
94
+
95
+ if self.is_old_version:
96
+ hidden_states = self.feature_projection(hidden_states)
97
+ else:
98
+ hidden_states = self.feature_projection(hidden_states)[0]
99
+
100
+ if self.config.apply_spec_augment and self.training:
101
+ batch_size, sequence_length, hidden_size = hidden_states.size()
102
+ if self.config.mask_time_prob > 0:
103
+ mask_time_indices = _compute_mask_indices((batch_size, sequence_length), self.config.mask_time_prob,
104
+ self.config.mask_time_length, attention_mask=attention_mask,
105
+ min_masks=2, )
106
+ hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype)
107
+ if self.config.mask_feature_prob > 0:
108
+ mask_feature_indices = _compute_mask_indices((batch_size, hidden_size), self.config.mask_feature_prob,
109
+ self.config.mask_feature_length, )
110
+ mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device)
111
+ hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
112
+ encoder_outputs = self.encoder(hidden_states, attention_mask=attention_mask,
113
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
114
+ return_dict=return_dict, )
115
+ hidden_states = encoder_outputs[0]
116
+ if not return_dict:
117
+ return (hidden_states,) + encoder_outputs[1:]
118
+
119
+ for i in range(len(encoder_outputs.hidden_states)):
120
+ print(f"hidden states {i} after encoder: {encoder_outputs.hidden_states[i].shape}")
121
+
122
+ return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states,
123
+ attentions=encoder_outputs.attentions, )
src/models/audio/wav2vec_modified.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ """
4
+ This module defines the Wav2Vec model, which is a pre-trained model for speech recognition and understanding.
5
+ It inherits from the Wav2Vec2Model class in the transformers library and provides additional functionalities
6
+ such as feature extraction and encoding.
7
+
8
+ Classes:
9
+ Wav2VecModel: Inherits from Wav2Vec2Model and adds additional methods for feature extraction and encoding.
10
+
11
+ Functions:
12
+ linear_interpolation: Interpolates the features based on the sequence length.
13
+ """
14
+
15
+ from typing import Optional, Tuple, Union
16
+ import torch
17
+
18
+ import torch.nn.functional as F
19
+ from transformers import Wav2Vec2Model
20
+ from transformers.modeling_outputs import BaseModelOutput, Wav2Vec2BaseModelOutput
21
+
22
+
23
+ class Wav2VecModel(Wav2Vec2Model):
24
+ """
25
+ Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library.
26
+ It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding.
27
+ ...
28
+
29
+ Attributes:
30
+ base_model (Wav2Vec2Model): The base Wav2Vec2Model object.
31
+
32
+ Methods:
33
+ forward(input_values, seq_len, attention_mask=None, mask_time_indices=None
34
+ , output_attentions=None, output_hidden_states=None, return_dict=None):
35
+ Forward pass of the Wav2VecModel.
36
+ It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model.
37
+
38
+ feature_extract(input_values, seq_len):
39
+ Extracts features from the input_values using the base model.
40
+
41
+ encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None):
42
+ Encodes the extracted features using the base model and returns the encoded features.
43
+ """
44
+ def forward(
45
+ self,
46
+ input_values: Optional[torch.Tensor],
47
+ seq_len: Optional[int],
48
+ sample_strategy: Optional[str] = "presample",
49
+ attention_mask: Optional[torch.LongTensor] = None,
50
+ mask_time_indices: Optional[torch.FloatTensor] = None,
51
+ output_attentions: Optional[bool] = None,
52
+ output_hidden_states: Optional[bool] = None,
53
+ return_dict: Optional[bool] = None,
54
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
55
+ """
56
+ Forward pass of the Wav2Vec model.
57
+
58
+ Args:
59
+ self: The instance of the model.
60
+ input_values: The input values (waveform) to the model.
61
+ seq_len: The sequence length of the input values.
62
+ sample_strategy: The sample strategy to align features and seq_len, supports ['presample', 'postsample'].
63
+ attention_mask: Attention mask to be used for the model.
64
+ mask_time_indices: Mask indices to be used for the model.
65
+ output_attentions: If set to True, returns attentions.
66
+ output_hidden_states: If set to True, returns hidden states.
67
+ return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
68
+
69
+ Returns:
70
+ The output of the Wav2Vec model.
71
+ """
72
+ assert sample_strategy in ["presample", "postsample"], f"sample_strategy must be in ['presample', 'postsample]"
73
+
74
+ self.config.output_attentions = True
75
+
76
+ output_hidden_states = (
77
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
78
+ )
79
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
80
+
81
+ extract_features = self.feature_extractor(input_values)
82
+ extract_features = extract_features.transpose(1, 2)
83
+ if sample_strategy == "presample":
84
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
85
+
86
+ if attention_mask is not None:
87
+ # compute reduced attention_mask corresponding to feature vectors
88
+ attention_mask = self._get_feature_vector_attention_mask(
89
+ extract_features.shape[1], attention_mask, add_adapter=False
90
+ )
91
+
92
+ hidden_states, extract_features = self.feature_projection(extract_features)
93
+ hidden_states = self._mask_hidden_states(
94
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
95
+ )
96
+
97
+ encoder_outputs = self.encoder(
98
+ hidden_states,
99
+ attention_mask=attention_mask,
100
+ output_attentions=output_attentions,
101
+ output_hidden_states=output_hidden_states,
102
+ return_dict=return_dict,
103
+ )
104
+
105
+ hidden_states = encoder_outputs[0]
106
+
107
+ if self.adapter is not None:
108
+ hidden_states = self.adapter(hidden_states)
109
+
110
+ if sample_strategy == "postsample":
111
+ hidden_states = linear_interpolation(hidden_states, seq_len=seq_len)
112
+ for i in range(len(encoder_outputs.hidden_states)):
113
+ encoder_outputs.hidden_states[i] = linear_interpolation(encoder_outputs.hidden_states[i], seq_len=seq_len)
114
+
115
+ if not return_dict:
116
+ return (hidden_states, extract_features) + encoder_outputs[1:]
117
+
118
+ return Wav2Vec2BaseModelOutput(
119
+ last_hidden_state=hidden_states,
120
+ extract_features=extract_features,
121
+ hidden_states=encoder_outputs.hidden_states,
122
+ attentions=encoder_outputs.attentions,
123
+ )
124
+
125
+
126
+ def feature_extract(
127
+ self,
128
+ input_values,
129
+ seq_len,
130
+ ):
131
+ """
132
+ Extracts features from the input values and returns the extracted features.
133
+
134
+ Parameters:
135
+ input_values (torch.Tensor): The input values to be processed.
136
+ seq_len (torch.Tensor): The sequence lengths of the input values.
137
+
138
+ Returns:
139
+ extracted_features (torch.Tensor): The extracted features from the input values.
140
+ """
141
+ extract_features = self.feature_extractor(input_values)
142
+ extract_features = extract_features.transpose(1, 2)
143
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
144
+
145
+ return extract_features
146
+
147
+ def encode(
148
+ self,
149
+ extract_features,
150
+ attention_mask=None,
151
+ mask_time_indices=None,
152
+ output_attentions=None,
153
+ output_hidden_states=None,
154
+ return_dict=None,
155
+ ):
156
+ """
157
+ Encodes the input features into the output space.
158
+
159
+ Args:
160
+ extract_features (torch.Tensor): The extracted features from the audio signal.
161
+ attention_mask (torch.Tensor, optional): Attention mask to be used for padding.
162
+ mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension.
163
+ output_attentions (bool, optional): If set to True, returns the attention weights.
164
+ output_hidden_states (bool, optional): If set to True, returns all hidden states.
165
+ return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple.
166
+
167
+ Returns:
168
+ The encoded output features.
169
+ """
170
+ self.config.output_attentions = True
171
+
172
+ output_hidden_states = (
173
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
174
+ )
175
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
176
+
177
+ if attention_mask is not None:
178
+ # compute reduced attention_mask corresponding to feature vectors
179
+ attention_mask = self._get_feature_vector_attention_mask(
180
+ extract_features.shape[1], attention_mask, add_adapter=False
181
+ )
182
+
183
+ hidden_states, extract_features = self.feature_projection(extract_features)
184
+ hidden_states = self._mask_hidden_states(
185
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
186
+ )
187
+
188
+ encoder_outputs = self.encoder(
189
+ hidden_states,
190
+ attention_mask=attention_mask,
191
+ output_attentions=output_attentions,
192
+ output_hidden_states=output_hidden_states,
193
+ return_dict=return_dict,
194
+ )
195
+
196
+ hidden_states = encoder_outputs[0]
197
+
198
+ if self.adapter is not None:
199
+ hidden_states = self.adapter(hidden_states)
200
+
201
+ if not return_dict:
202
+ return (hidden_states, ) + encoder_outputs[1:]
203
+ return BaseModelOutput(
204
+ last_hidden_state=hidden_states,
205
+ hidden_states=encoder_outputs.hidden_states,
206
+ attentions=encoder_outputs.attentions,
207
+ )
208
+
209
+
210
+ def linear_interpolation(features, seq_len):
211
+ """
212
+ Transpose the features to interpolate linearly.
213
+
214
+ Args:
215
+ features (torch.Tensor): The extracted features to be interpolated.
216
+ seq_len (torch.Tensor): The sequence lengths of the features.
217
+
218
+ Returns:
219
+ torch.Tensor: The interpolated features.
220
+ """
221
+ features = features.transpose(1, 2)
222
+ output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
223
+ return output_features.transpose(1, 2)