Upload 247 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +8 -0
- configs/audio2motion/inference/inference.yaml +35 -0
- configs/audio2motion/model/audio_processer_config.yaml +36 -0
- configs/audio2motion/model/config.yaml +59 -0
- configs/audio2motion/model/crop_config.yaml +21 -0
- configs/audio2motion/model/liveportrait_config.yaml +59 -0
- configs/audio2motion/model/models.yaml +43 -0
- requirements.txt +45 -0
- src/datasets/mean.pt +3 -0
- src/datasets/preprocess/__pycache__/flow_filter.cpython-310.pyc +0 -0
- src/datasets/preprocess/__pycache__/video_crop.cpython-310.pyc +0 -0
- src/datasets/preprocess/__pycache__/visualize.cpython-310.pyc +0 -0
- src/datasets/preprocess/extract_features/__pycache__/audio_processer.cpython-310.pyc +0 -0
- src/datasets/preprocess/extract_features/__pycache__/audio_processer.cpython-312.pyc +0 -0
- src/datasets/preprocess/extract_features/__pycache__/feature_extractor_pipeline.cpython-310.pyc +0 -0
- src/datasets/preprocess/extract_features/__pycache__/motion_processer.cpython-310.pyc +0 -0
- src/datasets/preprocess/extract_features/__pycache__/test_processer.cpython-310.pyc +0 -0
- src/datasets/preprocess/extract_features/audio_processer.py +471 -0
- src/datasets/preprocess/extract_features/face_segmentation/__init__.py +88 -0
- src/datasets/preprocess/extract_features/face_segmentation/__pycache__/__init__.cpython-310.pyc +0 -0
- src/datasets/preprocess/extract_features/face_segmentation/__pycache__/bisenet.cpython-310.pyc +0 -0
- src/datasets/preprocess/extract_features/face_segmentation/__pycache__/resnet.cpython-310.pyc +0 -0
- src/datasets/preprocess/extract_features/face_segmentation/bisenet.py +285 -0
- src/datasets/preprocess/extract_features/face_segmentation/resnet.py +113 -0
- src/datasets/preprocess/extract_features/motion_processer.py +1420 -0
- src/examples/driving_audios/10.wav +3 -0
- src/examples/driving_audios/5.wav +3 -0
- src/examples/driving_audios/6.wav +3 -0
- src/examples/driving_audios/tmp_5.wav +3 -0
- src/examples/reference_images/1.jpg +3 -0
- src/examples/reference_images/2.jpg +0 -0
- src/examples/reference_images/3.jpg +0 -0
- src/examples/reference_images/4.jpg +0 -0
- src/examples/reference_images/5.jpg +0 -0
- src/examples/reference_images/6.jpg +0 -0
- src/examples/reference_images/7.jpg +3 -0
- src/examples/silent-audio.wav +3 -0
- src/models/audio/__pycache__/audio_processer.cpython-310.pyc +0 -0
- src/models/audio/__pycache__/audio_proj.cpython-310.pyc +0 -0
- src/models/audio/__pycache__/hubert.cpython-310.pyc +0 -0
- src/models/audio/__pycache__/wav2vec.cpython-310.pyc +0 -0
- src/models/audio/__pycache__/wav2vec2.cpython-310.pyc +0 -0
- src/models/audio/__pycache__/wav2vec_modified.cpython-310.pyc +0 -0
- src/models/audio/audio_processer.py +407 -0
- src/models/audio/audio_proj.py +124 -0
- src/models/audio/hubert.py +120 -0
- src/models/audio/hubert2.py +120 -0
- src/models/audio/wav2vec.py +210 -0
- src/models/audio/wav2vec2.py +123 -0
- src/models/audio/wav2vec_modified.py +223 -0
.gitattributes
CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
src/examples/driving_audios/10.wav filter=lfs diff=lfs merge=lfs -text
|
37 |
+
src/examples/driving_audios/5.wav filter=lfs diff=lfs merge=lfs -text
|
38 |
+
src/examples/driving_audios/6.wav filter=lfs diff=lfs merge=lfs -text
|
39 |
+
src/examples/driving_audios/tmp_5.wav filter=lfs diff=lfs merge=lfs -text
|
40 |
+
src/examples/reference_images/1.jpg filter=lfs diff=lfs merge=lfs -text
|
41 |
+
src/examples/reference_images/7.jpg filter=lfs diff=lfs merge=lfs -text
|
42 |
+
src/examples/silent-audio.wav filter=lfs diff=lfs merge=lfs -text
|
43 |
+
src/thirdparty/liveportrait/src/utils/dependencies/insightface/data/images/t1.jpg filter=lfs diff=lfs merge=lfs -text
|
configs/audio2motion/inference/inference.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
output_fps: 25
|
3 |
+
## appearance and motion feature extractor
|
4 |
+
appearance_feature_extractor_path: pretrain_weights/decode/v1/first_stage/base_models/appearance_feature_extractor.pth
|
5 |
+
motion_extractor_path: pretrain_weights/decode/v1/first_stage/base_models/motion_extractor.pth
|
6 |
+
## SPADEGenerator
|
7 |
+
spade_generator_path: pretrain_weights/decode/v1/first_stage/base_models/spade_generator.pth
|
8 |
+
warping_module_path: pretrain_weights/decode/v1/first_stage/base_models/warping_module.pth
|
9 |
+
## stitching retargeting module
|
10 |
+
stitching_retargeting_module_path: pretrain_weights/decode/v1/first_stage/retargeting_models/stitching_retargeting_module.pth
|
11 |
+
#
|
12 |
+
|
13 |
+
# audio processer config
|
14 |
+
audio_model_config: configs/audio2motion/model/audio_processer_config.yaml
|
15 |
+
|
16 |
+
# motion processer config
|
17 |
+
motion_processer_config: configs/audio2motion/model/liveportrait_config.yaml
|
18 |
+
|
19 |
+
# motion generator model
|
20 |
+
motion_models_config: configs/audio2motion/model/config.yaml
|
21 |
+
use_ref_kp: False
|
22 |
+
motion_generator_path: pretrain_weights/moda/net-200.pth
|
23 |
+
need_normalized: True
|
24 |
+
|
25 |
+
# other configs
|
26 |
+
device_id: 0
|
27 |
+
batch_size: 100
|
28 |
+
|
29 |
+
source_max_dim: 1280 # the max dim of height and width of source image or video
|
30 |
+
source_division: 2 # make sure the height and width of source image or video can be divided by this number
|
31 |
+
input_height: 256
|
32 |
+
input_width: 256
|
33 |
+
source_fps: 25
|
34 |
+
min_video_length: 50
|
35 |
+
max_video_length: 500
|
configs/audio2motion/model/audio_processer_config.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# models settings
|
2 |
+
model_params:
|
3 |
+
model_name: hubert # wav2vec or hubert
|
4 |
+
model_type: base # base large
|
5 |
+
is_chinese: True
|
6 |
+
is_original: True
|
7 |
+
only_last_features: False
|
8 |
+
use_audio_separator: False
|
9 |
+
audio_separator_name: Kim_Vocal_2.onnx
|
10 |
+
|
11 |
+
# model weights
|
12 |
+
model_weights:
|
13 |
+
audio_separator_path: pretrain_weights/audio/audio_separator
|
14 |
+
hubert_path:
|
15 |
+
chinese:
|
16 |
+
base: pretrain_weights/audio/chinese-hubert-base
|
17 |
+
# data settings
|
18 |
+
data_params:
|
19 |
+
sample_rate: 16000
|
20 |
+
max_length: 60 # seconds
|
21 |
+
sub_clip_length: 3000 # samples
|
22 |
+
fps: 25
|
23 |
+
sample_strategy: "presample"
|
24 |
+
audio_pad_mode: replicate # pad mode for audio, replicate or zero
|
25 |
+
save_to_cpu: True # saving gpu memory
|
26 |
+
|
27 |
+
# device settings
|
28 |
+
device_params:
|
29 |
+
device_id: 0
|
30 |
+
flag_force_cpu: False
|
31 |
+
flag_use_half_precision: False
|
32 |
+
|
33 |
+
cache_dir: preprocessed/HDTF/vocals
|
34 |
+
tmp_dir: src/tmp
|
35 |
+
|
36 |
+
|
configs/audio2motion/model/config.yaml
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: TalkingHeadDiT-B
|
2 |
+
audio_projector:
|
3 |
+
type: MLP
|
4 |
+
pretrained_model_path: None
|
5 |
+
device: cuda
|
6 |
+
params:
|
7 |
+
model_name: MLP-S-3
|
8 |
+
sequence_length: 1
|
9 |
+
blocks: 12
|
10 |
+
audio_feat_dim: 768
|
11 |
+
keypoint_dim: 63
|
12 |
+
feature_dim: 512
|
13 |
+
output_dim: 256
|
14 |
+
context_tokens: 1
|
15 |
+
audio_embedder_type: simple
|
16 |
+
audio_cond_dim: 63
|
17 |
+
motion_generator:
|
18 |
+
type: DiT
|
19 |
+
pretrained_model_path: None
|
20 |
+
device: cuda
|
21 |
+
params:
|
22 |
+
model_name: DiT-S-8-8
|
23 |
+
architecture: decoder
|
24 |
+
use_emo: True
|
25 |
+
input_dim: 70
|
26 |
+
output_dim: 70
|
27 |
+
exp_dim: 63
|
28 |
+
n_prev_frames: 1
|
29 |
+
n_pred_frames: 80
|
30 |
+
use_indicator: False
|
31 |
+
feature_dim: 256
|
32 |
+
n_heads: 8
|
33 |
+
n_layers: 8
|
34 |
+
mlp_ratio: 4
|
35 |
+
no_use_learnable_pe: True
|
36 |
+
norm_type: rms_norm # [rms_norm|layer_norm]
|
37 |
+
qk_norm: rms_norm # [rms_norm|layer_norm|null]
|
38 |
+
steps: 1000
|
39 |
+
noise_scheduler:
|
40 |
+
type: flow_matching
|
41 |
+
sample_mode: sample
|
42 |
+
device: cuda
|
43 |
+
params:
|
44 |
+
time_shifting: True
|
45 |
+
num_train_timesteps: 1000
|
46 |
+
num_inference_steps: 10
|
47 |
+
eta: 0.2
|
48 |
+
beta_start: 0.0001
|
49 |
+
beta_end: 0.02
|
50 |
+
s: 0.008
|
51 |
+
mode: cosine
|
52 |
+
train:
|
53 |
+
audio_drop_prob: 0.3
|
54 |
+
cond_drop_prob: 0.2
|
55 |
+
motion_drop_prob: 0.3
|
56 |
+
audio_drop_ratio : 0.2
|
57 |
+
motion_drop_ratio: 0.1
|
58 |
+
pre_drop_ratio : 0.0
|
59 |
+
device_specific: True
|
configs/audio2motion/model/crop_config.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
insightface_root: pretrain_weights/decode/v1/insightface
|
2 |
+
landmark_ckpt_path: pretrain_weights/decode/v1/first_stage/landmark.onnx
|
3 |
+
xpose_config_file_path: src/utils/UniPose_SwinT.py
|
4 |
+
device_id: 0 # gpu device id
|
5 |
+
flag_force_cpu: False # force cpu inference, WIP
|
6 |
+
det_thresh: 0.15 # detection threshold
|
7 |
+
########## source image or video cropping option ##########
|
8 |
+
dsize: 512 # crop size
|
9 |
+
scale: 2.3 # scale factor
|
10 |
+
vx_ratio: 0 # vx ratio
|
11 |
+
vy_ratio: -0.125 # vy ratio +up, -down
|
12 |
+
max_face_num: 0 # max face number, 0 mean no limit
|
13 |
+
flag_do_rot: True # whether to conduct the rotation when flag_do_crop is True
|
14 |
+
animal_face_type: animal_face_9 # animal_face_68 -> 68 landmark points, animal_face_9 -> 9 landmarks
|
15 |
+
########## driving video auto cropping option ##########
|
16 |
+
scale_crop_driving_video: 2.2 # 2.0 # scale factor for cropping driving video
|
17 |
+
vx_ratio_crop_driving_video: 0.0 # adjust x offset
|
18 |
+
vy_ratio_crop_driving_video: -0.1 # adjust y offset
|
19 |
+
direction: large-small # direction of cropping
|
20 |
+
source_max_dim: 1280
|
21 |
+
source_division: 2
|
configs/audio2motion/model/liveportrait_config.yaml
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# model config
|
2 |
+
models_config: configs/audio2motion/model/models.yaml
|
3 |
+
|
4 |
+
# 1. face appearance feature
|
5 |
+
appearance_feature_extractor_path: pretrain_weights/decode/v1/first_stage/base_models/appearance_feature_extractor.pth
|
6 |
+
|
7 |
+
# 2. motion feature
|
8 |
+
motion_extractor_path: pretrain_weights/decode/v1/first_stage/base_models/motion_extractor.pth
|
9 |
+
|
10 |
+
# 3. stitching retargeting module
|
11 |
+
stitching_retargeting_module_path: pretrain_weights/decode/v1/first_stage/retargeting_models/stitching_retargeting_module.pth
|
12 |
+
|
13 |
+
# 4. feature warper
|
14 |
+
warping_module_path: pretrain_weights/decode/v1/first_stage/base_models/warping_module.pth
|
15 |
+
|
16 |
+
# 5. SPADEGenerator
|
17 |
+
spade_generator_path: pretrain_weights/decode/v1/first_stage/base_models/spade_generator.pth
|
18 |
+
|
19 |
+
# 6. cropper
|
20 |
+
crop_cfg: "configs/audio2motion/model/crop_config.yaml"
|
21 |
+
|
22 |
+
# 7. face parser
|
23 |
+
face_parser_weight_path: "pretrain_weights/face/face-parsing/79999_iter.pth"
|
24 |
+
resnet_weight_path: "pretrain_weights/face/face-parsing/resnet18-5c106cde.pth"
|
25 |
+
|
26 |
+
# motion template
|
27 |
+
need_normalized: True
|
28 |
+
|
29 |
+
# others
|
30 |
+
batch_size: 100
|
31 |
+
source_max_dim: 1920 # the max dim of height and width of source image or video
|
32 |
+
source_division: 2 # make sure the height and width of source image or video can be divided by this number
|
33 |
+
input_height: 256
|
34 |
+
input_width: 256
|
35 |
+
output_height: 512
|
36 |
+
output_width: 512
|
37 |
+
output_fps: 25
|
38 |
+
|
39 |
+
# driving params
|
40 |
+
flag_do_torch_compile: False
|
41 |
+
flag_use_half_precision: True
|
42 |
+
flag_relative_motion: False
|
43 |
+
flag_normalize_lip: False
|
44 |
+
flag_source_video_eye_retargeting: False
|
45 |
+
flag_eye_retargeting: False
|
46 |
+
flag_lip_retargeting: False
|
47 |
+
flag_stitching: True
|
48 |
+
|
49 |
+
lip_normalize_threshold: 0.03 # threshold for flag_normalize_lip
|
50 |
+
source_video_eye_retargeting_threshold: 0.18 # threshold for eyes retargeting if the input is a source video
|
51 |
+
anchor_frame: 0 # TO IMPLEMENT
|
52 |
+
|
53 |
+
driving_option: "expression-friendly" # "expression-friendly" or "pose-friendly"
|
54 |
+
driving_multiplier: 1.0 # be used only when driving_option is "expression-friendly"
|
55 |
+
lib_multiplier: 1.0
|
56 |
+
driving_smooth_observation_variance: 3e-7 # the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
57 |
+
animation_region: "all" #["exp", "pose", "lip", "eyes", "all"], the region where the animation was performed, "exp" means the expression, "pose" means the head pose
|
58 |
+
mask_crop: src/utils/resources/mask_template.png
|
59 |
+
lip_array: src/utils/resources/lip_array.pkl
|
configs/audio2motion/model/models.yaml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_params:
|
2 |
+
appearance_feature_extractor_params: # the F in the paper
|
3 |
+
image_channel: 3
|
4 |
+
block_expansion: 64
|
5 |
+
num_down_blocks: 2
|
6 |
+
max_features: 512
|
7 |
+
reshape_channel: 32
|
8 |
+
reshape_depth: 16
|
9 |
+
num_resblocks: 6
|
10 |
+
motion_extractor_params: # the M in the paper
|
11 |
+
num_kp: 21
|
12 |
+
backbone: convnextv2_tiny
|
13 |
+
warping_module_params: # the W in the paper
|
14 |
+
num_kp: 21
|
15 |
+
block_expansion: 64
|
16 |
+
max_features: 512
|
17 |
+
num_down_blocks: 2
|
18 |
+
reshape_channel: 32
|
19 |
+
estimate_occlusion_map: True
|
20 |
+
dense_motion_params:
|
21 |
+
block_expansion: 32
|
22 |
+
max_features: 1024
|
23 |
+
num_blocks: 5
|
24 |
+
reshape_depth: 16
|
25 |
+
compress: 4
|
26 |
+
spade_generator_params: # the G in the paper
|
27 |
+
upscale: 2 # represents upsample factor 256x256 -> 512x512
|
28 |
+
block_expansion: 64
|
29 |
+
max_features: 512
|
30 |
+
num_down_blocks: 2
|
31 |
+
stitching_retargeting_module_params: # the S in the paper
|
32 |
+
stitching:
|
33 |
+
input_size: 126 # (21*3)*2
|
34 |
+
hidden_sizes: [128, 128, 64]
|
35 |
+
output_size: 65 # (21*3)+2(tx,ty)
|
36 |
+
lip:
|
37 |
+
input_size: 65 # (21*3)+2
|
38 |
+
hidden_sizes: [128, 128, 64]
|
39 |
+
output_size: 63 # (21*3)
|
40 |
+
eye:
|
41 |
+
input_size: 66 # (21*3)+3
|
42 |
+
hidden_sizes: [256, 256, 128, 128, 64]
|
43 |
+
output_size: 63 # (21*3)
|
requirements.txt
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--find-links https://download.pytorch.org/whl/torch_stable.html
|
2 |
+
|
3 |
+
accelerate==0.28.0
|
4 |
+
audio-separator==0.17.2
|
5 |
+
av==12.1.0
|
6 |
+
bitsandbytes==0.43.1
|
7 |
+
decord==0.6.0
|
8 |
+
diffusers==0.27.2
|
9 |
+
einops==0.8.0
|
10 |
+
huggingface==0.0.1
|
11 |
+
huggingface-hub==0.25.1
|
12 |
+
insightface==0.7.3
|
13 |
+
librosa==0.10.2.post1
|
14 |
+
mediapipe[vision]==0.10.14
|
15 |
+
mlflow==2.13.1
|
16 |
+
moviepy==1.0.3
|
17 |
+
numpy==1.26.4
|
18 |
+
omegaconf==2.3.0
|
19 |
+
onnx2torch==1.5.14
|
20 |
+
onnx==1.16.1
|
21 |
+
onnxruntime-gpu==1.18.0
|
22 |
+
opencv-python==4.10.0.84
|
23 |
+
pillow==10.3.0
|
24 |
+
pyyaml==6.0.1
|
25 |
+
setuptools==70.0.0
|
26 |
+
torch==2.2.2+cu121
|
27 |
+
torchaudio==2.2.2
|
28 |
+
torchvision==0.17.2+cu121
|
29 |
+
transformers==4.39.2
|
30 |
+
xformers==0.0.25.post1
|
31 |
+
isort==5.13.2
|
32 |
+
pre-commit==3.7.1
|
33 |
+
scipy==1.13.1
|
34 |
+
imageio==2.34.2
|
35 |
+
lmdb==1.4.1
|
36 |
+
rich==13.7.1
|
37 |
+
ffmpeg-python==0.2.0
|
38 |
+
scikit-image==0.24.0
|
39 |
+
albumentations==1.4.10
|
40 |
+
matplotlib==3.9.0
|
41 |
+
imageio-ffmpeg==0.5.1
|
42 |
+
tyro==0.8.5
|
43 |
+
gradio==5.1.0
|
44 |
+
pykalman==0.9.7
|
45 |
+
tensorboardX==2.6.2.2
|
src/datasets/mean.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:db742e76a39bbf81fb5b09fcc488bad0cbab9355df509d8e91967b58d02c6dfc
|
3 |
+
size 2582
|
src/datasets/preprocess/__pycache__/flow_filter.cpython-310.pyc
ADDED
Binary file (7.38 kB). View file
|
|
src/datasets/preprocess/__pycache__/video_crop.cpython-310.pyc
ADDED
Binary file (7.24 kB). View file
|
|
src/datasets/preprocess/__pycache__/visualize.cpython-310.pyc
ADDED
Binary file (867 Bytes). View file
|
|
src/datasets/preprocess/extract_features/__pycache__/audio_processer.cpython-310.pyc
ADDED
Binary file (13.2 kB). View file
|
|
src/datasets/preprocess/extract_features/__pycache__/audio_processer.cpython-312.pyc
ADDED
Binary file (25.6 kB). View file
|
|
src/datasets/preprocess/extract_features/__pycache__/feature_extractor_pipeline.cpython-310.pyc
ADDED
Binary file (20.1 kB). View file
|
|
src/datasets/preprocess/extract_features/__pycache__/motion_processer.cpython-310.pyc
ADDED
Binary file (36.8 kB). View file
|
|
src/datasets/preprocess/extract_features/__pycache__/test_processer.cpython-310.pyc
ADDED
Binary file (6.32 kB). View file
|
|
src/datasets/preprocess/extract_features/audio_processer.py
ADDED
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
from posixpath import isfile
|
4 |
+
from re import A
|
5 |
+
import sys
|
6 |
+
import os.path as osp
|
7 |
+
|
8 |
+
from typing import List, Dict, Tuple, Optional, Union, Any
|
9 |
+
|
10 |
+
import yaml
|
11 |
+
from omegaconf import OmegaConf
|
12 |
+
|
13 |
+
import math
|
14 |
+
import librosa
|
15 |
+
import soundfile
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
from einops import rearrange
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.nn.functional as F
|
22 |
+
|
23 |
+
from pydub import AudioSegment
|
24 |
+
from audio_separator.separator import Separator
|
25 |
+
|
26 |
+
from transformers import Wav2Vec2FeatureExtractor, HubertModel
|
27 |
+
|
28 |
+
from src.utils.rprint import rlog as log
|
29 |
+
from src.utils.util import resample_audio
|
30 |
+
|
31 |
+
from src.models.audio.wav2vec_modified import Wav2VecModel
|
32 |
+
from src.models.audio.hubert import HubertModel_ as HubertModel
|
33 |
+
|
34 |
+
|
35 |
+
def pad_audio(audio, audio_unit=320, pad_threshold=80):
|
36 |
+
batch_size, audio_len = audio.shape
|
37 |
+
n_units = audio_len // audio_unit
|
38 |
+
side_len = math.ceil((audio_unit * n_units + pad_threshold - audio_len) / 2)
|
39 |
+
if side_len >= 0:
|
40 |
+
reflect_len = side_len // 2
|
41 |
+
replicate_len = side_len % 2
|
42 |
+
if reflect_len > 0:
|
43 |
+
audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
|
44 |
+
audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
|
45 |
+
if replicate_len > 0:
|
46 |
+
audio = F.pad(audio, (1, 1), mode='replicate')
|
47 |
+
|
48 |
+
return audio
|
49 |
+
|
50 |
+
|
51 |
+
def cut_audio(audio_path: str, save_dir: str, length=60) -> List[str]:
|
52 |
+
"""Cut audio into sub-divisions and return subfile paths. Supports wav format.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
audio_path (str): the source audio file path
|
56 |
+
save_dir (str): the save directory of sub-divisions
|
57 |
+
length (int, optional): The max length of each sub-division. Defaults to 60 secs.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
List[str]: the subfile paths
|
61 |
+
"""
|
62 |
+
audio_name = osp.basename(audio_path).split('.')[0]
|
63 |
+
audio = AudioSegment.from_wav(audio_path)
|
64 |
+
segment_length = length * 1000. # pydub uses milliseconds
|
65 |
+
num_segments = math.ceil(len(audio) / segment_length)
|
66 |
+
|
67 |
+
os.makedirs(save_dir, exist_ok=True)
|
68 |
+
audio_list = []
|
69 |
+
|
70 |
+
if num_segments > 1:
|
71 |
+
for i in range(num_segments):
|
72 |
+
start_time = i * segment_length
|
73 |
+
end_time = min((i + 1) * segment_length, len(audio))
|
74 |
+
segment = audio[start_time:end_time]
|
75 |
+
|
76 |
+
path = osp.join(save_dir, f"{audio_name}_segment_{i+1}.wav")
|
77 |
+
audio_list.append(path)
|
78 |
+
segment.export(path, format="wav")
|
79 |
+
else:
|
80 |
+
audio_list = [audio_path]
|
81 |
+
return audio_list
|
82 |
+
|
83 |
+
|
84 |
+
class AudioProcessor(object):
|
85 |
+
def __init__(self, cfg_path: str, is_training: bool = False, device_id=0) -> None:
|
86 |
+
cfg = OmegaConf.load(cfg_path)
|
87 |
+
self.cfg = cfg
|
88 |
+
self.is_training = is_training
|
89 |
+
log("========================================= Audio Processer =========================================")
|
90 |
+
log(OmegaConf.to_yaml(cfg))
|
91 |
+
|
92 |
+
# setting device
|
93 |
+
self.device_id = device_id
|
94 |
+
self.use_half = cfg.device_params.flag_use_half_precision
|
95 |
+
if cfg.device_params.flag_force_cpu:
|
96 |
+
self.device = 'cpu'
|
97 |
+
else:
|
98 |
+
try:
|
99 |
+
if torch.backends.mps.is_available():
|
100 |
+
self.device = 'mps'
|
101 |
+
else:
|
102 |
+
self.device = 'cuda:' + str(self.device_id)
|
103 |
+
except:
|
104 |
+
self.device = 'cuda:' + str(self.device_id)
|
105 |
+
|
106 |
+
# init audio separator
|
107 |
+
self.audio_separator = None
|
108 |
+
self.cache_dir = cfg.cache_dir
|
109 |
+
self.tmp_dir = cfg.tmp_dir
|
110 |
+
self.use_audio_separator = cfg.model_params.use_audio_separator
|
111 |
+
self.audio_separator_name = cfg.model_params.audio_separator_name
|
112 |
+
self.audio_separator_path = cfg.model_weights.audio_separator_path
|
113 |
+
self.set_audio_separator(cfg.cache_dir)
|
114 |
+
|
115 |
+
# load audio encoder, wav2vec or hubert
|
116 |
+
self.model_name = cfg.model_params.model_name
|
117 |
+
self.is_chinese = cfg.model_params.is_chinese
|
118 |
+
self.audio_encoder, self.feature_extractor = self.load_model(
|
119 |
+
model_name = cfg.model_params.model_name,
|
120 |
+
model_type = cfg.model_params.model_type,
|
121 |
+
is_chinese = cfg.model_params.is_chinese,
|
122 |
+
)
|
123 |
+
self.only_last_features = cfg.model_params.only_last_features
|
124 |
+
if cfg.model_params.only_last_features:
|
125 |
+
self.feature_shape = (1, 768)
|
126 |
+
else:
|
127 |
+
self.feature_shape = (12, 768) # features of 12 blocks
|
128 |
+
|
129 |
+
# init data params
|
130 |
+
self.sample_strategy = cfg.data_params.sample_strategy
|
131 |
+
self.sample_rate = cfg.data_params.sample_rate
|
132 |
+
self.fps = cfg.data_params.fps
|
133 |
+
self.audio_unit = cfg.data_params.sample_rate / cfg.data_params.fps # num of audio samples per frame
|
134 |
+
self.max_length = cfg.data_params.max_length
|
135 |
+
self.subclip_len = cfg.data_params.sub_clip_length
|
136 |
+
self.save_to_cpu = cfg.data_params.save_to_cpu
|
137 |
+
self.pad_mode = cfg.data_params.audio_pad_mode
|
138 |
+
|
139 |
+
log("========================================= Audio Processer: Done =========================================")
|
140 |
+
|
141 |
+
def load_model(self, model_name: str="wav2vec", model_type: str="base", is_chinese: bool = False):
|
142 |
+
assert model_name in ["wav2vec", "hubert"], f"Unknown audio model {model_name}, only support wav2vec or hubert"
|
143 |
+
assert model_type in ["base", "large"], f"Unknown audio model type {model_type}, only support base or large"
|
144 |
+
|
145 |
+
if model_name == "wav2vec":
|
146 |
+
# load wav2vec model weights
|
147 |
+
if is_chinese:
|
148 |
+
if model_type == "base":
|
149 |
+
model_weight_path = self.cfg.model_weights.wav2vec_path.chinese.base
|
150 |
+
else:
|
151 |
+
model_weight_path = self.cfg.model_weights.wav2vec_path.chinese.large
|
152 |
+
else:
|
153 |
+
if model_type == "base":
|
154 |
+
model_weight_path = self.cfg.model_weights.wav2vec_path.default.base
|
155 |
+
else:
|
156 |
+
model_weight_path = self.cfg.model_weights.wav2vec_path.default.large
|
157 |
+
if model_weight_path is None:
|
158 |
+
raise ValueError(f"model_weight_path is None")
|
159 |
+
audio_encoder = Wav2VecModel.from_pretrained(model_weight_path, local_files_only=True).to(device=self.device)
|
160 |
+
else:
|
161 |
+
if is_chinese:
|
162 |
+
if model_type == "base":
|
163 |
+
model_weight_path = self.cfg.model_weights.hubert_path.chinese.base
|
164 |
+
else:
|
165 |
+
model_weight_path = self.cfg.model_weights.hubert_path.chinese.large
|
166 |
+
else:
|
167 |
+
if model_type == "base":
|
168 |
+
model_weight_path = self.cfg.model_weights.hubert_path.default.base
|
169 |
+
else:
|
170 |
+
model_weight_path = self.cfg.model_weights.hubert_path.default.large
|
171 |
+
if model_weight_path is None:
|
172 |
+
raise ValueError(f"model_weight_path is None")
|
173 |
+
audio_encoder = HubertModel.from_pretrained(model_weight_path, local_files_only=True).to(device=self.device)
|
174 |
+
|
175 |
+
log(f"{model_name}-{model_type}-chinese-{is_chinese} model has beed loaded from {model_weight_path}")
|
176 |
+
total_params = sum(p.numel() for p in audio_encoder.parameters())
|
177 |
+
print('Number of parameter: % .4fM' % (total_params / 1e6))
|
178 |
+
|
179 |
+
# weights initialization
|
180 |
+
audio_encoder.feature_extractor._freeze_parameters()
|
181 |
+
if not self.cfg.model_params.is_original:
|
182 |
+
frozen_layers = [0, 1]
|
183 |
+
for name, param in audio_encoder.named_parameters():
|
184 |
+
if name.startswith("feature_projection"):
|
185 |
+
param.requires_grad = False
|
186 |
+
if name.startswith("encoder.layers"):
|
187 |
+
layer = int(name.split(".")[2])
|
188 |
+
if layer in frozen_layers:
|
189 |
+
param.requires_grad = False
|
190 |
+
|
191 |
+
audio_encoder = audio_encoder.to(self.device)
|
192 |
+
if self.use_half:
|
193 |
+
audio_encoder = audio_encoder.half()
|
194 |
+
audio_encoder.eval()
|
195 |
+
|
196 |
+
# feature extractor
|
197 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_weight_path)
|
198 |
+
|
199 |
+
return audio_encoder, feature_extractor
|
200 |
+
|
201 |
+
def set_audio_separator(self, output_dir: str) -> None:
|
202 |
+
del self.audio_separator
|
203 |
+
|
204 |
+
if self.audio_separator_name is not None and self.use_audio_separator:
|
205 |
+
try:
|
206 |
+
os.makedirs(output_dir, exist_ok=True)
|
207 |
+
except OSError as _:
|
208 |
+
print("Fail to create the output cache dir.")
|
209 |
+
self.audio_separator = Separator(
|
210 |
+
output_dir=output_dir,
|
211 |
+
output_single_stem="vocals",
|
212 |
+
model_file_dir=self.audio_separator_path,
|
213 |
+
)
|
214 |
+
self.audio_separator.load_model(self.audio_separator_name)
|
215 |
+
assert self.audio_separator.model_instance is not None, "Fail to load audio separate model."
|
216 |
+
else:
|
217 |
+
self.audio_separator=None
|
218 |
+
log("Use audio directly without vocals seperator.")
|
219 |
+
|
220 |
+
def seperate_audio(self, audio_path: str, output_dir: Union[str, None] = None) -> str:
|
221 |
+
if output_dir is not None:
|
222 |
+
if output_dir != self.cache_dir:
|
223 |
+
# reload audio separator
|
224 |
+
self.set_audio_separator(output_dir)
|
225 |
+
|
226 |
+
if self.audio_separator is not None:
|
227 |
+
# 1. separate vocals
|
228 |
+
# TODO: process in memory
|
229 |
+
try:
|
230 |
+
outputs = self.audio_separator.separate(audio_path)
|
231 |
+
if len(outputs) <= 0:
|
232 |
+
raise RuntimeError("Audio separate failed.")
|
233 |
+
|
234 |
+
vocal_audio_file = outputs[0]
|
235 |
+
vocal_audio_name, _ = os.path.splitext(vocal_audio_file)
|
236 |
+
vocal_audio_file = os.path.join(self.audio_separator.output_dir, vocal_audio_file)
|
237 |
+
vocal_audio_file = resample_audio(vocal_audio_file, os.path.join(self.audio_separator.output_dir, f"{vocal_audio_name}-16k.wav"), self.sample_rate)
|
238 |
+
except Exception as e:
|
239 |
+
log(f"Fail to separate vocals from {audio_path}, error info [{e}]")
|
240 |
+
vocal_audio_file=audio_path
|
241 |
+
else:
|
242 |
+
vocal_audio_file=audio_path
|
243 |
+
|
244 |
+
return vocal_audio_file
|
245 |
+
|
246 |
+
def load_audio(self, audio_path: str, mono: bool = True, duration: Optional[float] = None) -> Any:
|
247 |
+
try:
|
248 |
+
audio_data, sampling_rate = librosa.load(audio_path, sr=self.sample_rate, mono=mono, duration=duration)
|
249 |
+
except Exception as e:
|
250 |
+
raise RuntimeError(f"Fail to load audio from {audio_path}, error info [{e}]")
|
251 |
+
return audio_data, sampling_rate
|
252 |
+
|
253 |
+
def prepare_audio_data(self, audio_data: Union[np.ndarray, torch.Tensor], n_frames: Optional[int]=None) -> Tuple[List[Any], int]:
|
254 |
+
"""Prepare audio data for processing.
|
255 |
+
"""
|
256 |
+
#print(f"==========> Using Wav2Vec2FeatureExtractor to extract audio features")
|
257 |
+
audio_data = np.squeeze(self.feature_extractor(audio_data, sampling_rate=self.sample_rate).input_values)
|
258 |
+
|
259 |
+
clip_len = int(len(audio_data) / self.audio_unit)
|
260 |
+
if n_frames is not None:
|
261 |
+
if abs(n_frames - clip_len) > 7:
|
262 |
+
log(f"The number of frames must be close to the clip length (in 280ms), got {n_frames} and {clip_len}")
|
263 |
+
return [], n_frames
|
264 |
+
clip_len = n_frames
|
265 |
+
else:
|
266 |
+
n_frames = clip_len
|
267 |
+
|
268 |
+
if isinstance(audio_data, np.ndarray):
|
269 |
+
audio_data = torch.from_numpy(audio_data).float().to(self.device)
|
270 |
+
assert audio_data.ndim == 1, 'Audio must be 1D tensor.'
|
271 |
+
|
272 |
+
# padding
|
273 |
+
# padding audio to fit the clip length
|
274 |
+
n_audio_samples = round(self.audio_unit * clip_len)
|
275 |
+
n_padding_audio_samples = n_audio_samples - len(audio_data)
|
276 |
+
n_padding_frames = math.ceil(n_padding_audio_samples / self.audio_unit)
|
277 |
+
if n_padding_audio_samples > 0:
|
278 |
+
if self.pad_mode == 'zero':
|
279 |
+
padding_value = 0
|
280 |
+
elif self.pad_mode == 'replicate':
|
281 |
+
padding_value = float(audio_data[-1])
|
282 |
+
else:
|
283 |
+
raise ValueError(f'Unknown pad mode: {self.pad_mode}')
|
284 |
+
audio_data = F.pad(audio_data, (0, n_padding_audio_samples), value=padding_value)
|
285 |
+
|
286 |
+
# devide audio into sub-divisions for saving GPU memory
|
287 |
+
audio_segments = []
|
288 |
+
if clip_len <= self.subclip_len:
|
289 |
+
n_subdivision = 1
|
290 |
+
subclip_len = clip_len
|
291 |
+
else:
|
292 |
+
n_subdivision = math.ceil(clip_len / self.subclip_len)
|
293 |
+
subclip_len = self.subclip_len
|
294 |
+
|
295 |
+
for i in range(0, n_subdivision):
|
296 |
+
start_idx = i * subclip_len
|
297 |
+
end_idx = min(start_idx + subclip_len, clip_len)
|
298 |
+
# debug
|
299 |
+
#log(f"[{i+1}/{n_subdivision}] data index [{round(start_idx * self.audio_unit)}, {round(end_idx * self.audio_unit)})")
|
300 |
+
audio_segments.append(
|
301 |
+
{
|
302 |
+
"data": audio_data[round(start_idx * self.audio_unit):round(end_idx * self.audio_unit)].unsqueeze(0),
|
303 |
+
"start_idx": start_idx,
|
304 |
+
"end_idx": end_idx,
|
305 |
+
"length": end_idx - start_idx
|
306 |
+
}
|
307 |
+
)
|
308 |
+
return audio_segments, n_frames
|
309 |
+
|
310 |
+
def get_audio_embedding(self, audio, clip_len: int) -> torch.Tensor:
|
311 |
+
if audio.ndim == 2:
|
312 |
+
# Extract audio features
|
313 |
+
assert audio.shape[1] == 16000 * clip_len / self.fps, \
|
314 |
+
f'Incorrect audio length {audio.shape[1]}'
|
315 |
+
|
316 |
+
# Extract audio features
|
317 |
+
if self.use_half:
|
318 |
+
audio = audio.half()
|
319 |
+
embeddings = self.audio_encoder(
|
320 |
+
pad_audio(audio), seq_len=clip_len, sample_strategy=self.sample_strategy, output_hidden_states=True
|
321 |
+
) # (N, L, 768)
|
322 |
+
assert len(embeddings) > 0, "Fail to extract audio embedding"
|
323 |
+
|
324 |
+
if self.only_last_features:
|
325 |
+
audio_emb = embeddings.last_hidden_state.squeeze(0)
|
326 |
+
else:
|
327 |
+
audio_emb = torch.stack(
|
328 |
+
embeddings.hidden_states[1:], dim=1
|
329 |
+
).squeeze(0)
|
330 |
+
audio_emb = rearrange(audio_emb, "b s d -> s b d")
|
331 |
+
|
332 |
+
elif audio.ndim == 3:
|
333 |
+
assert audio.shape[1] == clip_len, f'Incorrect audio feature length {audio.shape[1]}'
|
334 |
+
audio_emb = audio
|
335 |
+
else:
|
336 |
+
raise ValueError(f'Incorrect audio input shape {audio.shape}')
|
337 |
+
|
338 |
+
return audio_emb
|
339 |
+
|
340 |
+
def get_audio_embeddings(self, audio_segments: List[Any]) -> Optional[torch.Tensor]:
|
341 |
+
audio_embs = []
|
342 |
+
for audio_segment in audio_segments:
|
343 |
+
if self.is_training:
|
344 |
+
audio_emb = self.get_audio_embedding(audio_segment["data"], audio_segment["length"])
|
345 |
+
else:
|
346 |
+
with torch.no_grad():
|
347 |
+
audio_emb = self.get_audio_embedding(audio_segment["data"], audio_segment["length"])
|
348 |
+
|
349 |
+
audio_emb = audio_emb.cpu() if self.save_to_cpu else audio_emb
|
350 |
+
audio_embs.append(audio_emb)
|
351 |
+
#log(f"audio segment [{audio_segment['start_idx']}, {audio_segment['end_idx']}) has been processed.")
|
352 |
+
|
353 |
+
if len(audio_embs) == 0:
|
354 |
+
return None
|
355 |
+
|
356 |
+
audio_emb = torch.cat(audio_embs, dim=0)
|
357 |
+
|
358 |
+
return audio_emb
|
359 |
+
|
360 |
+
def preprocess(
|
361 |
+
self,
|
362 |
+
audio_path: str,
|
363 |
+
n_frames: Optional[int] = None,
|
364 |
+
duration: Optional[float] = None,
|
365 |
+
need_seperate: bool = False
|
366 |
+
):
|
367 |
+
""" Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate.
|
368 |
+
The separated vocal track is then converted into wav2vec2 for further processing or analysis.
|
369 |
+
"""
|
370 |
+
if need_seperate:
|
371 |
+
vocal_audio_file = self.seperate_audio(audio_path)
|
372 |
+
else:
|
373 |
+
vocal_audio_file = audio_path
|
374 |
+
|
375 |
+
audio_data, sampling_rate = self.load_audio(vocal_audio_file, duration=duration)
|
376 |
+
|
377 |
+
assert sampling_rate == 16000, "The sample rate of audio must be 16000"
|
378 |
+
audio_segments, n_frames = self.prepare_audio_data(audio_data, n_frames)
|
379 |
+
audio_emb = self.get_audio_embeddings(audio_segments)
|
380 |
+
if audio_emb is None:
|
381 |
+
log(f"{audio_path} has been processed, but no audio embedding, set as 'None'.")
|
382 |
+
#else:
|
383 |
+
#log(f"{audio_path} has been processed, audio embedding shape {audio_emb.shape}.")
|
384 |
+
return audio_emb, n_frames
|
385 |
+
|
386 |
+
def preprocess_long(
|
387 |
+
self,
|
388 |
+
audio_path: str,
|
389 |
+
need_seperate: bool = False
|
390 |
+
):
|
391 |
+
audio_list = cut_audio(audio_path, self.tmp_dir, length=self.max_length)
|
392 |
+
audio_emb_list = []
|
393 |
+
l = 0
|
394 |
+
|
395 |
+
for idx, audio_path in enumerate(audio_list):
|
396 |
+
padding = (idx+1) == len(audio_list)
|
397 |
+
emb, length = self.preprocess(audio_path, need_seperate=need_seperate)
|
398 |
+
audio_emb_list.append(emb)
|
399 |
+
log(f"Processing audio {idx+1}/{len(audio_list)}, path: {audio_path} length: {length}")
|
400 |
+
l += length
|
401 |
+
|
402 |
+
audio_emb = torch.cat(audio_emb_list)
|
403 |
+
audio_length = l
|
404 |
+
|
405 |
+
# remove tmp file
|
406 |
+
if len(audio_list) > 1:
|
407 |
+
for audio_path in audio_list:
|
408 |
+
os.remove(audio_path)
|
409 |
+
|
410 |
+
return audio_emb, audio_length
|
411 |
+
|
412 |
+
def add_silent_audio(self, audio_path: str, silent_audio_path: Optional[str] = None, add_duration: float = 1., linear_fusion=False, mode="post"):
|
413 |
+
# mode, pre, post, both
|
414 |
+
assert mode in ["pre", "post", "both"], f"Unkown mode: {mode}, only support pre, post, both"
|
415 |
+
if silent_audio_path is None:
|
416 |
+
return audio_path, 0
|
417 |
+
else:
|
418 |
+
audio_dir = osp.dirname(audio_path)
|
419 |
+
audio_name = osp.basename(audio_path)
|
420 |
+
temp_audio_path = osp.join(audio_dir, f"tmp_{audio_name}")
|
421 |
+
if osp.isfile(temp_audio_path):
|
422 |
+
os.remove(temp_audio_path)
|
423 |
+
|
424 |
+
audio, sr1 = librosa.load(audio_path, mono=True, sr=16000)
|
425 |
+
# denoise
|
426 |
+
audio = librosa.effects.preemphasis(audio) # enhance voice
|
427 |
+
# load silent audio
|
428 |
+
silent_audio, sr2 = librosa.load(silent_audio_path, mono=True, sr=16000)
|
429 |
+
silent_audio = silent_audio[:int(add_duration*sr2)]
|
430 |
+
|
431 |
+
if linear_fusion:
|
432 |
+
short_len = min(len(audio), len(silent_audio))
|
433 |
+
fusion_ratio = np.linspace(0, 1.0, num=short_len)
|
434 |
+
# get pre padding audio
|
435 |
+
pre_pad_audio = fusion_ratio * silent_audio[:short_len] + (1 - fusion_ratio) * audio[:short_len]
|
436 |
+
if short_len < len(silent_audio):
|
437 |
+
pre_pad_audio = np.hstack((pre_pad_audio, silent_audio[short_len:]))
|
438 |
+
pre_pad_audio = np.flip(pre_pad_audio, axis=0)
|
439 |
+
|
440 |
+
# get post padding audio
|
441 |
+
post_pad_audio = (1 - fusion_ratio) * silent_audio[-short_len:] + fusion_ratio * audio[-short_len:]
|
442 |
+
if short_len < len(silent_audio):
|
443 |
+
post_pad_audio = np.hstack((silent_audio[:-short_len], post_pad_audio))
|
444 |
+
post_pad_audio = np.flip(post_pad_audio, axis=0)
|
445 |
+
else:
|
446 |
+
pre_pad_audio = silent_audio
|
447 |
+
post_pad_audio = silent_audio
|
448 |
+
|
449 |
+
# padding audio
|
450 |
+
if mode == "both":
|
451 |
+
combined_audio = np.hstack((pre_pad_audio, audio, post_pad_audio))
|
452 |
+
elif mode == "pre":
|
453 |
+
combined_audio = np.hstack((pre_pad_audio, audio))
|
454 |
+
else:
|
455 |
+
combined_audio = np.hstack((audio, post_pad_audio))
|
456 |
+
|
457 |
+
add_nframes = math.floor(add_duration * sr2 / self.audio_unit)
|
458 |
+
#print(f"audio length: {len(audio)}, pre_pad_audio length: {len(pre_pad_audio)}, post_pad_audio length: {len(post_pad_audio)}, combined_length: {len(combined_audio)}, total add {add_nframes*2} frames")
|
459 |
+
#print(f"audio duration: {librosa.get_duration(audio, sr=sr1)}, silent duration: {librosa.get_duration(silent_audio, sr=sr2)}, combined duration: {librosa.get_duration(combined_audio, sr=sr2)}")
|
460 |
+
soundfile.write(temp_audio_path, combined_audio, sr2)
|
461 |
+
|
462 |
+
return temp_audio_path, add_nframes
|
463 |
+
|
464 |
+
def get_long_audio_emb(self, audio_path: str) -> torch.Tensor:
|
465 |
+
audio_emb, length = self.preprocess_long(audio_path)
|
466 |
+
log(f"Load audio from {osp.realpath(audio_path)} done, audio_emb shape: {audio_emb.shape}.")
|
467 |
+
return audio_emb
|
468 |
+
|
469 |
+
def __enter__(self):
|
470 |
+
return self
|
471 |
+
|
src/datasets/preprocess/extract_features/face_segmentation/__init__.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torchvision import transforms
|
6 |
+
|
7 |
+
from .bisenet import BiSeNet
|
8 |
+
|
9 |
+
|
10 |
+
def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='parsing_map_on_im2.jpg'):
|
11 |
+
# Colors for all 20 parts
|
12 |
+
part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
|
13 |
+
[255, 0, 85], [255, 0, 170],
|
14 |
+
[0, 255, 0], [85, 255, 0], [170, 255, 0],
|
15 |
+
[0, 255, 85], [0, 255, 170],
|
16 |
+
[0, 0, 255], [85, 0, 255], [170, 0, 255],
|
17 |
+
[0, 85, 255], [0, 170, 255],
|
18 |
+
[255, 255, 0], [255, 255, 85], [255, 255, 170],
|
19 |
+
[255, 0, 255], [255, 85, 255], [255, 170, 255],
|
20 |
+
[0, 255, 255], [85, 255, 255], [170, 255, 255]]
|
21 |
+
|
22 |
+
im = np.array(im)
|
23 |
+
vis_im = im.copy().astype(np.uint8)
|
24 |
+
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
|
25 |
+
vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
|
26 |
+
vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
|
27 |
+
|
28 |
+
num_of_class = np.max(vis_parsing_anno)
|
29 |
+
|
30 |
+
for pi in range(1, num_of_class + 1):
|
31 |
+
index = np.where(vis_parsing_anno == pi)
|
32 |
+
vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
|
33 |
+
|
34 |
+
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
|
35 |
+
# print(vis_parsing_anno_color.shape, vis_im.shape)
|
36 |
+
vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
|
37 |
+
|
38 |
+
# Save result or not
|
39 |
+
if save_im:
|
40 |
+
cv2.imwrite(save_path[:-4] +'.png', vis_parsing_anno)
|
41 |
+
cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
|
42 |
+
|
43 |
+
# return vis_im
|
44 |
+
|
45 |
+
def get_face_mask(face_parser, images, batch_size=128):
|
46 |
+
# images: Bx3xHxW
|
47 |
+
kernel = np.ones((13, 13), np.float32)
|
48 |
+
face_masks = []
|
49 |
+
for i in range(0, images.shape[0], batch_size):
|
50 |
+
images_batch = images[i:i+batch_size]
|
51 |
+
with torch.no_grad():
|
52 |
+
out = face_parser(images_batch)[0]
|
53 |
+
parsing = out.cpu().numpy().argmax(1)
|
54 |
+
masks = np.zeros_like(parsing, np.float32)
|
55 |
+
for idx in range(1, 14):
|
56 |
+
masks[parsing == idx] = 1
|
57 |
+
|
58 |
+
for mask in masks:
|
59 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=2)
|
60 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2)
|
61 |
+
mask = cv2.dilate(mask, kernel, iterations=3)
|
62 |
+
face_masks.append(mask)
|
63 |
+
|
64 |
+
return face_masks
|
65 |
+
|
66 |
+
|
67 |
+
def build_face_parser(weight_path, resnet_weight_path, n_classes=19, device_id=0):
|
68 |
+
model_state_dict = torch.load(weight_path, weights_only=False)
|
69 |
+
bisenet = BiSeNet(n_classes, resnet_weight_path=resnet_weight_path)
|
70 |
+
# load model
|
71 |
+
#bisenet.load_state_dict(model_state_dict, strict=True)
|
72 |
+
bisenet_state_dict = bisenet.state_dict()
|
73 |
+
for k, v in model_state_dict.items():
|
74 |
+
if 'fc' in k: continue
|
75 |
+
bisenet_state_dict.update({k: v})
|
76 |
+
bisenet.load_state_dict(bisenet_state_dict)
|
77 |
+
bisenet.to(f"cuda:{device_id}")
|
78 |
+
|
79 |
+
to_tensor = transforms.Compose([
|
80 |
+
transforms.ToTensor(),
|
81 |
+
transforms.Resize((512, 512)),
|
82 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
83 |
+
])
|
84 |
+
|
85 |
+
return bisenet.eval(), to_tensor
|
86 |
+
|
87 |
+
|
88 |
+
|
src/datasets/preprocess/extract_features/face_segmentation/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (3.07 kB). View file
|
|
src/datasets/preprocess/extract_features/face_segmentation/__pycache__/bisenet.cpython-310.pyc
ADDED
Binary file (8.38 kB). View file
|
|
src/datasets/preprocess/extract_features/face_segmentation/__pycache__/resnet.cpython-310.pyc
ADDED
Binary file (3.77 kB). View file
|
|
src/datasets/preprocess/extract_features/face_segmentation/bisenet.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchvision
|
9 |
+
|
10 |
+
from .resnet import Resnet18
|
11 |
+
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
12 |
+
|
13 |
+
|
14 |
+
class ConvBNReLU(nn.Module):
|
15 |
+
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
|
16 |
+
super(ConvBNReLU, self).__init__()
|
17 |
+
self.conv = nn.Conv2d(in_chan,
|
18 |
+
out_chan,
|
19 |
+
kernel_size = ks,
|
20 |
+
stride = stride,
|
21 |
+
padding = padding,
|
22 |
+
bias = False)
|
23 |
+
self.bn = nn.BatchNorm2d(out_chan)
|
24 |
+
self.init_weight()
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = self.conv(x)
|
28 |
+
x = F.relu(self.bn(x))
|
29 |
+
return x
|
30 |
+
|
31 |
+
def init_weight(self):
|
32 |
+
for ly in self.children():
|
33 |
+
if isinstance(ly, nn.Conv2d):
|
34 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
35 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
36 |
+
|
37 |
+
class BiSeNetOutput(nn.Module):
|
38 |
+
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
|
39 |
+
super(BiSeNetOutput, self).__init__()
|
40 |
+
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
41 |
+
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
|
42 |
+
self.init_weight()
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.conv(x)
|
46 |
+
x = self.conv_out(x)
|
47 |
+
return x
|
48 |
+
|
49 |
+
def init_weight(self):
|
50 |
+
for ly in self.children():
|
51 |
+
if isinstance(ly, nn.Conv2d):
|
52 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
53 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
54 |
+
|
55 |
+
def get_params(self):
|
56 |
+
wd_params, nowd_params = [], []
|
57 |
+
for name, module in self.named_modules():
|
58 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
59 |
+
wd_params.append(module.weight)
|
60 |
+
if not module.bias is None:
|
61 |
+
nowd_params.append(module.bias)
|
62 |
+
elif isinstance(module, nn.BatchNorm2d):
|
63 |
+
nowd_params += list(module.parameters())
|
64 |
+
return wd_params, nowd_params
|
65 |
+
|
66 |
+
|
67 |
+
class AttentionRefinementModule(nn.Module):
|
68 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
69 |
+
super(AttentionRefinementModule, self).__init__()
|
70 |
+
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
71 |
+
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
|
72 |
+
self.bn_atten = nn.BatchNorm2d(out_chan)
|
73 |
+
self.sigmoid_atten = nn.Sigmoid()
|
74 |
+
self.init_weight()
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
feat = self.conv(x)
|
78 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
79 |
+
atten = self.conv_atten(atten)
|
80 |
+
atten = self.bn_atten(atten)
|
81 |
+
atten = self.sigmoid_atten(atten)
|
82 |
+
out = torch.mul(feat, atten)
|
83 |
+
return out
|
84 |
+
|
85 |
+
def init_weight(self):
|
86 |
+
for ly in self.children():
|
87 |
+
if isinstance(ly, nn.Conv2d):
|
88 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
89 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
90 |
+
|
91 |
+
|
92 |
+
class ContextPath(nn.Module):
|
93 |
+
def __init__(self, *args, **kwargs):
|
94 |
+
super(ContextPath, self).__init__()
|
95 |
+
backbone_weight_path = kwargs.get("resnet_weight_path", None)
|
96 |
+
self.resnet = Resnet18(backbone_weight_path)
|
97 |
+
self.arm16 = AttentionRefinementModule(256, 128)
|
98 |
+
self.arm32 = AttentionRefinementModule(512, 128)
|
99 |
+
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
100 |
+
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
101 |
+
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
|
102 |
+
|
103 |
+
self.init_weight()
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
H0, W0 = x.size()[2:]
|
107 |
+
feat8, feat16, feat32 = self.resnet(x)
|
108 |
+
H8, W8 = feat8.size()[2:]
|
109 |
+
H16, W16 = feat16.size()[2:]
|
110 |
+
H32, W32 = feat32.size()[2:]
|
111 |
+
|
112 |
+
avg = F.avg_pool2d(feat32, feat32.size()[2:])
|
113 |
+
avg = self.conv_avg(avg)
|
114 |
+
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
|
115 |
+
|
116 |
+
feat32_arm = self.arm32(feat32)
|
117 |
+
feat32_sum = feat32_arm + avg_up
|
118 |
+
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
|
119 |
+
feat32_up = self.conv_head32(feat32_up)
|
120 |
+
|
121 |
+
feat16_arm = self.arm16(feat16)
|
122 |
+
feat16_sum = feat16_arm + feat32_up
|
123 |
+
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
|
124 |
+
feat16_up = self.conv_head16(feat16_up)
|
125 |
+
|
126 |
+
return feat8, feat16_up, feat32_up # x8, x8, x16
|
127 |
+
|
128 |
+
def init_weight(self):
|
129 |
+
for ly in self.children():
|
130 |
+
if isinstance(ly, nn.Conv2d):
|
131 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
132 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
133 |
+
|
134 |
+
def get_params(self):
|
135 |
+
wd_params, nowd_params = [], []
|
136 |
+
for name, module in self.named_modules():
|
137 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
138 |
+
wd_params.append(module.weight)
|
139 |
+
if not module.bias is None:
|
140 |
+
nowd_params.append(module.bias)
|
141 |
+
elif isinstance(module, nn.BatchNorm2d):
|
142 |
+
nowd_params += list(module.parameters())
|
143 |
+
return wd_params, nowd_params
|
144 |
+
|
145 |
+
|
146 |
+
### This is not used, since I replace this with the resnet feature with the same size
|
147 |
+
class SpatialPath(nn.Module):
|
148 |
+
def __init__(self, *args, **kwargs):
|
149 |
+
super(SpatialPath, self).__init__()
|
150 |
+
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
|
151 |
+
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
152 |
+
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
153 |
+
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
|
154 |
+
self.init_weight()
|
155 |
+
|
156 |
+
def forward(self, x):
|
157 |
+
feat = self.conv1(x)
|
158 |
+
feat = self.conv2(feat)
|
159 |
+
feat = self.conv3(feat)
|
160 |
+
feat = self.conv_out(feat)
|
161 |
+
return feat
|
162 |
+
|
163 |
+
def init_weight(self):
|
164 |
+
for ly in self.children():
|
165 |
+
if isinstance(ly, nn.Conv2d):
|
166 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
167 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
168 |
+
|
169 |
+
def get_params(self):
|
170 |
+
wd_params, nowd_params = [], []
|
171 |
+
for name, module in self.named_modules():
|
172 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
173 |
+
wd_params.append(module.weight)
|
174 |
+
if not module.bias is None:
|
175 |
+
nowd_params.append(module.bias)
|
176 |
+
elif isinstance(module, nn.BatchNorm2d):
|
177 |
+
nowd_params += list(module.parameters())
|
178 |
+
return wd_params, nowd_params
|
179 |
+
|
180 |
+
|
181 |
+
class FeatureFusionModule(nn.Module):
|
182 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
183 |
+
super(FeatureFusionModule, self).__init__()
|
184 |
+
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
185 |
+
self.conv1 = nn.Conv2d(out_chan,
|
186 |
+
out_chan//4,
|
187 |
+
kernel_size = 1,
|
188 |
+
stride = 1,
|
189 |
+
padding = 0,
|
190 |
+
bias = False)
|
191 |
+
self.conv2 = nn.Conv2d(out_chan//4,
|
192 |
+
out_chan,
|
193 |
+
kernel_size = 1,
|
194 |
+
stride = 1,
|
195 |
+
padding = 0,
|
196 |
+
bias = False)
|
197 |
+
self.relu = nn.ReLU(inplace=True)
|
198 |
+
self.sigmoid = nn.Sigmoid()
|
199 |
+
self.init_weight()
|
200 |
+
|
201 |
+
def forward(self, fsp, fcp):
|
202 |
+
fcat = torch.cat([fsp, fcp], dim=1)
|
203 |
+
feat = self.convblk(fcat)
|
204 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
205 |
+
atten = self.conv1(atten)
|
206 |
+
atten = self.relu(atten)
|
207 |
+
atten = self.conv2(atten)
|
208 |
+
atten = self.sigmoid(atten)
|
209 |
+
feat_atten = torch.mul(feat, atten)
|
210 |
+
feat_out = feat_atten + feat
|
211 |
+
return feat_out
|
212 |
+
|
213 |
+
def init_weight(self):
|
214 |
+
for ly in self.children():
|
215 |
+
if isinstance(ly, nn.Conv2d):
|
216 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
217 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
218 |
+
|
219 |
+
def get_params(self):
|
220 |
+
wd_params, nowd_params = [], []
|
221 |
+
for name, module in self.named_modules():
|
222 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
223 |
+
wd_params.append(module.weight)
|
224 |
+
if not module.bias is None:
|
225 |
+
nowd_params.append(module.bias)
|
226 |
+
elif isinstance(module, nn.BatchNorm2d):
|
227 |
+
nowd_params += list(module.parameters())
|
228 |
+
return wd_params, nowd_params
|
229 |
+
|
230 |
+
|
231 |
+
class BiSeNet(nn.Module):
|
232 |
+
def __init__(self, n_classes, *args, **kwargs):
|
233 |
+
super(BiSeNet, self).__init__()
|
234 |
+
backbone_weight_path = kwargs.get("resnet_weight_path", None)
|
235 |
+
self.cp = ContextPath(resnet_weight_path=backbone_weight_path)
|
236 |
+
## here self.sp is deleted
|
237 |
+
self.ffm = FeatureFusionModule(256, 256)
|
238 |
+
self.conv_out = BiSeNetOutput(256, 256, n_classes)
|
239 |
+
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
|
240 |
+
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
|
241 |
+
self.init_weight()
|
242 |
+
|
243 |
+
def forward(self, x):
|
244 |
+
H, W = x.size()[2:]
|
245 |
+
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
|
246 |
+
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
|
247 |
+
feat_fuse = self.ffm(feat_sp, feat_cp8)
|
248 |
+
|
249 |
+
feat_out = self.conv_out(feat_fuse)
|
250 |
+
feat_out16 = self.conv_out16(feat_cp8)
|
251 |
+
feat_out32 = self.conv_out32(feat_cp16)
|
252 |
+
|
253 |
+
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
|
254 |
+
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
|
255 |
+
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
|
256 |
+
return feat_out, feat_out16, feat_out32
|
257 |
+
|
258 |
+
def init_weight(self):
|
259 |
+
for ly in self.children():
|
260 |
+
if isinstance(ly, nn.Conv2d):
|
261 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
262 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
263 |
+
|
264 |
+
def get_params(self):
|
265 |
+
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
|
266 |
+
for name, child in self.named_children():
|
267 |
+
child_wd_params, child_nowd_params = child.get_params()
|
268 |
+
if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
|
269 |
+
lr_mul_wd_params += child_wd_params
|
270 |
+
lr_mul_nowd_params += child_nowd_params
|
271 |
+
else:
|
272 |
+
wd_params += child_wd_params
|
273 |
+
nowd_params += child_nowd_params
|
274 |
+
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
|
275 |
+
|
276 |
+
|
277 |
+
if __name__ == "__main__":
|
278 |
+
net = BiSeNet(19)
|
279 |
+
net.cuda()
|
280 |
+
net.eval()
|
281 |
+
in_ten = torch.randn(16, 3, 640, 480).cuda()
|
282 |
+
out, out16, out32 = net(in_ten)
|
283 |
+
print(out.shape)
|
284 |
+
|
285 |
+
net.get_params()
|
src/datasets/preprocess/extract_features/face_segmentation/resnet.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#!/usr/bin/python
|
3 |
+
# -*- encoding: utf-8 -*-
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch.utils.model_zoo as modelzoo
|
9 |
+
|
10 |
+
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
11 |
+
|
12 |
+
resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
|
13 |
+
|
14 |
+
|
15 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
16 |
+
"""3x3 convolution with padding"""
|
17 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
18 |
+
padding=1, bias=False)
|
19 |
+
|
20 |
+
|
21 |
+
class BasicBlock(nn.Module):
|
22 |
+
def __init__(self, in_chan, out_chan, stride=1):
|
23 |
+
super(BasicBlock, self).__init__()
|
24 |
+
self.conv1 = conv3x3(in_chan, out_chan, stride)
|
25 |
+
self.bn1 = nn.BatchNorm2d(out_chan)
|
26 |
+
self.conv2 = conv3x3(out_chan, out_chan)
|
27 |
+
self.bn2 = nn.BatchNorm2d(out_chan)
|
28 |
+
self.relu = nn.ReLU(inplace=True)
|
29 |
+
self.downsample = None
|
30 |
+
if in_chan != out_chan or stride != 1:
|
31 |
+
self.downsample = nn.Sequential(
|
32 |
+
nn.Conv2d(in_chan, out_chan,
|
33 |
+
kernel_size=1, stride=stride, bias=False),
|
34 |
+
nn.BatchNorm2d(out_chan),
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
residual = self.conv1(x)
|
39 |
+
residual = F.relu(self.bn1(residual))
|
40 |
+
residual = self.conv2(residual)
|
41 |
+
residual = self.bn2(residual)
|
42 |
+
|
43 |
+
shortcut = x
|
44 |
+
if self.downsample is not None:
|
45 |
+
shortcut = self.downsample(x)
|
46 |
+
|
47 |
+
out = shortcut + residual
|
48 |
+
out = self.relu(out)
|
49 |
+
return out
|
50 |
+
|
51 |
+
|
52 |
+
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
|
53 |
+
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
|
54 |
+
for i in range(bnum-1):
|
55 |
+
layers.append(BasicBlock(out_chan, out_chan, stride=1))
|
56 |
+
return nn.Sequential(*layers)
|
57 |
+
|
58 |
+
|
59 |
+
class Resnet18(nn.Module):
|
60 |
+
def __init__(self, backbone_weight_path=None):
|
61 |
+
super(Resnet18, self).__init__()
|
62 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
63 |
+
bias=False)
|
64 |
+
self.bn1 = nn.BatchNorm2d(64)
|
65 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
66 |
+
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
|
67 |
+
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
|
68 |
+
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
|
69 |
+
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
|
70 |
+
self.init_weight(backbone_weight_path)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
x = self.conv1(x)
|
74 |
+
x = F.relu(self.bn1(x))
|
75 |
+
x = self.maxpool(x)
|
76 |
+
|
77 |
+
x = self.layer1(x)
|
78 |
+
feat8 = self.layer2(x) # 1/8
|
79 |
+
feat16 = self.layer3(feat8) # 1/16
|
80 |
+
feat32 = self.layer4(feat16) # 1/32
|
81 |
+
return feat8, feat16, feat32
|
82 |
+
|
83 |
+
def init_weight(self, backbone_weight_path=None):
|
84 |
+
if backbone_weight_path is None:
|
85 |
+
state_dict = modelzoo.load_url(resnet18_url)
|
86 |
+
else:
|
87 |
+
state_dict = torch.load(backbone_weight_path, weights_only=False)
|
88 |
+
self_state_dict = self.state_dict()
|
89 |
+
for k, v in state_dict.items():
|
90 |
+
if 'fc' in k: continue
|
91 |
+
self_state_dict.update({k: v})
|
92 |
+
self.load_state_dict(self_state_dict)
|
93 |
+
|
94 |
+
def get_params(self):
|
95 |
+
wd_params, nowd_params = [], []
|
96 |
+
for name, module in self.named_modules():
|
97 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
98 |
+
wd_params.append(module.weight)
|
99 |
+
if not module.bias is None:
|
100 |
+
nowd_params.append(module.bias)
|
101 |
+
elif isinstance(module, nn.BatchNorm2d):
|
102 |
+
nowd_params += list(module.parameters())
|
103 |
+
return wd_params, nowd_params
|
104 |
+
|
105 |
+
|
106 |
+
if __name__ == "__main__":
|
107 |
+
net = Resnet18()
|
108 |
+
x = torch.randn(16, 3, 224, 224)
|
109 |
+
out = net(x)
|
110 |
+
print(out[0].size())
|
111 |
+
print(out[1].size())
|
112 |
+
print(out[2].size())
|
113 |
+
net.get_params()
|
src/datasets/preprocess/extract_features/motion_processer.py
ADDED
@@ -0,0 +1,1420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Motion feature extractor
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import os.path as osp
|
6 |
+
import sys
|
7 |
+
import pickle
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from PIL import Image
|
13 |
+
import numpy as np
|
14 |
+
import cv2
|
15 |
+
import imageio
|
16 |
+
import pickle
|
17 |
+
import time
|
18 |
+
from decord import VideoReader # must after import torch
|
19 |
+
|
20 |
+
from rich.progress import track
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))))))
|
26 |
+
from src.datasets.preprocess.extract_features.face_segmentation import build_face_parser, get_face_mask, vis_parsing_maps
|
27 |
+
from src.thirdparty.liveportrait.src.utils.helper import load_model, concat_feat
|
28 |
+
from src.thirdparty.liveportrait.src.utils.io import load_image_rgb, resize_to_limit, load_video
|
29 |
+
from src.thirdparty.liveportrait.src.utils.video import get_fps, images2video, add_audio_to_video
|
30 |
+
from src.thirdparty.liveportrait.src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
|
31 |
+
|
32 |
+
from src.thirdparty.liveportrait.src.utils.cropper import Cropper
|
33 |
+
from src.thirdparty.liveportrait.src.utils.crop import prepare_paste_back, paste_back, paste_back_with_face_mask
|
34 |
+
from src.thirdparty.liveportrait.src.utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
|
35 |
+
from src.thirdparty.liveportrait.src.utils.helper import mkdir, basename, dct2device, is_image, calc_motion_multiplier
|
36 |
+
from src.utils.filter import smooth as ksmooth
|
37 |
+
from src.utils.filter import smooth_
|
38 |
+
|
39 |
+
from skimage.metrics import peak_signal_noise_ratio
|
40 |
+
import warnings
|
41 |
+
|
42 |
+
|
43 |
+
def psnr(imgs1, imgs2):
|
44 |
+
psnrs = []
|
45 |
+
for img1, img2 in zip(imgs1, imgs2):
|
46 |
+
psnr = peak_signal_noise_ratio(img1, img2, data_range=255)
|
47 |
+
psnrs.append(psnr)
|
48 |
+
return psnrs
|
49 |
+
|
50 |
+
|
51 |
+
def suffix(filename):
|
52 |
+
"""a.jpg -> jpg"""
|
53 |
+
pos = filename.rfind(".")
|
54 |
+
if pos == -1:
|
55 |
+
return ""
|
56 |
+
return filename[pos + 1:]
|
57 |
+
|
58 |
+
def dump(wfp, obj):
|
59 |
+
wd = osp.split(wfp)[0]
|
60 |
+
if wd != "" and not osp.exists(wd):
|
61 |
+
mkdir(wd)
|
62 |
+
|
63 |
+
_suffix = suffix(wfp)
|
64 |
+
if _suffix == "npy":
|
65 |
+
np.save(wfp, obj)
|
66 |
+
elif _suffix == "pkl":
|
67 |
+
pickle.dump(obj, open(wfp, "wb"))
|
68 |
+
else:
|
69 |
+
raise Exception("Unknown type: {}".format(_suffix))
|
70 |
+
|
71 |
+
def load(fp):
|
72 |
+
suffix_ = suffix(fp)
|
73 |
+
|
74 |
+
if suffix_ == "npy":
|
75 |
+
return np.load(fp)
|
76 |
+
elif suffix_ == "pkl":
|
77 |
+
return pickle.load(open(fp, "rb"))
|
78 |
+
else:
|
79 |
+
raise Exception(f"Unknown type: {suffix}")
|
80 |
+
|
81 |
+
|
82 |
+
def remove_suffix(filepath):
|
83 |
+
"""a/b/c.jpg -> a/b/c"""
|
84 |
+
return osp.join(osp.dirname(filepath), basename(filepath))
|
85 |
+
|
86 |
+
|
87 |
+
class MotionProcesser(object):
|
88 |
+
def __init__(self, cfg_path, device_id=0) -> None:
|
89 |
+
device = f"cuda:{device_id}"
|
90 |
+
cfg = OmegaConf.load(cfg_path)
|
91 |
+
print(f"Load cfg from {osp.realpath(cfg_path)} done.")
|
92 |
+
print(f"=============================== Driven CFG ===============================")
|
93 |
+
print(OmegaConf.to_yaml(cfg))
|
94 |
+
print(f"=============================== ========== ===============================")
|
95 |
+
models_config = OmegaConf.load(cfg.models_config)
|
96 |
+
|
97 |
+
# 1. init appearance feature extractor
|
98 |
+
self.appearance_feature_extractor = load_model(
|
99 |
+
cfg.appearance_feature_extractor_path,
|
100 |
+
models_config,
|
101 |
+
device,
|
102 |
+
'appearance_feature_extractor'
|
103 |
+
)
|
104 |
+
print(f'1. Load appearance_feature_extractor from {osp.realpath(cfg.appearance_feature_extractor_path)} done.')
|
105 |
+
|
106 |
+
# 2. # init motion extractor
|
107 |
+
self.motion_extractor = load_model(
|
108 |
+
cfg.motion_extractor_path,
|
109 |
+
models_config,
|
110 |
+
device,
|
111 |
+
'motion_extractor'
|
112 |
+
)
|
113 |
+
print(f'2. Load motion_extractor from {osp.realpath(cfg.motion_extractor_path)} done.')
|
114 |
+
|
115 |
+
# 3. init S and R
|
116 |
+
if cfg.stitching_retargeting_module_path is not None and osp.exists(cfg.stitching_retargeting_module_path):
|
117 |
+
self.stitching_retargeting_module = load_model(
|
118 |
+
cfg.stitching_retargeting_module_path,
|
119 |
+
models_config,
|
120 |
+
device,
|
121 |
+
'stitching_retargeting_module'
|
122 |
+
)
|
123 |
+
print(f'3. Load stitching_retargeting_module from {osp.realpath(cfg.stitching_retargeting_module_path)} done.')
|
124 |
+
else:
|
125 |
+
self.stitching_retargeting_module = None
|
126 |
+
|
127 |
+
# 4. init motion warper
|
128 |
+
self.warping_module = load_model(
|
129 |
+
cfg.warping_module_path,
|
130 |
+
models_config,
|
131 |
+
device,
|
132 |
+
'warping_module'
|
133 |
+
)
|
134 |
+
print(f"4. Load warping_module from {osp.realpath(cfg.warping_module_path)} done.")
|
135 |
+
|
136 |
+
# 5. init decoder
|
137 |
+
self.spade_generator = load_model(
|
138 |
+
cfg.spade_generator_path,
|
139 |
+
models_config,
|
140 |
+
device,
|
141 |
+
'spade_generator'
|
142 |
+
)
|
143 |
+
print(f"Load generator from {osp.realpath(cfg.spade_generator_path)} done.")
|
144 |
+
|
145 |
+
# # Optimize for inference
|
146 |
+
self.compile = cfg.flag_do_torch_compile
|
147 |
+
if self.compile:
|
148 |
+
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
|
149 |
+
self.warping_module = torch.compile(self.warping_module, mode='max-autotune')
|
150 |
+
self.spade_generator = torch.compile(self.spade_generator, mode='max-autotune')
|
151 |
+
|
152 |
+
# 6. init cropper
|
153 |
+
crop_cfg = OmegaConf.load(cfg.crop_cfg)
|
154 |
+
self.cropper = Cropper(crop_cfg=crop_cfg, image_type="human_face", device_id=device_id)
|
155 |
+
|
156 |
+
self.cfg = cfg
|
157 |
+
self.models_config = models_config
|
158 |
+
self.device = device
|
159 |
+
|
160 |
+
|
161 |
+
# 7. load crop mask
|
162 |
+
self.mask_crop = cv2.imread(cfg.mask_crop, cv2.IMREAD_COLOR)
|
163 |
+
# 8. load lib array
|
164 |
+
with open(cfg.lip_array, 'rb') as f:
|
165 |
+
self.lip_array = pickle.load(f)
|
166 |
+
|
167 |
+
# 9. load face parser
|
168 |
+
self.face_parser, self.to_tensor = build_face_parser(weight_path=cfg.face_parser_weight_path, resnet_weight_path=cfg.resnet_weight_path, device_id=device_id)
|
169 |
+
|
170 |
+
def inference_ctx(self):
|
171 |
+
ctx = torch.autocast(device_type=self.device[:4], dtype=torch.float16,
|
172 |
+
enabled=self.cfg.flag_use_half_precision)
|
173 |
+
return ctx
|
174 |
+
|
175 |
+
@torch.no_grad()
|
176 |
+
def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor:
|
177 |
+
""" get the appearance feature of the image by F
|
178 |
+
x: Bx3xHxW, normalized to 0~1
|
179 |
+
"""
|
180 |
+
with self.inference_ctx():
|
181 |
+
feature_3d = self.appearance_feature_extractor(x)
|
182 |
+
|
183 |
+
return feature_3d.float()
|
184 |
+
|
185 |
+
@torch.no_grad()
|
186 |
+
def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict:
|
187 |
+
""" get the implicit keypoint information
|
188 |
+
x: Bx3xHxW, normalized to 0~1
|
189 |
+
flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
|
190 |
+
return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
|
191 |
+
"""
|
192 |
+
with self.inference_ctx():
|
193 |
+
kp_info = self.motion_extractor(x)
|
194 |
+
|
195 |
+
if self.cfg.flag_use_half_precision:
|
196 |
+
# float the dict
|
197 |
+
for k, v in kp_info.items():
|
198 |
+
if isinstance(v, torch.Tensor):
|
199 |
+
kp_info[k] = v.float()
|
200 |
+
|
201 |
+
return kp_info
|
202 |
+
|
203 |
+
@torch.no_grad()
|
204 |
+
def refine_kp(self, kp_info):
|
205 |
+
bs = kp_info['exp'].shape[0]
|
206 |
+
kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None] # Bx1
|
207 |
+
kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None] # Bx1
|
208 |
+
kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None] # Bx1
|
209 |
+
kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3) # BxNx3
|
210 |
+
if 'kp' in kp_info.keys():
|
211 |
+
kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3) # BxNx3
|
212 |
+
|
213 |
+
return kp_info
|
214 |
+
|
215 |
+
@torch.no_grad()
|
216 |
+
def transform_keypoint(self, kp_info: dict):
|
217 |
+
"""
|
218 |
+
transform the implicit keypoints with the pose, shift, and expression deformation
|
219 |
+
kp: BxNx3
|
220 |
+
"""
|
221 |
+
kp = kp_info['kp'] # (bs, k, 3)
|
222 |
+
pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']
|
223 |
+
|
224 |
+
t, exp = kp_info['t'], kp_info['exp']
|
225 |
+
scale = kp_info['scale']
|
226 |
+
|
227 |
+
pitch = headpose_pred_to_degree(pitch)
|
228 |
+
yaw = headpose_pred_to_degree(yaw)
|
229 |
+
roll = headpose_pred_to_degree(roll)
|
230 |
+
|
231 |
+
bs = kp.shape[0]
|
232 |
+
if kp.ndim == 2:
|
233 |
+
num_kp = kp.shape[1] // 3 # Bx(num_kpx3)
|
234 |
+
else:
|
235 |
+
num_kp = kp.shape[1] # Bxnum_kpx3
|
236 |
+
|
237 |
+
rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3)
|
238 |
+
|
239 |
+
# Eqn.2: s * (R * x_c,s + exp) + t
|
240 |
+
kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
|
241 |
+
kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
|
242 |
+
kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty
|
243 |
+
|
244 |
+
return kp_transformed
|
245 |
+
|
246 |
+
@torch.no_grad()
|
247 |
+
def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
|
248 |
+
""" conduct the stitching
|
249 |
+
kp_source: Bxnum_kpx3
|
250 |
+
kp_driving: Bxnum_kpx3
|
251 |
+
"""
|
252 |
+
|
253 |
+
if self.stitching_retargeting_module is not None:
|
254 |
+
bs, num_kp = kp_source.shape[:2]
|
255 |
+
kp_driving_new = kp_driving.clone()
|
256 |
+
# stich
|
257 |
+
feat_stiching = concat_feat(kp_source, kp_driving_new)
|
258 |
+
delta = self.stitching_retargeting_module['stitching'](feat_stiching) # Bxnum_kpx3
|
259 |
+
|
260 |
+
delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) # 1x20x3
|
261 |
+
delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) # 1x1x2
|
262 |
+
|
263 |
+
kp_driving_new += delta_exp
|
264 |
+
kp_driving_new[..., :2] += delta_tx_ty
|
265 |
+
|
266 |
+
return kp_driving_new
|
267 |
+
|
268 |
+
return kp_driving
|
269 |
+
|
270 |
+
@torch.no_grad()
|
271 |
+
def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> dict[str, torch.Tensor]:
|
272 |
+
""" get the image after the warping of the implicit keypoints
|
273 |
+
feature_3d: Bx32x16x64x64, feature volume
|
274 |
+
kp_source: BxNx3
|
275 |
+
kp_driving: BxNx3
|
276 |
+
"""
|
277 |
+
# The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i))
|
278 |
+
with self.inference_ctx():
|
279 |
+
if self.compile:
|
280 |
+
# Mark the beginning of a new CUDA Graph step
|
281 |
+
torch.compiler.cudagraph_mark_step_begin()
|
282 |
+
# get decoder input
|
283 |
+
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
|
284 |
+
|
285 |
+
# print(f"=============================================================================")
|
286 |
+
# for out_key, out_value in ret_dct.items():
|
287 |
+
# if isinstance(out_value, str) or isinstance(out_value, int) or isinstance(out_value, float):
|
288 |
+
# print(f"{out_key}: {out_value}")
|
289 |
+
# elif isinstance(out_value, torch.Tensor):
|
290 |
+
# print(f"{out_key}: tensor shape {out_value.shape}, min: {torch.min(out_value)}, max: {torch.max(out_value)}, mean: {torch.mean(out_value)}, std: {torch.std(out_value)}")
|
291 |
+
# else:
|
292 |
+
# print(f"{out_key}: data type {type(out_value)}")
|
293 |
+
# decode
|
294 |
+
ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
|
295 |
+
|
296 |
+
# float the dict
|
297 |
+
if self.cfg.flag_use_half_precision:
|
298 |
+
for k, v in ret_dct.items():
|
299 |
+
if isinstance(v, torch.Tensor):
|
300 |
+
ret_dct[k] = v.float()
|
301 |
+
|
302 |
+
return ret_dct
|
303 |
+
|
304 |
+
def parse_output(self, out: torch.Tensor) -> np.ndarray:
|
305 |
+
""" construct the output as standard
|
306 |
+
return: 1xHxWx3, uint8
|
307 |
+
"""
|
308 |
+
out = np.transpose(out.cpu().numpy(), [0, 2, 3, 1]) # 1x3xHxW -> 1xHxWx3
|
309 |
+
out = np.clip(out, 0, 1) # clip to 0~1
|
310 |
+
out = np.clip(out * 255, 0, 255).astype(np.uint8) # 0~1 -> 0~255
|
311 |
+
|
312 |
+
return out
|
313 |
+
|
314 |
+
@torch.no_grad()
|
315 |
+
def calc_combined_eye_ratio(self, c_d_eyes_i, source_lmk):
|
316 |
+
c_s_eyes = calc_eye_close_ratio(source_lmk[None])
|
317 |
+
c_s_eyes_tensor = torch.from_numpy(c_s_eyes).float().to(self.device)
|
318 |
+
c_d_eyes_i_tensor = torch.Tensor([c_d_eyes_i[0][0]]).reshape(1, 1).to(self.device)
|
319 |
+
# [c_s,eyes, c_d,eyes,i]
|
320 |
+
combined_eye_ratio_tensor = torch.cat([c_s_eyes_tensor, c_d_eyes_i_tensor], dim=1)
|
321 |
+
return combined_eye_ratio_tensor
|
322 |
+
|
323 |
+
@torch.no_grad()
|
324 |
+
def calc_combined_lip_ratio(self, c_d_lip_i, source_lmk):
|
325 |
+
c_s_lip = calc_lip_close_ratio(source_lmk[None])
|
326 |
+
c_s_lip_tensor = torch.from_numpy(c_s_lip).float().to(self.device)
|
327 |
+
c_d_lip_i_tensor = torch.Tensor([c_d_lip_i[0]]).to(self.device).reshape(1, 1) # 1x1
|
328 |
+
# [c_s,lip, c_d,lip,i]
|
329 |
+
combined_lip_ratio_tensor = torch.cat([c_s_lip_tensor, c_d_lip_i_tensor], dim=1) # 1x2
|
330 |
+
return combined_lip_ratio_tensor
|
331 |
+
|
332 |
+
def calc_ratio(self, lmk_lst):
|
333 |
+
input_eye_ratio_lst = []
|
334 |
+
input_lip_ratio_lst = []
|
335 |
+
for lmk in lmk_lst:
|
336 |
+
# for eyes retargeting
|
337 |
+
input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
|
338 |
+
# for lip retargeting
|
339 |
+
input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
|
340 |
+
return input_eye_ratio_lst, input_lip_ratio_lst
|
341 |
+
|
342 |
+
@torch.no_grad()
|
343 |
+
def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor:
|
344 |
+
"""
|
345 |
+
kp_source: BxNx3
|
346 |
+
lip_close_ratio: Bx2
|
347 |
+
Return: Bx(3*num_kp)
|
348 |
+
"""
|
349 |
+
feat_lip = concat_feat(kp_source, lip_close_ratio)
|
350 |
+
|
351 |
+
delta = self.stitching_retargeting_module['lip'](feat_lip)
|
352 |
+
|
353 |
+
return delta.reshape(-1, kp_source.shape[1], 3)
|
354 |
+
|
355 |
+
@torch.no_grad()
|
356 |
+
def retarget_eye(self, kp_source: torch.Tensor, eye_close_ratio: torch.Tensor) -> torch.Tensor:
|
357 |
+
"""
|
358 |
+
kp_source: BxNx3
|
359 |
+
eye_close_ratio: Bx3
|
360 |
+
Return: Bx(3*num_kp)
|
361 |
+
"""
|
362 |
+
feat_eye = concat_feat(kp_source, eye_close_ratio)
|
363 |
+
|
364 |
+
delta = self.stitching_retargeting_module['eye'](feat_eye)
|
365 |
+
|
366 |
+
return delta.reshape(-1, kp_source.shape[1], 3)
|
367 |
+
|
368 |
+
def crop_image(self, img, do_crop=False):
|
369 |
+
######## process source info ########
|
370 |
+
if do_crop:
|
371 |
+
crop_info = self.cropper.crop_source_image(img, self.cropper.crop_cfg)
|
372 |
+
if crop_info is None:
|
373 |
+
raise Exception("No face detected in the source image!")
|
374 |
+
lmk = crop_info['lmk_crop']
|
375 |
+
img_crop_256x256 = crop_info['img_crop_256x256']
|
376 |
+
else:
|
377 |
+
crop_info = None
|
378 |
+
lmk = self.cropper.calc_lmk_from_cropped_image(img)
|
379 |
+
img_crop_256x256 = cv2.resize(img, (256, 256)) # force to resize to 256x256
|
380 |
+
return img_crop_256x256, lmk, crop_info
|
381 |
+
|
382 |
+
def crop_source_video(self, img_lst, do_crop=False):
|
383 |
+
if do_crop:
|
384 |
+
ret_s = self.cropper.crop_source_video(img_lst, self.cropper.crop_cfg)
|
385 |
+
print(f'Source video is cropped, {len(ret_s["frame_crop_lst"])} frames are processed.')
|
386 |
+
img_crop_256x256_lst, lmk_crop_lst, M_c2o_lst = ret_s['frame_crop_lst'], ret_s['lmk_crop_lst'], ret_s['M_c2o_lst']
|
387 |
+
else:
|
388 |
+
M_c2o_lst = None
|
389 |
+
lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(img_lst)
|
390 |
+
img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in img_lst] # force to resize to 256x256
|
391 |
+
return img_crop_256x256_lst, lmk_crop_lst, M_c2o_lst
|
392 |
+
|
393 |
+
def crop_driving_videos(self, img_lst, do_crop=False):
|
394 |
+
if do_crop:
|
395 |
+
ret_d = self.cropper.crop_driving_video(img_lst)
|
396 |
+
print(f'Driving video is cropped, {len(ret_d["frame_crop_lst"])} frames are processed.')
|
397 |
+
img_crop_lst, lmk_crop_lst = ret_d['frame_crop_lst'], ret_d['lmk_crop_lst']
|
398 |
+
img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in img_lst]
|
399 |
+
else:
|
400 |
+
lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(img_lst)
|
401 |
+
img_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in img_lst] # force to resize to 256x256
|
402 |
+
return img_crop_256x256_lst, lmk_crop_lst
|
403 |
+
|
404 |
+
def prepare_source(self, src_img):
|
405 |
+
""" construct the input as standard
|
406 |
+
img: HxWx3, uint8, 256x256
|
407 |
+
"""
|
408 |
+
# processing source image to tensor
|
409 |
+
h, w = src_img.shape[:2]
|
410 |
+
if h != self.cfg.input_height or w != self.cfg.input_width:
|
411 |
+
x = cv2.resize(src_img, (self.cfg.input_width, self.cfg.input_height))
|
412 |
+
else:
|
413 |
+
x = src_img.copy()
|
414 |
+
|
415 |
+
if x.ndim == 3:
|
416 |
+
x = x[np.newaxis].astype(np.float32) / 255. # HxWx3 -> 1xHxWx3, normalized to 0~1
|
417 |
+
elif x.ndim == 4:
|
418 |
+
x = x.astype(np.float32) / 255. # BxHxWx3, normalized to 0~1
|
419 |
+
else:
|
420 |
+
raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
|
421 |
+
|
422 |
+
x = np.clip(x, 0, 1) # clip to 0~1
|
423 |
+
x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
|
424 |
+
x = x.to(self.device)
|
425 |
+
|
426 |
+
# extract features
|
427 |
+
I_s = x
|
428 |
+
f_s = self.extract_feature_3d(I_s)
|
429 |
+
x_s_info = self.get_kp_info(I_s)
|
430 |
+
|
431 |
+
return f_s, x_s_info
|
432 |
+
|
433 |
+
def process_clips(self, clips):
|
434 |
+
""" construct the input as standard
|
435 |
+
clips: NxBxHxWx3, uint8
|
436 |
+
"""
|
437 |
+
# resize to 256 x 256
|
438 |
+
imgs = []
|
439 |
+
for img in clips:
|
440 |
+
h, w = img.shape[:2]
|
441 |
+
if h != self.cfg.input_height or w != self.cfg.input_width:
|
442 |
+
img = cv2.resize(img, (self.cfg.input_width, self.cfg.input_height))
|
443 |
+
else:
|
444 |
+
img = img.copy()
|
445 |
+
imgs.append(img)
|
446 |
+
|
447 |
+
# processing video frames to tensor
|
448 |
+
if isinstance(imgs, list):
|
449 |
+
_imgs = np.array(imgs)[..., np.newaxis] # TxHxWx3x1
|
450 |
+
elif isinstance(imgs, np.ndarray):
|
451 |
+
_imgs = imgs
|
452 |
+
else:
|
453 |
+
raise ValueError(f'imgs type error: {type(imgs)}')
|
454 |
+
|
455 |
+
y = _imgs.astype(np.float32) / 255.
|
456 |
+
y = np.clip(y, 0, 1) # clip to 0~1
|
457 |
+
y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
|
458 |
+
y = y.to(self.device)
|
459 |
+
|
460 |
+
return y
|
461 |
+
|
462 |
+
def prepare_driving_videos(self, vid_frames, feat_type="tensor"):
|
463 |
+
""" get driving kp infos
|
464 |
+
vid_frames: image list of HxWx3, uint8
|
465 |
+
"""
|
466 |
+
# extract features
|
467 |
+
total_len = len(vid_frames)
|
468 |
+
kp_infos = {"pitch": [], "yaw": [], "roll": [], "t": [], "exp": [], "scale": [], "kp": []}
|
469 |
+
for start_idx in range(0, total_len, self.cfg.batch_size):
|
470 |
+
frames = vid_frames[start_idx: min(start_idx + self.cfg.batch_size, total_len)]
|
471 |
+
frames = self.process_clips(frames).squeeze(1)
|
472 |
+
kp_info = self.get_kp_info(frames)
|
473 |
+
|
474 |
+
for k, v in kp_info.items():
|
475 |
+
kp_infos[k].append(v)
|
476 |
+
|
477 |
+
# combine the kp_infos
|
478 |
+
for k, v in kp_infos.items():
|
479 |
+
kp_infos[k] = torch.cat(v, dim=0)
|
480 |
+
|
481 |
+
if feat_type == "np":
|
482 |
+
for k, v in kp_infos.items():
|
483 |
+
kp_infos[k] = v.cpu().numpy()
|
484 |
+
|
485 |
+
return kp_infos
|
486 |
+
|
487 |
+
def get_driving_template(self, kp_infos, smooth=False, dtype="pt_tensor"):
|
488 |
+
kp_infos = self.refine_kp(kp_infos)
|
489 |
+
motion_list = []
|
490 |
+
n_frames = len(kp_infos["exp"])
|
491 |
+
for idx in range(n_frames):
|
492 |
+
exp = kp_infos["exp"][idx]
|
493 |
+
scale = kp_infos["scale"][idx]
|
494 |
+
t = kp_infos["t"][idx]
|
495 |
+
pitch = kp_infos["pitch"][idx]
|
496 |
+
yaw = kp_infos["yaw"][idx]
|
497 |
+
roll = kp_infos["roll"][idx]
|
498 |
+
|
499 |
+
R = get_rotation_matrix(pitch, yaw, roll)
|
500 |
+
R = R.reshape(1, 3, 3)
|
501 |
+
|
502 |
+
exp = exp.reshape(1, 21, 3)
|
503 |
+
scale = scale.reshape(1, 1)
|
504 |
+
t = t.reshape(1, 3)
|
505 |
+
pitch = pitch.reshape(1, 1)
|
506 |
+
yaw = yaw.reshape(1, 1)
|
507 |
+
roll = roll.reshape(1, 1)
|
508 |
+
|
509 |
+
if dtype == "np":
|
510 |
+
R = R.cpu().numpy().astype(np.float32)
|
511 |
+
exp = exp.cpu().numpy().astype(np.float32)
|
512 |
+
scale = scale.cpu().numpy().astype(np.float32)
|
513 |
+
t = t.cpu().numpy().astype(np.float32)
|
514 |
+
pitch = pitch.cpu().numpy().astype(np.float32)
|
515 |
+
yaw = yaw.cpu().numpy().astype(np.float32)
|
516 |
+
roll = roll.cpu().numpy().astype(np.float32)
|
517 |
+
|
518 |
+
motion_list.append(
|
519 |
+
{"exp": exp, "scale": scale, "R": R, "t": t, "pitch": pitch, "yaw": yaw, "roll": roll}
|
520 |
+
)
|
521 |
+
tgt_motion = {'n_frames': n_frames, 'output_fps': 25, 'motion': motion_list}
|
522 |
+
|
523 |
+
if smooth:
|
524 |
+
print("Smoothing motion sequence...")
|
525 |
+
tgt_motion = smooth_(tgt_motion, method="ema")
|
526 |
+
return tgt_motion
|
527 |
+
|
528 |
+
@torch.no_grad()
|
529 |
+
def update_delta_new_eyeball_direction(self, eyeball_direction_x, eyeball_direction_y, delta_new, **kwargs):
|
530 |
+
if eyeball_direction_x > 0:
|
531 |
+
delta_new[0, 11, 0] += eyeball_direction_x * 0.0007
|
532 |
+
delta_new[0, 15, 0] += eyeball_direction_x * 0.001
|
533 |
+
else:
|
534 |
+
delta_new[0, 11, 0] += eyeball_direction_x * 0.001
|
535 |
+
delta_new[0, 15, 0] += eyeball_direction_x * 0.0007
|
536 |
+
|
537 |
+
delta_new[0, 11, 1] += eyeball_direction_y * -0.001
|
538 |
+
delta_new[0, 15, 1] += eyeball_direction_y * -0.001
|
539 |
+
blink = -eyeball_direction_y / 2.
|
540 |
+
|
541 |
+
delta_new[0, 11, 1] += blink * -0.001
|
542 |
+
delta_new[0, 13, 1] += blink * 0.0003
|
543 |
+
delta_new[0, 15, 1] += blink * -0.001
|
544 |
+
delta_new[0, 16, 1] += blink * 0.0003
|
545 |
+
|
546 |
+
return delta_new
|
547 |
+
|
548 |
+
def driven(self, f_s, x_s_info, s_lmk, c_s_eyes_lst, kp_infos, c_d_eyes_lst=None, c_d_lip_lst=None, smooth=False):
|
549 |
+
# source kp info
|
550 |
+
x_d_i_news=[]
|
551 |
+
x_ss=[]
|
552 |
+
f_ss=[]
|
553 |
+
x_s_info = self.refine_kp(x_s_info)
|
554 |
+
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
555 |
+
x_s = self.transform_keypoint(x_s_info)
|
556 |
+
x_c_s = x_s_info["kp"]
|
557 |
+
|
558 |
+
# driving kp infos
|
559 |
+
driving_template_dct = self.get_driving_template(kp_infos, smooth)
|
560 |
+
n_frames = driving_template_dct['n_frames']
|
561 |
+
|
562 |
+
# driving params
|
563 |
+
flag_normalize_lip = self.cfg.flag_normalize_lip
|
564 |
+
flag_relative_motion = self.cfg.flag_relative_motion
|
565 |
+
flag_source_video_eye_retargeting = self.cfg.flag_source_video_eye_retargeting
|
566 |
+
lip_normalize_threshold = self.cfg.lip_normalize_threshold
|
567 |
+
source_video_eye_retargeting_threshold = self.cfg.source_video_eye_retargeting_threshold
|
568 |
+
animation_region = self.cfg.animation_region
|
569 |
+
driving_option = self.cfg.driving_option
|
570 |
+
flag_stitching = self.cfg.flag_stitching
|
571 |
+
flag_eye_retargeting = self.cfg.flag_eye_retargeting
|
572 |
+
flag_lip_retargeting = self.cfg.flag_lip_retargeting
|
573 |
+
driving_multiplier = self.cfg.driving_multiplier
|
574 |
+
lib_multiplier = self.cfg.lib_multiplier
|
575 |
+
|
576 |
+
# let lip-open scalar to be 0 at first
|
577 |
+
lip_delta_before_animation, eye_delta_before_animation = None, None
|
578 |
+
if flag_normalize_lip and flag_relative_motion and s_lmk is not None:
|
579 |
+
c_d_lip_before_animation = [0.]
|
580 |
+
combined_lip_ratio_tensor_before_animation = self.calc_combined_lip_ratio(c_d_lip_before_animation, s_lmk)
|
581 |
+
if combined_lip_ratio_tensor_before_animation[0][0] >= lip_normalize_threshold:
|
582 |
+
lip_delta_before_animation = self.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
|
583 |
+
|
584 |
+
# let eye-open scalar to be the same as the first frame if the latter is eye-open state
|
585 |
+
if flag_source_video_eye_retargeting and s_lmk is not None:
|
586 |
+
combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0]
|
587 |
+
c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]]
|
588 |
+
if c_d_eye_before_animation_frame_zero[0][0] < source_video_eye_retargeting_threshold:
|
589 |
+
c_d_eye_before_animation_frame_zero = [[0.39]]
|
590 |
+
combined_eye_ratio_tensor_before_animation = self.calc_combined_eye_ratio(c_d_eye_before_animation_frame_zero, s_lmk)
|
591 |
+
eye_delta_before_animation = self.retarget_eye(x_s, combined_eye_ratio_tensor_before_animation)
|
592 |
+
|
593 |
+
# animate
|
594 |
+
I_p_lst = []
|
595 |
+
for i in range(n_frames):
|
596 |
+
x_d_i_info = driving_template_dct['motion'][i]
|
597 |
+
x_d_i_info = dct2device(x_d_i_info, self.device)
|
598 |
+
# R
|
599 |
+
R_d_i = x_d_i_info['R']
|
600 |
+
if i == 0: # cache the first frame
|
601 |
+
R_d_0 = R_d_i
|
602 |
+
x_d_0_info = x_d_i_info.copy()
|
603 |
+
|
604 |
+
# enhance lip
|
605 |
+
# if i > 0:
|
606 |
+
# for lip_idx in [6, 12, 14, 17, 19, 20]:
|
607 |
+
# x_d_i_info['exp'][:, lip_idx, :] = x_d_0_info['exp'][:, lip_idx, :] + (x_d_i_info['exp'][:, lip_idx, :] - x_d_0_info['exp'][:, lip_idx, :]) * lib_multiplier
|
608 |
+
|
609 |
+
# normalize eye_ball, TODO
|
610 |
+
x_d_i_info['exp'] = self.update_delta_new_eyeball_direction(0, -5, x_d_i_info['exp'])
|
611 |
+
|
612 |
+
# debug
|
613 |
+
#print(f"frame {i:03d}, src scale {x_s_info['scale']}, 0 scale {x_d_0_info['scale']}, i scale {x_d_i_info['scale']}")
|
614 |
+
# delta
|
615 |
+
delta_new = x_s_info['exp'].clone()
|
616 |
+
if flag_relative_motion:
|
617 |
+
# R
|
618 |
+
if animation_region == "all" or animation_region == "pose":
|
619 |
+
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
|
620 |
+
else:
|
621 |
+
R_new = R_s
|
622 |
+
|
623 |
+
# exp
|
624 |
+
if animation_region == "all" or animation_region == "exp":
|
625 |
+
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
626 |
+
elif animation_region == "lip":
|
627 |
+
for lip_idx in [6, 12, 14, 17, 19, 20]:
|
628 |
+
delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :]
|
629 |
+
elif animation_region == "eyes":
|
630 |
+
for eyes_idx in [11, 13, 15, 16, 18]:
|
631 |
+
delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, eyes_idx, :]
|
632 |
+
|
633 |
+
# scale
|
634 |
+
if animation_region == "all":
|
635 |
+
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
636 |
+
else:
|
637 |
+
scale_new = x_s_info['scale']
|
638 |
+
|
639 |
+
# translation
|
640 |
+
if animation_region == "all" or animation_region == "pose":
|
641 |
+
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
|
642 |
+
else:
|
643 |
+
t_new = x_s_info['t']
|
644 |
+
else:
|
645 |
+
# R
|
646 |
+
if animation_region == "all" or animation_region == "pose":
|
647 |
+
R_new = R_d_i
|
648 |
+
else:
|
649 |
+
R_new = R_s
|
650 |
+
|
651 |
+
# exp
|
652 |
+
if animation_region == "all" or animation_region == "exp":
|
653 |
+
EYE_IDX=[1,2,6,11,12,13,14,15,16,17,18,19,20]
|
654 |
+
delta_new[:, EYE_IDX, :] = x_d_i_info['exp'][:, EYE_IDX, :]
|
655 |
+
# for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
|
656 |
+
# delta_new[:, idx, :] = x_d_i_info['exp'][:, idx, :]
|
657 |
+
delta_new[:, 3:5, 1] = x_d_i_info['exp'][:, 3:5, 1]
|
658 |
+
delta_new[:, 5, 2] = x_d_i_info['exp'][:, 5, 2]
|
659 |
+
delta_new[:, 8, 2] = x_d_i_info['exp'][:, 8, 2]
|
660 |
+
delta_new[:, 9, 1:] = x_d_i_info['exp'][:, 9, 1:]
|
661 |
+
elif animation_region == "lip":
|
662 |
+
for lip_idx in [6, 12, 14, 17, 19, 20]:
|
663 |
+
delta_new[:, lip_idx, :] = x_d_i_info['exp'][:, lip_idx, :]
|
664 |
+
elif animation_region == "eyes":
|
665 |
+
for eyes_idx in [11, 13, 15, 16, 18]:
|
666 |
+
delta_new[:, eyes_idx, :] = x_d_i_info['exp'][:, eyes_idx, :]
|
667 |
+
|
668 |
+
# scale
|
669 |
+
scale_new = x_s_info['scale']
|
670 |
+
|
671 |
+
# translation
|
672 |
+
if animation_region == "all" or animation_region == "pose":
|
673 |
+
t_new = x_d_i_info['t']
|
674 |
+
else:
|
675 |
+
t_new = x_s_info['t']
|
676 |
+
|
677 |
+
t_new[..., 2].fill_(0) # zero tz
|
678 |
+
|
679 |
+
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
|
680 |
+
|
681 |
+
if flag_relative_motion and driving_option == "expression-friendly":
|
682 |
+
if i == 0:
|
683 |
+
x_d_0_new = x_d_i_new
|
684 |
+
motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new)
|
685 |
+
x_d_diff = (x_d_i_new - x_d_0_new) * motion_multiplier
|
686 |
+
x_d_i_new = x_d_diff + x_s
|
687 |
+
|
688 |
+
# Algorithm 1 in Liveportrait:
|
689 |
+
if not flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting:
|
690 |
+
# without stitching or retargeting
|
691 |
+
if flag_normalize_lip and lip_delta_before_animation is not None:
|
692 |
+
x_d_i_new += lip_delta_before_animation
|
693 |
+
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
|
694 |
+
x_d_i_new += eye_delta_before_animation
|
695 |
+
else:
|
696 |
+
pass
|
697 |
+
elif flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting:
|
698 |
+
# with stitching and without retargeting
|
699 |
+
if flag_normalize_lip and lip_delta_before_animation is not None:
|
700 |
+
x_d_i_new = self.stitching(x_s, x_d_i_new) + lip_delta_before_animation
|
701 |
+
else:
|
702 |
+
x_d_i_new = self.stitching(x_s, x_d_i_new)
|
703 |
+
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
|
704 |
+
x_d_i_new += eye_delta_before_animation
|
705 |
+
else:
|
706 |
+
eyes_delta, lip_delta = None, None
|
707 |
+
if flag_eye_retargeting and s_lmk is not None and c_d_eyes_lst is not None:
|
708 |
+
c_d_eyes_i = c_d_eyes_lst[i]
|
709 |
+
combined_eye_ratio_tensor = self.calc_combined_eye_ratio(c_d_eyes_i, s_lmk)
|
710 |
+
eyes_delta = self.retarget_eye(x_s, combined_eye_ratio_tensor)
|
711 |
+
|
712 |
+
if flag_lip_retargeting and s_lmk is not None and c_d_lip_lst is not None:
|
713 |
+
c_d_lip_i = c_d_lip_lst[i]
|
714 |
+
combined_lip_ratio_tensor = self.calc_combined_lip_ratio(c_d_lip_i, s_lmk)
|
715 |
+
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
716 |
+
lip_delta = self.retarget_lip(x_s, combined_lip_ratio_tensor)
|
717 |
+
|
718 |
+
if flag_relative_motion: # use x_s
|
719 |
+
x_d_i_new = x_s + \
|
720 |
+
(eyes_delta if eyes_delta is not None else 0) + \
|
721 |
+
(lip_delta if lip_delta is not None else 0)
|
722 |
+
else: # use x_d,i
|
723 |
+
x_d_i_new = x_d_i_new + \
|
724 |
+
(eyes_delta if eyes_delta is not None else 0) + \
|
725 |
+
(lip_delta if lip_delta is not None else 0)
|
726 |
+
|
727 |
+
if flag_stitching:
|
728 |
+
x_d_i_new = self.stitching(x_s, x_d_i_new)
|
729 |
+
|
730 |
+
x_d_i_new = x_s + (x_d_i_new - x_s) * driving_multiplier
|
731 |
+
x_d_i_news.append(x_d_i_new)
|
732 |
+
f_s_s= f_s.expand(n_frames, *f_s.shape[1:])
|
733 |
+
x_s_s = x_s.expand(n_frames, *x_s.shape[1:])
|
734 |
+
x_d_i_new = torch.cat(x_d_i_news, dim=0)
|
735 |
+
for start in range(0, n_frames, 100):
|
736 |
+
end = min(start + 100,n_frames)
|
737 |
+
with torch.no_grad(), torch.autocast('cuda'):
|
738 |
+
out = self.warp_decode(f_s_s[start:end], x_s_s[start:end], x_d_i_new[start:end])
|
739 |
+
I_p_lst.append(out['out'])
|
740 |
+
I_p=torch.cat(I_p_lst, dim=0)
|
741 |
+
I_p_i = self.parse_output(I_p)
|
742 |
+
return I_p_i
|
743 |
+
|
744 |
+
def driven_debug(self, f_s, x_s_info, s_lmk, c_s_eyes_lst, driving_template_dct, c_d_eyes_lst=None, c_d_lip_lst=None):
|
745 |
+
# source kp info
|
746 |
+
x_s_info = self.refine_kp(x_s_info)
|
747 |
+
x_c_s = x_s_info["kp"]
|
748 |
+
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
749 |
+
x_s = self.transform_keypoint(x_s_info)
|
750 |
+
|
751 |
+
n_frames = driving_template_dct['n_frames']
|
752 |
+
|
753 |
+
# driving params
|
754 |
+
flag_normalize_lip = self.cfg.flag_normalize_lip
|
755 |
+
flag_relative_motion = self.cfg.flag_relative_motion
|
756 |
+
flag_source_video_eye_retargeting = self.cfg.flag_source_video_eye_retargeting
|
757 |
+
lip_normalize_threshold = self.cfg.lip_normalize_threshold
|
758 |
+
source_video_eye_retargeting_threshold = self.cfg.source_video_eye_retargeting_threshold
|
759 |
+
animation_region = self.cfg.animation_region
|
760 |
+
driving_option = self.cfg.driving_option
|
761 |
+
flag_stitching = self.cfg.flag_stitching
|
762 |
+
flag_eye_retargeting = self.cfg.flag_eye_retargeting
|
763 |
+
flag_lip_retargeting = self.cfg.flag_lip_retargeting
|
764 |
+
driving_multiplier = self.cfg.driving_multiplier
|
765 |
+
|
766 |
+
# let lip-open scalar to be 0 at first
|
767 |
+
lip_delta_before_animation, eye_delta_before_animation = None, None
|
768 |
+
if flag_normalize_lip and flag_relative_motion and s_lmk is not None:
|
769 |
+
c_d_lip_before_animation = [0.]
|
770 |
+
combined_lip_ratio_tensor_before_animation = self.calc_combined_lip_ratio(c_d_lip_before_animation, s_lmk)
|
771 |
+
if combined_lip_ratio_tensor_before_animation[0][0] >= lip_normalize_threshold:
|
772 |
+
lip_delta_before_animation = self.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
|
773 |
+
|
774 |
+
# let eye-open scalar to be the same as the first frame if the latter is eye-open state
|
775 |
+
if flag_source_video_eye_retargeting and s_lmk is not None:
|
776 |
+
combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0]
|
777 |
+
c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]]
|
778 |
+
if c_d_eye_before_animation_frame_zero[0][0] < source_video_eye_retargeting_threshold:
|
779 |
+
c_d_eye_before_animation_frame_zero = [[0.39]]
|
780 |
+
combined_eye_ratio_tensor_before_animation = self.calc_combined_eye_ratio(c_d_eye_before_animation_frame_zero, s_lmk)
|
781 |
+
eye_delta_before_animation = self.retarget_eye(x_s, combined_eye_ratio_tensor_before_animation)
|
782 |
+
|
783 |
+
# animate
|
784 |
+
I_p_lst = []
|
785 |
+
for i in range(n_frames):
|
786 |
+
x_d_i_info = driving_template_dct['motion'][i]
|
787 |
+
x_d_i_info = dct2device(x_d_i_info, self.device)
|
788 |
+
# R
|
789 |
+
R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys
|
790 |
+
if i == 0: # cache the first frame
|
791 |
+
R_d_0 = R_d_i
|
792 |
+
x_d_0_info = x_d_i_info.copy()
|
793 |
+
|
794 |
+
# debug
|
795 |
+
#print(f"frame {i:03d}, src scale {x_s_info['scale']}, 0 scale {x_d_0_info['scale']}, i scale {x_d_i_info['scale']}")
|
796 |
+
# delta
|
797 |
+
delta_new = x_s_info['exp'].clone()
|
798 |
+
if flag_relative_motion:
|
799 |
+
# R
|
800 |
+
if animation_region == "all" or animation_region == "pose":
|
801 |
+
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
|
802 |
+
else:
|
803 |
+
R_new = R_s
|
804 |
+
|
805 |
+
# exp
|
806 |
+
if animation_region == "all" or animation_region == "exp":
|
807 |
+
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
808 |
+
elif animation_region == "lip":
|
809 |
+
for lip_idx in [6, 12, 14, 17, 19, 20]:
|
810 |
+
delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, lip_idx, :]
|
811 |
+
elif animation_region == "eyes":
|
812 |
+
for eyes_idx in [11, 13, 15, 16, 18]:
|
813 |
+
delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, eyes_idx, :]
|
814 |
+
|
815 |
+
# scale
|
816 |
+
if animation_region == "all":
|
817 |
+
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
818 |
+
else:
|
819 |
+
scale_new = x_s_info['scale']
|
820 |
+
|
821 |
+
# translation
|
822 |
+
if animation_region == "all" or animation_region == "pose":
|
823 |
+
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
|
824 |
+
else:
|
825 |
+
t_new = x_s_info['t']
|
826 |
+
else:
|
827 |
+
# R
|
828 |
+
if animation_region == "all" or animation_region == "pose":
|
829 |
+
R_new = R_d_i
|
830 |
+
else:
|
831 |
+
R_new = R_s
|
832 |
+
|
833 |
+
# exp
|
834 |
+
if animation_region == "all" or animation_region == "exp":
|
835 |
+
for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
|
836 |
+
delta_new[:, idx, :] = x_d_i_info['exp'][:, idx, :]
|
837 |
+
delta_new[:, 3:5, 1] = x_d_i_info['exp'][:, 3:5, 1]
|
838 |
+
delta_new[:, 5, 2] = x_d_i_info['exp'][:, 5, 2]
|
839 |
+
delta_new[:, 8, 2] = x_d_i_info['exp'][:, 8, 2]
|
840 |
+
delta_new[:, 9, 1:] = x_d_i_info['exp'][:, 9, 1:]
|
841 |
+
elif animation_region == "lip":
|
842 |
+
for lip_idx in [6, 12, 14, 17, 19, 20]:
|
843 |
+
delta_new[:, lip_idx, :] = x_d_i_info['exp'][:, lip_idx, :]
|
844 |
+
elif animation_region == "eyes":
|
845 |
+
for eyes_idx in [11, 13, 15, 16, 18]:
|
846 |
+
delta_new[:, eyes_idx, :] = x_d_i_info['exp'][:, eyes_idx, :]
|
847 |
+
|
848 |
+
# scale
|
849 |
+
scale_new = x_s_info['scale']
|
850 |
+
|
851 |
+
# translation
|
852 |
+
if animation_region == "all" or animation_region == "pose":
|
853 |
+
t_new = x_d_i_info['t']
|
854 |
+
else:
|
855 |
+
t_new = x_s_info['t']
|
856 |
+
|
857 |
+
t_new[..., 2].fill_(0) # zero tz
|
858 |
+
|
859 |
+
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
|
860 |
+
|
861 |
+
if flag_relative_motion and driving_option == "expression-friendly":
|
862 |
+
if i == 0:
|
863 |
+
x_d_0_new = x_d_i_new
|
864 |
+
motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new)
|
865 |
+
x_d_diff = (x_d_i_new - x_d_0_new) * motion_multiplier
|
866 |
+
x_d_i_new = x_d_diff + x_s
|
867 |
+
|
868 |
+
# Algorithm 1 in Liveportrait:
|
869 |
+
if not flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting:
|
870 |
+
# without stitching or retargeting
|
871 |
+
if flag_normalize_lip and lip_delta_before_animation is not None:
|
872 |
+
x_d_i_new += lip_delta_before_animation
|
873 |
+
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
|
874 |
+
x_d_i_new += eye_delta_before_animation
|
875 |
+
else:
|
876 |
+
pass
|
877 |
+
elif flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting:
|
878 |
+
# with stitching and without retargeting
|
879 |
+
if flag_normalize_lip and lip_delta_before_animation is not None:
|
880 |
+
x_d_i_new = self.stitching(x_s, x_d_i_new) + lip_delta_before_animation
|
881 |
+
else:
|
882 |
+
x_d_i_new = self.stitching(x_s, x_d_i_new)
|
883 |
+
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
|
884 |
+
x_d_i_new += eye_delta_before_animation
|
885 |
+
else:
|
886 |
+
eyes_delta, lip_delta = None, None
|
887 |
+
if flag_eye_retargeting and s_lmk is not None and c_d_eyes_lst is not None:
|
888 |
+
c_d_eyes_i = c_d_eyes_lst[i]
|
889 |
+
combined_eye_ratio_tensor = self.calc_combined_eye_ratio(c_d_eyes_i, s_lmk)
|
890 |
+
eyes_delta = self.retarget_eye(x_s, combined_eye_ratio_tensor)
|
891 |
+
|
892 |
+
if flag_lip_retargeting and s_lmk is not None and c_d_lip_lst is not None:
|
893 |
+
c_d_lip_i = c_d_lip_lst[i]
|
894 |
+
combined_lip_ratio_tensor = self.calc_combined_lip_ratio(c_d_lip_i, s_lmk)
|
895 |
+
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
896 |
+
lip_delta = self.retarget_lip(x_s, combined_lip_ratio_tensor)
|
897 |
+
|
898 |
+
if flag_relative_motion: # use x_s
|
899 |
+
x_d_i_new = x_s + \
|
900 |
+
(eyes_delta if eyes_delta is not None else 0) + \
|
901 |
+
(lip_delta if lip_delta is not None else 0)
|
902 |
+
else: # use x_d,i
|
903 |
+
x_d_i_new = x_d_i_new + \
|
904 |
+
(eyes_delta if eyes_delta is not None else 0) + \
|
905 |
+
(lip_delta if lip_delta is not None else 0)
|
906 |
+
|
907 |
+
if flag_stitching:
|
908 |
+
x_d_i_new = self.stitching(x_s, x_d_i_new)
|
909 |
+
|
910 |
+
x_d_i_new = x_s + (x_d_i_new - x_s) * driving_multiplier
|
911 |
+
out = self.warp_decode(f_s, x_s, x_d_i_new)
|
912 |
+
I_p_i = self.parse_output(out['out'])[0]
|
913 |
+
I_p_lst.append(I_p_i)
|
914 |
+
|
915 |
+
return I_p_lst
|
916 |
+
|
917 |
+
def read_image(self, image_path: str) -> list:
|
918 |
+
img_rgb = load_image_rgb(image_path)
|
919 |
+
img_rgb = resize_to_limit(img_rgb, self.cfg.source_max_dim, self.cfg.source_division)
|
920 |
+
source_rgb_list = [img_rgb]
|
921 |
+
print(f"Load image from {osp.realpath(image_path)} done.")
|
922 |
+
return source_rgb_list
|
923 |
+
|
924 |
+
def read_video(self, video_path: str, interval=None) -> list:
|
925 |
+
vr = VideoReader(video_path)
|
926 |
+
if interval is not None:
|
927 |
+
video_frames = vr.get_batch(np.arange(0, len(vr), interval)).numpy()
|
928 |
+
else:
|
929 |
+
video_frames = [vr[0].numpy(), vr[len(vr) // 2].numpy(), vr[-1].numpy()]
|
930 |
+
vr.seek(0)
|
931 |
+
driving_rgb_list = []
|
932 |
+
for video_frame in video_frames:
|
933 |
+
# h, w = video_frame.shape[:2]
|
934 |
+
# if h != self.cfg.output_height or w != self.cfg.output_width:
|
935 |
+
# video_frame = cv2.resize(video_frame, (self.cfg.output_height, self.cfg.output_width))
|
936 |
+
driving_rgb_list.append(video_frame)
|
937 |
+
|
938 |
+
return driving_rgb_list
|
939 |
+
|
940 |
+
def prepare_videos(self, imgs) -> torch.Tensor:
|
941 |
+
""" construct the input as standard
|
942 |
+
imgs: NxBxHxWx3, uint8
|
943 |
+
"""
|
944 |
+
if isinstance(imgs, list):
|
945 |
+
_imgs = np.array(imgs)[..., np.newaxis] # TxHxWx3x1
|
946 |
+
elif isinstance(imgs, np.ndarray):
|
947 |
+
_imgs = imgs
|
948 |
+
else:
|
949 |
+
raise ValueError(f'imgs type error: {type(imgs)}')
|
950 |
+
|
951 |
+
y = _imgs.astype(np.float32) / 255.
|
952 |
+
y = np.clip(y, 0, 1) # clip to 0~1
|
953 |
+
y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
|
954 |
+
y = y.to(self.device)
|
955 |
+
|
956 |
+
return y
|
957 |
+
|
958 |
+
def make_motion_template(self, I_lst, c_eyes_lst, c_lip_lst, **kwargs):
|
959 |
+
n_frames = I_lst.shape[0]
|
960 |
+
template_dct = {
|
961 |
+
'n_frames': n_frames,
|
962 |
+
'output_fps': kwargs.get('output_fps', 25),
|
963 |
+
'motion': [],
|
964 |
+
'c_eyes_lst': [],
|
965 |
+
'c_lip_lst': [],
|
966 |
+
}
|
967 |
+
|
968 |
+
for i in track(range(n_frames), description='Making motion templates...', total=n_frames):
|
969 |
+
# collect s, R, δ and t for inference
|
970 |
+
I_i = I_lst[i]
|
971 |
+
x_i_info = self.refine_kp(self.get_kp_info(I_i))
|
972 |
+
x_s = self.transform_keypoint(x_i_info)
|
973 |
+
R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll'])
|
974 |
+
|
975 |
+
item_dct = {
|
976 |
+
'scale': x_i_info['scale'].cpu().numpy().astype(np.float32),
|
977 |
+
'R': R_i.cpu().numpy().astype(np.float32),
|
978 |
+
'exp': x_i_info['exp'].cpu().numpy().astype(np.float32),
|
979 |
+
't': x_i_info['t'].cpu().numpy().astype(np.float32),
|
980 |
+
'kp': x_i_info['kp'].cpu().numpy().astype(np.float32),
|
981 |
+
'x_s': x_s.cpu().numpy().astype(np.float32),
|
982 |
+
}
|
983 |
+
|
984 |
+
template_dct['motion'].append(item_dct)
|
985 |
+
|
986 |
+
c_eyes = c_eyes_lst[i].astype(np.float32)
|
987 |
+
template_dct['c_eyes_lst'].append(c_eyes)
|
988 |
+
|
989 |
+
c_lip = c_lip_lst[i].astype(np.float32)
|
990 |
+
template_dct['c_lip_lst'].append(c_lip)
|
991 |
+
|
992 |
+
return template_dct
|
993 |
+
|
994 |
+
def load_template(self, wfp_template):
|
995 |
+
print(f"Load from template: {wfp_template}, NOT the video, so the cropping video and audio are both NULL.")
|
996 |
+
driving_template_dct = load(wfp_template)
|
997 |
+
c_d_eyes_lst = driving_template_dct['c_eyes_lst'] if 'c_eyes_lst' in driving_template_dct.keys() else driving_template_dct['c_d_eyes_lst'] # compatible with previous keys
|
998 |
+
c_d_lip_lst = driving_template_dct['c_lip_lst'] if 'c_lip_lst' in driving_template_dct.keys() else driving_template_dct['c_d_lip_lst']
|
999 |
+
driving_n_frames = driving_template_dct['n_frames']
|
1000 |
+
flag_is_driving_video = True if driving_n_frames > 1 else False
|
1001 |
+
n_frames = driving_n_frames
|
1002 |
+
|
1003 |
+
# set output_fps
|
1004 |
+
output_fps = driving_template_dct.get('output_fps', 25)
|
1005 |
+
print(f'The FPS of template: {output_fps}')
|
1006 |
+
return driving_template_dct
|
1007 |
+
|
1008 |
+
def reconstruction(self, src_img, dst_imgs, video_path="template"):
|
1009 |
+
# prepare source
|
1010 |
+
src_img_256x256, s_lmk, _ = self.crop_image(src_img, do_crop=False)
|
1011 |
+
#c_s_eyes_lst, c_s_lip_lst = self.calc_ratio([s_lmk])
|
1012 |
+
c_s_eyes_lst = None
|
1013 |
+
f_s, x_s_info = self.prepare_source(src_img_256x256)
|
1014 |
+
|
1015 |
+
# prepare driving videos
|
1016 |
+
dst_imgs_256x256, d_lmk_lst = self.crop_driving_videos(dst_imgs, do_crop=False)
|
1017 |
+
c_d_eyes_lst, c_d_lip_lst = self.calc_ratio(d_lmk_lst)
|
1018 |
+
kp_infos = self.prepare_driving_videos(dst_imgs_256x256)
|
1019 |
+
|
1020 |
+
|
1021 |
+
recs = self.driven(f_s, x_s_info, s_lmk, c_s_eyes_lst, kp_infos, c_d_eyes_lst, c_d_lip_lst)
|
1022 |
+
return recs
|
1023 |
+
|
1024 |
+
def save_results(self, results, save_path, audio_path=None):
|
1025 |
+
save_dir = osp.dirname(save_path)
|
1026 |
+
save_name = osp.basename(save_path)
|
1027 |
+
final_video = osp.join(save_dir, f'final_{save_name}')
|
1028 |
+
|
1029 |
+
images2video(results, wfp=save_path, fps=self.cfg.output_fps)
|
1030 |
+
|
1031 |
+
if audio_path is not None:
|
1032 |
+
add_audio_to_video(save_path, audio_path, final_video)
|
1033 |
+
os.remove(save_path)
|
1034 |
+
|
1035 |
+
def rec_score(self, video_path: str, interval=None, save_path=None):
|
1036 |
+
video_frames = self.read_video(video_path, interval=interval)
|
1037 |
+
#print(f"len frames: {len(video_frames)}, shape: {video_frames[0].shape}")
|
1038 |
+
recs = self.reconstruction(video_frames[0], video_frames[1:], video_path)
|
1039 |
+
if save_path is not None:
|
1040 |
+
self.save_results(recs, save_path)
|
1041 |
+
#print(f"len rec: {len(recs)}, shape: {recs[0].shape}")
|
1042 |
+
psnrs = psnr(video_frames[1:], recs)
|
1043 |
+
psnrs_np = np.array(psnrs)
|
1044 |
+
psnr_mean, psnr_std = np.mean(psnrs_np), np.std(psnrs_np)
|
1045 |
+
rec_score = {"mean": psnr_mean, "std": psnr_std}
|
1046 |
+
return rec_score
|
1047 |
+
|
1048 |
+
@torch.no_grad()
|
1049 |
+
def paste_back_by_face_mask(self, result, crop_info, src_img, crop_src_image, use_laplacian=False):
|
1050 |
+
"""
|
1051 |
+
paste back the result to the original image with face mask
|
1052 |
+
"""
|
1053 |
+
# detect src mask
|
1054 |
+
crop_src_tensor = self.to_tensor(crop_src_image).unsqueeze(0).to(self.device)
|
1055 |
+
src_msks = get_face_mask(self.face_parser, crop_src_tensor)
|
1056 |
+
result_tensor = self.to_tensor(result).unsqueeze(0).to(self.device)
|
1057 |
+
result_msks = get_face_mask(self.face_parser, result_tensor)
|
1058 |
+
# combine masks
|
1059 |
+
masks = []
|
1060 |
+
for src_msk, result_msk in zip(src_msks, result_msks):
|
1061 |
+
mask = np.clip(src_msk + result_msk, 0, 1)
|
1062 |
+
masks.append(mask)
|
1063 |
+
result = paste_back_with_face_mask(result, crop_info, src_img, masks[0], use_laplacian=use_laplacian)
|
1064 |
+
return result
|
1065 |
+
|
1066 |
+
def driven_by_audio(self, src_img, kp_infos, save_path, audio_path=None, smooth=False):
|
1067 |
+
# prepare source
|
1068 |
+
# prepare source
|
1069 |
+
src_img_256x256, s_lmk, crop_info = self.crop_image(src_img, do_crop=True)
|
1070 |
+
#c_s_eyes_lst, c_s_lip_lst = self.calc_ratio([s_lmk])
|
1071 |
+
c_s_eyes_lst = None
|
1072 |
+
f_s, x_s_info = self.prepare_source(src_img_256x256)
|
1073 |
+
|
1074 |
+
mask_ori_float = prepare_paste_back(self.mask_crop, crop_info['M_c2o'], dsize=(src_img.shape[1], src_img.shape[0]))
|
1075 |
+
|
1076 |
+
# prepare driving videos
|
1077 |
+
results = self.driven(f_s, x_s_info, s_lmk, c_s_eyes_lst, kp_infos, smooth=smooth)
|
1078 |
+
frames=results.shape[0]
|
1079 |
+
results = [paste_back(results[i], crop_info['M_c2o'], src_img, mask_ori_float) for i in range(frames)]
|
1080 |
+
self.save_results(results, save_path, audio_path)
|
1081 |
+
def mix_kp_infos(self, emo_kp_infos, lip_kp_infos, smooth=False, dtype="pt_tensor"):
|
1082 |
+
driving_emo_template_dct = self.get_driving_template(emo_kp_infos, smooth=False, dtype=dtype)
|
1083 |
+
if lip_kp_infos is not None:
|
1084 |
+
driving_lip_template_dct = self.get_driving_template(lip_kp_infos, smooth=smooth, dtype=dtype)
|
1085 |
+
driving_template_dct = {**driving_emo_template_dct}
|
1086 |
+
n_frames = min(driving_emo_template_dct['n_frames'], driving_lip_template_dct['n_frames'])
|
1087 |
+
driving_template_dct['n_frames'] = n_frames
|
1088 |
+
for i in range(n_frames):
|
1089 |
+
emo_motion = driving_emo_template_dct['motion'][i]['exp']
|
1090 |
+
lib_motion = driving_lip_template_dct['motion'][i]['exp']
|
1091 |
+
for lip_idx in [6, 12, 14, 17, 19, 20]:
|
1092 |
+
emo_motion[:, lip_idx, :] = lib_motion[:, lip_idx, :]
|
1093 |
+
driving_template_dct['motion'][i]['exp'] = emo_motion
|
1094 |
+
else:
|
1095 |
+
driving_template_dct = driving_emo_template_dct
|
1096 |
+
|
1097 |
+
return driving_template_dct
|
1098 |
+
|
1099 |
+
def driven_by_mix(self, src_img, driving_video_path, kp_infos, save_path, audio_path=None, smooth=False):
|
1100 |
+
# prepare source
|
1101 |
+
src_img_256x256, s_lmk, crop_info = self.crop_image(src_img, do_crop=True)
|
1102 |
+
c_s_eyes_lst, c_s_lip_lst = self.calc_ratio([s_lmk])
|
1103 |
+
f_s, x_s_info = self.prepare_source(src_img_256x256)
|
1104 |
+
mask_ori_float = prepare_paste_back(self.mask_crop, crop_info['M_c2o'], dsize=(src_img.shape[1], src_img.shape[0]))
|
1105 |
+
# prepare driving videos
|
1106 |
+
driving_imgs = self.read_video(driving_video_path, interval=1)
|
1107 |
+
dst_imgs_256x256, d_lmk_lst = self.crop_driving_videos(driving_imgs, do_crop=True)
|
1108 |
+
c_d_eyes_lst, c_d_lip_lst = self.calc_ratio(d_lmk_lst)
|
1109 |
+
emo_kp_infos = self.prepare_driving_videos(dst_imgs_256x256)
|
1110 |
+
# mix kp_infos
|
1111 |
+
driving_template_dct = self.mix_kp_infos(emo_kp_infos, kp_infos, smooth=smooth)
|
1112 |
+
# driven
|
1113 |
+
results = self.driven_debug(f_s, x_s_info, s_lmk, c_s_eyes_lst, driving_template_dct, c_d_eyes_lst=c_d_eyes_lst, c_d_lip_lst=c_d_lip_lst)
|
1114 |
+
results = [paste_back(result, crop_info['M_c2o'], src_img, mask_ori_float) for result in results]
|
1115 |
+
print(results.shape)
|
1116 |
+
self.save_results(results, save_path, audio_path)
|
1117 |
+
|
1118 |
+
def drive_video_by_mix(self, video_path, driving_video_path, kp_infos, save_path, audio_path):
|
1119 |
+
# prepare driving videos
|
1120 |
+
driving_imgs = self.read_video(driving_video_path, interval=1)
|
1121 |
+
dst_imgs_256x256, d_lmk_lst = self.crop_driving_videos(driving_imgs, do_crop=True)
|
1122 |
+
emo_kp_infos = self.prepare_driving_videos(dst_imgs_256x256)
|
1123 |
+
# mix kp_infos
|
1124 |
+
#driving_template_dct = self.get_driving_template(emo_kp_infos, smooth=True, dtype="np")
|
1125 |
+
driving_template_dct = self.mix_kp_infos(emo_kp_infos, kp_infos, smooth=True, dtype="np")
|
1126 |
+
# driven
|
1127 |
+
self.video_lip_retargeting(
|
1128 |
+
video_path, None,
|
1129 |
+
save_path, audio_path,
|
1130 |
+
driving_template_dct=driving_template_dct, retargeting_ragion="exp"
|
1131 |
+
)
|
1132 |
+
|
1133 |
+
def load_source_video(self, video_info, n_frames=-1):
|
1134 |
+
reader = imageio.get_reader(video_info, "ffmpeg")
|
1135 |
+
|
1136 |
+
ret = []
|
1137 |
+
for idx, frame_rgb in enumerate(reader):
|
1138 |
+
if n_frames > 0 and idx >= n_frames:
|
1139 |
+
break
|
1140 |
+
ret.append(frame_rgb)
|
1141 |
+
|
1142 |
+
reader.close()
|
1143 |
+
|
1144 |
+
return ret
|
1145 |
+
|
1146 |
+
def video_lip_retargeting(self, video_path, kp_infos, save_path, audio_path, c_d_eyes_lst=None, c_d_lip_lst=None, smooth=False, driving_template_dct=None, retargeting_ragion="exp"):
|
1147 |
+
# 0. process source motion template
|
1148 |
+
source_rgb_lst = load_video(video_path)
|
1149 |
+
source_rgb_lst = [resize_to_limit(img, self.cfg.source_max_dim, self.cfg.source_division) for img in source_rgb_lst]
|
1150 |
+
img_crop_256x256_lst, source_lmk_crop_lst, source_M_c2o_lst = self.crop_source_video(source_rgb_lst, do_crop=True)
|
1151 |
+
c_s_eyes_lst, c_s_lip_lst = self.calc_ratio(source_lmk_crop_lst)
|
1152 |
+
I_s_lst = self.prepare_videos(img_crop_256x256_lst)
|
1153 |
+
source_template_dct = self.make_motion_template(I_s_lst, c_s_eyes_lst, c_s_lip_lst, output_fps=25)
|
1154 |
+
# 1. prepare driving template
|
1155 |
+
if driving_template_dct is None:
|
1156 |
+
driving_template_dct = self.get_driving_template(kp_infos, smooth=smooth, dtype="np")
|
1157 |
+
# 2. driving
|
1158 |
+
n_frames = min(source_template_dct['n_frames'], driving_template_dct['n_frames'])
|
1159 |
+
# driving params
|
1160 |
+
I_p_lst = []
|
1161 |
+
I_p_pstbk_lst = []
|
1162 |
+
R_d_0, x_d_0_info = None, None
|
1163 |
+
flag_normalize_lip = self.cfg.flag_normalize_lip
|
1164 |
+
flag_relative_motion = True #self.cfg.flag_relative_motion
|
1165 |
+
flag_source_video_eye_retargeting = self.cfg.flag_source_video_eye_retargeting
|
1166 |
+
lip_normalize_threshold = self.cfg.lip_normalize_threshold
|
1167 |
+
source_video_eye_retargeting_threshold = self.cfg.source_video_eye_retargeting_threshold
|
1168 |
+
animation_region = 'lip' #self.cfg.animation_region
|
1169 |
+
driving_option = self.cfg.driving_option
|
1170 |
+
flag_stitching = self.cfg.flag_stitching
|
1171 |
+
flag_eye_retargeting = self.cfg.flag_eye_retargeting
|
1172 |
+
flag_lip_retargeting = self.cfg.flag_lip_retargeting
|
1173 |
+
driving_multiplier = self.cfg.driving_multiplier
|
1174 |
+
driving_smooth_observation_variance = self.cfg.driving_smooth_observation_variance
|
1175 |
+
|
1176 |
+
key_r = 'R' if 'R' in driving_template_dct['motion'][0].keys() else 'R_d'
|
1177 |
+
if flag_relative_motion:
|
1178 |
+
x_d_exp_lst = [source_template_dct['motion'][i]['exp'] + driving_template_dct['motion'][i]['exp'] - driving_template_dct['motion'][0]['exp'] for i in range(n_frames)]
|
1179 |
+
for i in range(n_frames):
|
1180 |
+
for idx in [6, 12, 14, 17, 19, 20]:
|
1181 |
+
# lip motion use abs motion
|
1182 |
+
x_d_exp_lst[i][:, idx, :] = driving_template_dct['motion'][i]['exp'][:, idx, :]
|
1183 |
+
x_d_exp_lst_smooth = ksmooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, self.device, driving_smooth_observation_variance)
|
1184 |
+
|
1185 |
+
if animation_region == "all" or animation_region == "pose" or "all" in animation_region:
|
1186 |
+
x_d_r_lst = [(np.dot(driving_template_dct['motion'][i][key_r], driving_template_dct['motion'][0][key_r].transpose(0, 2, 1))) @ source_template_dct['motion'][i]['R'] for i in range(n_frames)]
|
1187 |
+
x_d_r_lst_smooth = ksmooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, self.device, driving_smooth_observation_variance)
|
1188 |
+
else:
|
1189 |
+
x_d_exp_lst = [driving_template_dct['motion'][i]['exp'] for i in range(n_frames)]
|
1190 |
+
x_d_exp_lst_smooth = ksmooth(x_d_exp_lst, source_template_dct['motion'][0]['exp'].shape, self.device, driving_smooth_observation_variance)
|
1191 |
+
|
1192 |
+
if animation_region == "all" or animation_region == "pose" or "all" in animation_region:
|
1193 |
+
x_d_r_lst = [driving_template_dct['motion'][i][key_r] for i in range(n_frames)]
|
1194 |
+
x_d_r_lst_smooth = ksmooth(x_d_r_lst, source_template_dct['motion'][0]['R'].shape, self.device, driving_smooth_observation_variance)
|
1195 |
+
|
1196 |
+
# driving all
|
1197 |
+
for i in track(range(n_frames), description='🚀Retargeting...', total=n_frames):
|
1198 |
+
x_s_info = source_template_dct['motion'][i]
|
1199 |
+
x_s_info = dct2device(x_s_info, self.device)
|
1200 |
+
|
1201 |
+
source_lmk = source_lmk_crop_lst[i]
|
1202 |
+
img_crop_256x256 = img_crop_256x256_lst[i]
|
1203 |
+
I_s = I_s_lst[i]
|
1204 |
+
f_s = self.extract_feature_3d(I_s)
|
1205 |
+
|
1206 |
+
x_c_s = x_s_info['kp']
|
1207 |
+
R_s = x_s_info['R']
|
1208 |
+
x_s =x_s_info['x_s']
|
1209 |
+
|
1210 |
+
# let lip-open scalar to be 0 at first if the input is a video
|
1211 |
+
lip_delta_before_animation = None
|
1212 |
+
if flag_normalize_lip and flag_relative_motion and source_lmk is not None:
|
1213 |
+
c_d_lip_before_animation = [0.]
|
1214 |
+
combined_lip_ratio_tensor_before_animation = self.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
|
1215 |
+
if combined_lip_ratio_tensor_before_animation[0][0] >= lip_normalize_threshold:
|
1216 |
+
lip_delta_before_animation = self.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
|
1217 |
+
else:
|
1218 |
+
lip_delta_before_animation = None
|
1219 |
+
|
1220 |
+
# let eye-open scalar to be the same as the first frame if the latter is eye-open state
|
1221 |
+
eye_delta_before_animation = None
|
1222 |
+
if flag_source_video_eye_retargeting and source_lmk is not None:
|
1223 |
+
if i == 0:
|
1224 |
+
combined_eye_ratio_tensor_frame_zero = c_s_eyes_lst[0]
|
1225 |
+
c_d_eye_before_animation_frame_zero = [[combined_eye_ratio_tensor_frame_zero[0][:2].mean()]]
|
1226 |
+
if c_d_eye_before_animation_frame_zero[0][0] < source_video_eye_retargeting_threshold:
|
1227 |
+
c_d_eye_before_animation_frame_zero = [[0.39]]
|
1228 |
+
combined_eye_ratio_tensor_before_animation = self.calc_combined_eye_ratio(c_d_eye_before_animation_frame_zero, source_lmk)
|
1229 |
+
eye_delta_before_animation = self.retarget_eye(x_s, combined_eye_ratio_tensor_before_animation)
|
1230 |
+
|
1231 |
+
if flag_stitching: # prepare for paste back
|
1232 |
+
mask_ori_float = prepare_paste_back(self.mask_crop, source_M_c2o_lst[i], dsize=(source_rgb_lst[i].shape[1], source_rgb_lst[i].shape[0]))
|
1233 |
+
|
1234 |
+
x_d_i_info = driving_template_dct['motion'][i]
|
1235 |
+
x_d_i_info = dct2device(x_d_i_info, self.device)
|
1236 |
+
R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] # compatible with previous keys
|
1237 |
+
|
1238 |
+
if i == 0: # cache the first frame
|
1239 |
+
R_d_0 = R_d_i
|
1240 |
+
x_d_0_info = x_d_i_info.copy()
|
1241 |
+
|
1242 |
+
delta_new = x_s_info['exp'].clone()
|
1243 |
+
if flag_relative_motion:
|
1244 |
+
if animation_region == "all" or animation_region == "pose" or "all" in animation_region:
|
1245 |
+
R_new = x_d_r_lst_smooth[i]
|
1246 |
+
else:
|
1247 |
+
R_new = R_s
|
1248 |
+
if animation_region == "all" or animation_region == "exp":
|
1249 |
+
for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
|
1250 |
+
delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :]
|
1251 |
+
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1]
|
1252 |
+
delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2]
|
1253 |
+
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2]
|
1254 |
+
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:]
|
1255 |
+
elif animation_region == "all_wo_lip" or animation_region == "exp_wo_lip":
|
1256 |
+
for idx in [1, 2, 11, 13, 15, 16, 18]:
|
1257 |
+
delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :]
|
1258 |
+
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1]
|
1259 |
+
delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2]
|
1260 |
+
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2]
|
1261 |
+
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:]
|
1262 |
+
elif animation_region == "lip":
|
1263 |
+
for lip_idx in [6, 12, 14, 17, 19, 20]:
|
1264 |
+
delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :]
|
1265 |
+
elif animation_region == "eyes":
|
1266 |
+
for eyes_idx in [11, 13, 15, 16, 18]:
|
1267 |
+
delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :]
|
1268 |
+
|
1269 |
+
scale_new = x_s_info['scale']
|
1270 |
+
t_new = x_s_info['t']
|
1271 |
+
else:
|
1272 |
+
if animation_region == "all" or animation_region == "pose" or "all" in animation_region:
|
1273 |
+
R_new = x_d_r_lst_smooth[i]
|
1274 |
+
else:
|
1275 |
+
R_new = R_s
|
1276 |
+
if animation_region == "all" or animation_region == "exp":
|
1277 |
+
for idx in [1,2,6,11,12,13,14,15,16,17,18,19,20]:
|
1278 |
+
delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :]
|
1279 |
+
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1]
|
1280 |
+
delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2]
|
1281 |
+
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2]
|
1282 |
+
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:]
|
1283 |
+
elif animation_region == "all_wo_lip" or animation_region == "exp_wo_lip":
|
1284 |
+
for idx in [1, 2, 11, 13, 15, 16, 18]:
|
1285 |
+
delta_new[:, idx, :] = x_d_exp_lst_smooth[i][idx, :]
|
1286 |
+
delta_new[:, 3:5, 1] = x_d_exp_lst_smooth[i][3:5, 1]
|
1287 |
+
delta_new[:, 5, 2] = x_d_exp_lst_smooth[i][5, 2]
|
1288 |
+
delta_new[:, 8, 2] = x_d_exp_lst_smooth[i][8, 2]
|
1289 |
+
delta_new[:, 9, 1:] = x_d_exp_lst_smooth[i][9, 1:]
|
1290 |
+
elif animation_region == "lip":
|
1291 |
+
for lip_idx in [6, 12, 14, 17, 19, 20]:
|
1292 |
+
delta_new[:, lip_idx, :] = x_d_exp_lst_smooth[i][lip_idx, :]
|
1293 |
+
elif animation_region == "eyes":
|
1294 |
+
for eyes_idx in [11, 13, 15, 16, 18]:
|
1295 |
+
delta_new[:, eyes_idx, :] = x_d_exp_lst_smooth[i][eyes_idx, :]
|
1296 |
+
scale_new = x_s_info['scale']
|
1297 |
+
if animation_region == "all" or animation_region == "pose" or "all" in animation_region:
|
1298 |
+
t_new = x_d_i_info['t']
|
1299 |
+
else:
|
1300 |
+
t_new = x_s_info['t']
|
1301 |
+
|
1302 |
+
t_new[..., 2].fill_(0) # zero tz
|
1303 |
+
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
|
1304 |
+
|
1305 |
+
# Algorithm 1:
|
1306 |
+
if not flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting:
|
1307 |
+
# without stitching or retargeting
|
1308 |
+
if flag_normalize_lip and lip_delta_before_animation is not None:
|
1309 |
+
x_d_i_new += lip_delta_before_animation
|
1310 |
+
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
|
1311 |
+
x_d_i_new += eye_delta_before_animation
|
1312 |
+
else:
|
1313 |
+
pass
|
1314 |
+
elif flag_stitching and not flag_eye_retargeting and not flag_lip_retargeting:
|
1315 |
+
# with stitching and without retargeting
|
1316 |
+
if flag_normalize_lip and lip_delta_before_animation is not None:
|
1317 |
+
x_d_i_new = self.stitching(x_s, x_d_i_new) + lip_delta_before_animation
|
1318 |
+
else:
|
1319 |
+
x_d_i_new = self.stitching(x_s, x_d_i_new)
|
1320 |
+
if flag_source_video_eye_retargeting and eye_delta_before_animation is not None:
|
1321 |
+
x_d_i_new += eye_delta_before_animation
|
1322 |
+
else:
|
1323 |
+
eyes_delta, lip_delta = None, None
|
1324 |
+
if flag_eye_retargeting and source_lmk is not None and c_d_eyes_lst is not None:
|
1325 |
+
c_d_eyes_i = c_d_eyes_lst[i]
|
1326 |
+
combined_eye_ratio_tensor = self.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
|
1327 |
+
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
1328 |
+
eyes_delta = self.retarget_eye(x_s, combined_eye_ratio_tensor)
|
1329 |
+
if flag_lip_retargeting and source_lmk is not None and c_d_lip_lst is not None:
|
1330 |
+
c_d_lip_i = c_d_lip_lst[i]
|
1331 |
+
combined_lip_ratio_tensor = self.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
|
1332 |
+
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
1333 |
+
lip_delta = self.retarget_lip(x_s, combined_lip_ratio_tensor)
|
1334 |
+
|
1335 |
+
if flag_relative_motion: # use x_s
|
1336 |
+
x_d_i_new = x_s + \
|
1337 |
+
(eyes_delta if eyes_delta is not None else 0) + \
|
1338 |
+
(lip_delta if lip_delta is not None else 0)
|
1339 |
+
else: # use x_d,i
|
1340 |
+
x_d_i_new = x_d_i_new + \
|
1341 |
+
(eyes_delta if eyes_delta is not None else 0) + \
|
1342 |
+
(lip_delta if lip_delta is not None else 0)
|
1343 |
+
|
1344 |
+
if flag_stitching:
|
1345 |
+
x_d_i_new = self.stitching(x_s, x_d_i_new)
|
1346 |
+
|
1347 |
+
x_d_i_new = x_s + (x_d_i_new - x_s) * driving_multiplier
|
1348 |
+
out = self.warp_decode(f_s, x_s, x_d_i_new)
|
1349 |
+
I_p_i = self.parse_output(out['out'])[0]
|
1350 |
+
I_p_lst.append(I_p_i)
|
1351 |
+
|
1352 |
+
if flag_stitching:
|
1353 |
+
# TODO: the paste back procedure is slow, considering optimize it using multi-threading or GPU
|
1354 |
+
#I_p_pstbk = self.paste_back_by_face_mask(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], img_crop_256x256, use_laplacian=True)
|
1355 |
+
I_p_pstbk = paste_back(I_p_i, source_M_c2o_lst[i], source_rgb_lst[i], mask_ori_float, use_laplacian=True)
|
1356 |
+
I_p_pstbk_lst.append(I_p_pstbk)
|
1357 |
+
|
1358 |
+
if len(I_p_pstbk_lst) > 0:
|
1359 |
+
self.save_results(I_p_pstbk_lst, save_path, audio_path)
|
1360 |
+
else:
|
1361 |
+
self.save_results(I_p_lst, save_path, audio_path)
|
1362 |
+
|
1363 |
+
@torch.no_grad()
|
1364 |
+
def video_reconstruction_test(self, video_tensor, xs, save_path):
|
1365 |
+
# video_tensor, (1, F, C, H, W), [-1, 1]
|
1366 |
+
# xs, (1, F, 63)
|
1367 |
+
result_lst = []
|
1368 |
+
#ori_videos = []
|
1369 |
+
video_tensor = video_tensor[0:1] * 0.5 + 0.5 # [-1, 1] -> [0, 1], 1xTx3xHxW
|
1370 |
+
video_tensor = torch.clip(video_tensor, 0, 1)
|
1371 |
+
video_tensor = video_tensor.permute(1, 0, 2, 3, 4) # 1xTx3xHxW -> Tx1x3xHxW
|
1372 |
+
video = video_tensor.to(self.device)
|
1373 |
+
xs = xs[0:1].permute(1, 0, 2) # 1xTx63 -> Tx1x63
|
1374 |
+
xs = xs.reshape(-1, 1, 21, 3)
|
1375 |
+
xs = xs.to(self.device)
|
1376 |
+
|
1377 |
+
x_s_0 = xs[0]
|
1378 |
+
I_s_0 = torch.nn.functional.interpolate(video[0], size=(256, 256), mode='bilinear')
|
1379 |
+
f_s_0 = self.extract_feature_3d(I_s_0)
|
1380 |
+
|
1381 |
+
for i in range(video_tensor.shape[0]):
|
1382 |
+
#I_s = video[i] # 1x3xHxW
|
1383 |
+
#ori_videos.append((I_s.squeeze(0).squeeze(0).permute(1, 2, 0).cpu().numpy()*255).astype(np.uint8))
|
1384 |
+
x_s = self.stitching(x_s_0, xs[i])
|
1385 |
+
out = self.warp_decode(f_s_0, x_s_0, x_s)
|
1386 |
+
I_p_i = self.parse_output(out['out'])[0]
|
1387 |
+
result_lst.append(I_p_i)
|
1388 |
+
|
1389 |
+
#save_dir = osp.dirname(save_path)
|
1390 |
+
#ori_path = osp.join(save_dir, "ori.mp4")
|
1391 |
+
#save_path = osp.join(save_dir, "rec.mp4")
|
1392 |
+
self.save_results(result_lst, save_path, audio_path=None)
|
1393 |
+
#self.save_results(ori_videos, ori_path, audio_path=None)
|
1394 |
+
|
1395 |
+
@torch.no_grad()
|
1396 |
+
def self_driven(self, image_tensor, xs, save_path, length):
|
1397 |
+
result_lst = []
|
1398 |
+
image_tensor = image_tensor[0:1] * 0.5 + 0.5 # [-1, 1] -> [0, 1], 1x3xHxW
|
1399 |
+
image_tensor = torch.clip(image_tensor, 0, 1)
|
1400 |
+
image = image_tensor.to(self.device)
|
1401 |
+
I_s_0 = torch.nn.functional.interpolate(image, size=(256, 256), mode='bilinear')
|
1402 |
+
|
1403 |
+
xs = xs[0:1].permute(1, 0, 2) # 1xTx63 -> Tx1x63
|
1404 |
+
xs = xs.reshape(-1, 1, 21, 3)
|
1405 |
+
xs = xs.to(self.device)
|
1406 |
+
|
1407 |
+
x_s_0 = xs[0]
|
1408 |
+
f_s_0 = self.extract_feature_3d(I_s_0)
|
1409 |
+
|
1410 |
+
for i in range(xs.shape[0]):
|
1411 |
+
x_d = self.stitching(x_s_0, xs[i])
|
1412 |
+
out = self.warp_decode(f_s_0, x_s_0, x_d)
|
1413 |
+
I_p_i = self.parse_output(out['out'])[0]
|
1414 |
+
result_lst.append(I_p_i)
|
1415 |
+
|
1416 |
+
assert len(result_lst) == length, f"length of result_lst is {len(result_lst)}, but length is {length}"
|
1417 |
+
|
1418 |
+
self.save_results(result_lst, save_path, audio_path=None)
|
1419 |
+
|
1420 |
+
|
src/examples/driving_audios/10.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:79b53cbd91ebd7756b51f4d388a769b461a247f26acae5c362ca326e27c23626
|
3 |
+
size 2880078
|
src/examples/driving_audios/5.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:79b53cbd91ebd7756b51f4d388a769b461a247f26acae5c362ca326e27c23626
|
3 |
+
size 2880078
|
src/examples/driving_audios/6.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:90be6ae092eaa9be4e74e0bed56ef343a825bc2c899d2868e0e3aee494c86a04
|
3 |
+
size 1323078
|
src/examples/driving_audios/tmp_5.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b2f615328211bb938ab7f6b603631695106d2e23ceaa4dfcd4f491bc5dc2faca
|
3 |
+
size 544044
|
src/examples/reference_images/1.jpg
ADDED
![]() |
Git LFS Details
|
src/examples/reference_images/2.jpg
ADDED
![]() |
src/examples/reference_images/3.jpg
ADDED
![]() |
src/examples/reference_images/4.jpg
ADDED
![]() |
src/examples/reference_images/5.jpg
ADDED
![]() |
src/examples/reference_images/6.jpg
ADDED
![]() |
src/examples/reference_images/7.jpg
ADDED
![]() |
Git LFS Details
|
src/examples/silent-audio.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:231cedffe295d0f5c8ea8569af9edc2471c262410689190bb705fb0adb62f63f
|
3 |
+
size 352878
|
src/models/audio/__pycache__/audio_processer.cpython-310.pyc
ADDED
Binary file (12.1 kB). View file
|
|
src/models/audio/__pycache__/audio_proj.cpython-310.pyc
ADDED
Binary file (4.66 kB). View file
|
|
src/models/audio/__pycache__/hubert.cpython-310.pyc
ADDED
Binary file (3.45 kB). View file
|
|
src/models/audio/__pycache__/wav2vec.cpython-310.pyc
ADDED
Binary file (6 kB). View file
|
|
src/models/audio/__pycache__/wav2vec2.cpython-310.pyc
ADDED
Binary file (4.32 kB). View file
|
|
src/models/audio/__pycache__/wav2vec_modified.cpython-310.pyc
ADDED
Binary file (6.78 kB). View file
|
|
src/models/audio/audio_processer.py
ADDED
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Audio processer for talking data.
|
2 |
+
Author: linzhihui.lzh
|
3 |
+
Date: 2024-12-12
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
from re import A
|
7 |
+
import sys
|
8 |
+
import os.path as osp
|
9 |
+
|
10 |
+
from typing import List, Dict, Tuple, Optional, Union, Any
|
11 |
+
|
12 |
+
import yaml
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
|
15 |
+
import math
|
16 |
+
import librosa
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
from einops import rearrange
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
from pydub import AudioSegment
|
25 |
+
# from audio_separator.separator import Separator
|
26 |
+
|
27 |
+
sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__))))))
|
28 |
+
from src.utils.rprint import rlog as log
|
29 |
+
from src.utils.util import resample_audio
|
30 |
+
|
31 |
+
from src.models.audio.wav2vec_modified import Wav2VecModel
|
32 |
+
from src.models.audio.hubert import HubertModel
|
33 |
+
|
34 |
+
|
35 |
+
def pad_audio(audio, audio_unit=320, pad_threshold=80):
|
36 |
+
batch_size, audio_len = audio.shape
|
37 |
+
n_units = audio_len // audio_unit
|
38 |
+
side_len = math.ceil((audio_unit * n_units + pad_threshold - audio_len) / 2)
|
39 |
+
if side_len >= 0:
|
40 |
+
reflect_len = side_len // 2
|
41 |
+
replicate_len = side_len % 2
|
42 |
+
if reflect_len > 0:
|
43 |
+
audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
|
44 |
+
audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
|
45 |
+
if replicate_len > 0:
|
46 |
+
audio = F.pad(audio, (1, 1), mode='replicate')
|
47 |
+
|
48 |
+
return audio
|
49 |
+
|
50 |
+
|
51 |
+
def cut_audio(audio_path: str, save_dir: str, length=60) -> List[str]:
|
52 |
+
"""Cut audio into sub-divisions and return subfile paths. Supports wav format.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
audio_path (str): the source audio file path
|
56 |
+
save_dir (str): the save directory of sub-divisions
|
57 |
+
length (int, optional): The max length of each sub-division. Defaults to 60 secs.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
List[str]: the subfile paths
|
61 |
+
"""
|
62 |
+
audio_name = osp.basename(audio_path).split('.')[0]
|
63 |
+
audio = AudioSegment.from_wav(audio_path)
|
64 |
+
segment_length = length * 1000. # pydub uses milliseconds
|
65 |
+
num_segments = math.ceil(len(audio) / segment_length)
|
66 |
+
|
67 |
+
os.makedirs(save_dir, exist_ok=True)
|
68 |
+
audio_list = []
|
69 |
+
|
70 |
+
for i in range(num_segments):
|
71 |
+
start_time = i * segment_length
|
72 |
+
end_time = min((i + 1) * segment_length, len(audio))
|
73 |
+
segment = audio[start_time:end_time]
|
74 |
+
|
75 |
+
path = osp.join(save_dir, f"{audio_name}_segment_{i+1}.wav")
|
76 |
+
audio_list.append(path)
|
77 |
+
segment.export(path, format="wav")
|
78 |
+
return audio_list
|
79 |
+
|
80 |
+
|
81 |
+
class AudioProcessor(object):
|
82 |
+
def __init__(self, cfg_path: str, is_training: bool = False) -> None:
|
83 |
+
cfg = OmegaConf.load(cfg_path)
|
84 |
+
self.cfg = cfg
|
85 |
+
self.is_training = is_training
|
86 |
+
log("========================================= Audio Processer =========================================")
|
87 |
+
log(OmegaConf.to_yaml(cfg))
|
88 |
+
|
89 |
+
# setting device
|
90 |
+
self.device_id = cfg.device_params.device_id
|
91 |
+
self.use_half = cfg.device_params.flag_use_half_precision
|
92 |
+
if cfg.device_params.flag_force_cpu:
|
93 |
+
self.device = 'cpu'
|
94 |
+
else:
|
95 |
+
try:
|
96 |
+
if torch.backends.mps.is_available():
|
97 |
+
self.device = 'mps'
|
98 |
+
else:
|
99 |
+
self.device = 'cuda:' + str(self.device_id)
|
100 |
+
except:
|
101 |
+
self.device = 'cuda:' + str(self.device_id)
|
102 |
+
|
103 |
+
# init audio separator
|
104 |
+
self.audio_separator = None
|
105 |
+
self.cache_dir = cfg.cache_dir
|
106 |
+
self.tmp_dir = cfg.tmp_dir
|
107 |
+
self.use_audio_separator = cfg.model_params.use_audio_separator
|
108 |
+
self.audio_separator_name = cfg.model_params.audio_separator_name
|
109 |
+
self.audio_separator_path = cfg.model_weights.audio_separator_path
|
110 |
+
self.set_audio_separator(cfg.cache_dir)
|
111 |
+
|
112 |
+
# load audio encoder, wav2vec or hubert
|
113 |
+
self.model_name = cfg.model_params.model_name
|
114 |
+
self.is_chinese = cfg.model_params.is_chinese
|
115 |
+
self.audio_encoder = self.load_model(
|
116 |
+
model_name = cfg.model_params.model_name,
|
117 |
+
model_type = cfg.model_params.model_type,
|
118 |
+
is_chinese = cfg.model_params.is_chinese,
|
119 |
+
)
|
120 |
+
self.only_last_features = cfg.model_params.only_last_features
|
121 |
+
if cfg.model_params.only_last_features:
|
122 |
+
self.feature_shape = (1, 768)
|
123 |
+
else:
|
124 |
+
self.feature_shape = (12, 768) # features of 12 blocks
|
125 |
+
|
126 |
+
# init data params
|
127 |
+
self.sample_strategy = cfg.data_params.sample_strategy
|
128 |
+
self.sample_rate = cfg.data_params.sample_rate
|
129 |
+
self.fps = cfg.data_params.fps
|
130 |
+
self.audio_unit = cfg.data_params.sample_rate / cfg.data_params.fps # num of audio samples per frame
|
131 |
+
self.max_length = cfg.data_params.max_length
|
132 |
+
self.subclip_len = cfg.data_params.sub_clip_length
|
133 |
+
self.save_to_cpu = cfg.data_params.save_to_cpu
|
134 |
+
self.pad_mode = cfg.data_params.audio_pad_mode
|
135 |
+
|
136 |
+
log("========================================= Audio Processer: Done =========================================")
|
137 |
+
|
138 |
+
def load_model(self, model_name: str="wav2vec", model_type: str="base", is_chinese: bool = False):
|
139 |
+
assert model_name in ["wav2vec", "hubert"], f"Unknown audio model {model_name}, only support wav2vec or hubert"
|
140 |
+
assert model_type in ["base", "large"], f"Unknown audio model type {model_type}, only support base or large"
|
141 |
+
|
142 |
+
if model_name == "wav2vec":
|
143 |
+
# load wav2vec model weights
|
144 |
+
if is_chinese:
|
145 |
+
if model_type == "base":
|
146 |
+
model_weight_path = self.cfg.model_weights.wav2vec_path.chinese.base
|
147 |
+
else:
|
148 |
+
model_weight_path = self.cfg.model_weights.wav2vec_path.chinese.large
|
149 |
+
else:
|
150 |
+
if model_type == "base":
|
151 |
+
model_weight_path = self.cfg.model_weights.wav2vec_path.default.base
|
152 |
+
else:
|
153 |
+
model_weight_path = self.cfg.model_weights.wav2vec_path.default.large
|
154 |
+
if model_weight_path is None:
|
155 |
+
raise ValueError(f"model_weight_path is None")
|
156 |
+
audio_encoder = Wav2VecModel.from_pretrained(model_weight_path, local_files_only=True).to(device=self.device)
|
157 |
+
else:
|
158 |
+
if is_chinese:
|
159 |
+
if model_type == "base":
|
160 |
+
model_weight_path = self.cfg.model_weights.hubert_path.chinese.base
|
161 |
+
else:
|
162 |
+
model_weight_path = self.cfg.model_weights.hubert_path.chinese.large
|
163 |
+
else:
|
164 |
+
if model_type == "base":
|
165 |
+
model_weight_path = self.cfg.model_weights.hubert_path.default.base
|
166 |
+
else:
|
167 |
+
model_weight_path = self.cfg.model_weights.hubert_path.default.large
|
168 |
+
if model_weight_path is None:
|
169 |
+
raise ValueError(f"model_weight_path is None")
|
170 |
+
audio_encoder = HubertModel.from_pretrained(model_weight_path, local_files_only=True).to(device=self.device)
|
171 |
+
|
172 |
+
log(f"{model_name}-{model_type}-chinese-{is_chinese} model has beed loaded from {model_weight_path}")
|
173 |
+
total_params = sum(p.numel() for p in audio_encoder.parameters())
|
174 |
+
print('Number of parameter: % .4fM' % (total_params / 1e6))
|
175 |
+
|
176 |
+
# weights initialization
|
177 |
+
audio_encoder.feature_extractor._freeze_parameters()
|
178 |
+
if not self.cfg.model_params.is_original:
|
179 |
+
frozen_layers = [0, 1]
|
180 |
+
for name, param in audio_encoder.named_parameters():
|
181 |
+
if name.startswith("feature_projection"):
|
182 |
+
param.requires_grad = False
|
183 |
+
if name.startswith("encoder.layers"):
|
184 |
+
layer = int(name.split(".")[2])
|
185 |
+
if layer in frozen_layers:
|
186 |
+
param.requires_grad = False
|
187 |
+
|
188 |
+
audio_encoder = audio_encoder.to(self.device)
|
189 |
+
if self.use_half:
|
190 |
+
audio_encoder = audio_encoder.half()
|
191 |
+
audio_encoder.eval()
|
192 |
+
return audio_encoder
|
193 |
+
|
194 |
+
def set_audio_separator(self, output_dir: str) -> None:
|
195 |
+
del self.audio_separator
|
196 |
+
|
197 |
+
if self.audio_separator_name is not None and self.use_audio_separator:
|
198 |
+
try:
|
199 |
+
os.makedirs(output_dir, exist_ok=True)
|
200 |
+
except OSError as _:
|
201 |
+
print("Fail to create the output cache dir.")
|
202 |
+
self.audio_separator = Separator(
|
203 |
+
output_dir=output_dir,
|
204 |
+
output_single_stem="vocals",
|
205 |
+
model_file_dir=self.audio_separator_path,
|
206 |
+
)
|
207 |
+
self.audio_separator.load_model(self.audio_separator_name)
|
208 |
+
assert self.audio_separator.model_instance is not None, "Fail to load audio separate model."
|
209 |
+
else:
|
210 |
+
self.audio_separator=None
|
211 |
+
log("Use audio directly without vocals seperator.")
|
212 |
+
|
213 |
+
def seperate_audio(self, audio_path: str, output_dir: Union[str, None] = None) -> str:
|
214 |
+
if output_dir is not None:
|
215 |
+
if output_dir != self.cache_dir:
|
216 |
+
# reload audio separator
|
217 |
+
self.set_audio_separator(output_dir)
|
218 |
+
|
219 |
+
if self.audio_separator is not None:
|
220 |
+
# 1. separate vocals
|
221 |
+
# TODO: process in memory
|
222 |
+
try:
|
223 |
+
outputs = self.audio_separator.separate(audio_path)
|
224 |
+
if len(outputs) <= 0:
|
225 |
+
raise RuntimeError("Audio separate failed.")
|
226 |
+
|
227 |
+
vocal_audio_file = outputs[0]
|
228 |
+
vocal_audio_name, _ = os.path.splitext(vocal_audio_file)
|
229 |
+
vocal_audio_file = os.path.join(self.audio_separator.output_dir, vocal_audio_file)
|
230 |
+
vocal_audio_file = resample_audio(vocal_audio_file, os.path.join(self.audio_separator.output_dir, f"{vocal_audio_name}-16k.wav"), self.sample_rate)
|
231 |
+
except Exception as e:
|
232 |
+
log(f"Fail to separate vocals from {audio_path}, error info [{e}]")
|
233 |
+
vocal_audio_file=audio_path
|
234 |
+
else:
|
235 |
+
vocal_audio_file=audio_path
|
236 |
+
|
237 |
+
return vocal_audio_file
|
238 |
+
|
239 |
+
def load_audio(self, audio_path: str, mono: bool = True, duration: Optional[float] = None) -> Any:
|
240 |
+
try:
|
241 |
+
audio_data, sampling_rate = librosa.load(audio_path, sr=self.sample_rate, mono=mono, duration=duration)
|
242 |
+
except Exception as e:
|
243 |
+
raise RuntimeError(f"Fail to load audio from {audio_path}, error info [{e}]")
|
244 |
+
return audio_data, sampling_rate
|
245 |
+
|
246 |
+
def prepare_audio_data(self, audio_data: Union[np.ndarray, torch.Tensor], n_frames: Optional[int]=None) -> Tuple[List[Any], int]:
|
247 |
+
"""Prepare audio data for processing.
|
248 |
+
"""
|
249 |
+
clip_len = int(len(audio_data) / self.audio_unit)
|
250 |
+
if n_frames is not None:
|
251 |
+
if abs(n_frames - clip_len) > 2:
|
252 |
+
log(f"The number of frames must be close to the clip length (in 80ms), got {n_frames} and {clip_len}")
|
253 |
+
return [], n_frames
|
254 |
+
clip_len = n_frames
|
255 |
+
else:
|
256 |
+
n_frames = clip_len
|
257 |
+
|
258 |
+
# normalize audio, replace Wav2Vec2FeatureExtractor
|
259 |
+
if isinstance(audio_data, np.ndarray):
|
260 |
+
audio_data = torch.from_numpy(audio_data).to(self.device)
|
261 |
+
assert audio_data.ndim == 1, 'Audio must be 1D tensor.'
|
262 |
+
audio_data = (audio_data - torch.mean(audio_data)) / (torch.std(audio_data) + 1e-7)
|
263 |
+
#log(f"audio loaded! {audio_data.shape}")
|
264 |
+
|
265 |
+
# padding
|
266 |
+
# padding audio to fit the clip length
|
267 |
+
n_audio_samples = round(self.audio_unit * clip_len)
|
268 |
+
n_padding_audio_samples = n_audio_samples - len(audio_data)
|
269 |
+
n_padding_frames = math.ceil(n_padding_audio_samples / self.audio_unit)
|
270 |
+
if n_padding_audio_samples > 0:
|
271 |
+
if self.pad_mode == 'zero':
|
272 |
+
padding_value = 0
|
273 |
+
elif self.pad_mode == 'replicate':
|
274 |
+
padding_value = float(audio_data[-1])
|
275 |
+
else:
|
276 |
+
raise ValueError(f'Unknown pad mode: {self.pad_mode}')
|
277 |
+
audio_data = F.pad(audio_data, (0, n_padding_audio_samples), value=padding_value)
|
278 |
+
|
279 |
+
# devide audio into sub-divisions for saving GPU memory
|
280 |
+
audio_segments = []
|
281 |
+
if clip_len <= self.subclip_len:
|
282 |
+
n_subdivision = 1
|
283 |
+
subclip_len = clip_len
|
284 |
+
else:
|
285 |
+
n_subdivision = math.ceil(clip_len / self.subclip_len)
|
286 |
+
subclip_len = self.subclip_len
|
287 |
+
|
288 |
+
for i in range(0, n_subdivision):
|
289 |
+
start_idx = i * subclip_len
|
290 |
+
end_idx = min(start_idx + subclip_len, clip_len)
|
291 |
+
# debug
|
292 |
+
#log(f"[{i+1}/{n_subdivision}] data index [{round(start_idx * self.audio_unit)}, {round(end_idx * self.audio_unit)})")
|
293 |
+
audio_segments.append(
|
294 |
+
{
|
295 |
+
"data": audio_data[round(start_idx * self.audio_unit):round(end_idx * self.audio_unit)].unsqueeze(0),
|
296 |
+
"start_idx": start_idx,
|
297 |
+
"end_idx": end_idx,
|
298 |
+
"length": end_idx - start_idx
|
299 |
+
}
|
300 |
+
)
|
301 |
+
return audio_segments, n_frames
|
302 |
+
|
303 |
+
def get_audio_embedding(self, audio, clip_len: int) -> torch.Tensor:
|
304 |
+
if audio.ndim == 2:
|
305 |
+
# Extract audio features
|
306 |
+
assert audio.shape[1] == 16000 * clip_len / self.fps, \
|
307 |
+
f'Incorrect audio length {audio.shape[1]}'
|
308 |
+
|
309 |
+
# Extract audio features
|
310 |
+
if self.use_half:
|
311 |
+
audio = audio.half()
|
312 |
+
embeddings = self.audio_encoder(
|
313 |
+
pad_audio(audio), seq_len=clip_len, sample_strategy=self.sample_strategy, output_hidden_states=True
|
314 |
+
) # (N, L, 768)
|
315 |
+
assert len(embeddings) > 0, "Fail to extract audio embedding"
|
316 |
+
|
317 |
+
if self.only_last_features:
|
318 |
+
audio_emb = embeddings.last_hidden_state.squeeze(0)
|
319 |
+
else:
|
320 |
+
audio_emb = torch.stack(
|
321 |
+
embeddings.hidden_states[1:], dim=1
|
322 |
+
).squeeze(0)
|
323 |
+
audio_emb = rearrange(audio_emb, "b s d -> s b d")
|
324 |
+
|
325 |
+
elif audio.ndim == 3:
|
326 |
+
assert audio.shape[1] == clip_len, f'Incorrect audio feature length {audio.shape[1]}'
|
327 |
+
audio_emb = audio
|
328 |
+
else:
|
329 |
+
raise ValueError(f'Incorrect audio input shape {audio.shape}')
|
330 |
+
|
331 |
+
return audio_emb
|
332 |
+
|
333 |
+
def get_audio_embeddings(self, audio_segments: List[Any]) -> Optional[torch.Tensor]:
|
334 |
+
audio_embs = []
|
335 |
+
for audio_segment in audio_segments:
|
336 |
+
if self.is_training:
|
337 |
+
audio_emb = self.get_audio_embedding(audio_segment["data"], audio_segment["length"])
|
338 |
+
else:
|
339 |
+
with torch.no_grad():
|
340 |
+
audio_emb = self.get_audio_embedding(audio_segment["data"], audio_segment["length"])
|
341 |
+
|
342 |
+
audio_emb = audio_emb.cpu() if self.save_to_cpu else audio_emb
|
343 |
+
audio_embs.append(audio_emb)
|
344 |
+
#log(f"audio segment [{audio_segment['start_idx']}, {audio_segment['end_idx']}) has been processed.")
|
345 |
+
|
346 |
+
if len(audio_embs) == 0:
|
347 |
+
return None
|
348 |
+
|
349 |
+
audio_emb = torch.cat(audio_embs, dim=0)
|
350 |
+
|
351 |
+
return audio_emb
|
352 |
+
|
353 |
+
def preprocess(
|
354 |
+
self,
|
355 |
+
audio_path: str,
|
356 |
+
n_frames: Optional[int] = None,
|
357 |
+
duration: Optional[float] = None,
|
358 |
+
need_seperate: bool = False
|
359 |
+
):
|
360 |
+
""" Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate.
|
361 |
+
The separated vocal track is then converted into wav2vec2 for further processing or analysis.
|
362 |
+
"""
|
363 |
+
if need_seperate:
|
364 |
+
vocal_audio_file = self.seperate_audio(audio_path)
|
365 |
+
else:
|
366 |
+
vocal_audio_file = audio_path
|
367 |
+
|
368 |
+
audio_data, sampling_rate = self.load_audio(vocal_audio_file, duration=duration)
|
369 |
+
|
370 |
+
assert sampling_rate == 16000, "The sample rate of audio must be 16000"
|
371 |
+
audio_segments, n_frames = self.prepare_audio_data(audio_data, n_frames)
|
372 |
+
audio_emb = self.get_audio_embeddings(audio_segments)
|
373 |
+
if audio_emb is None:
|
374 |
+
log(f"{audio_path} has been processed, but no audio embedding, set as 'None'.")
|
375 |
+
#else:
|
376 |
+
#log(f"{audio_path} has been processed, audio embedding shape {audio_emb.shape}.")
|
377 |
+
return audio_emb, n_frames
|
378 |
+
|
379 |
+
def preprocess_long(
|
380 |
+
self,
|
381 |
+
audio_path: str,
|
382 |
+
need_seperate: bool = False
|
383 |
+
):
|
384 |
+
audio_list = cut_audio(audio_path, self.tmp_dir, length=self.max_length)
|
385 |
+
audio_emb_list = []
|
386 |
+
l = 0
|
387 |
+
|
388 |
+
for idx, audio_path in enumerate(audio_list):
|
389 |
+
padding = (idx+1) == len(audio_list)
|
390 |
+
emb, length = self.preprocess(audio_path, need_seperate=need_seperate)
|
391 |
+
audio_emb_list.append(emb)
|
392 |
+
log(f"Processing audio {idx+1}/{len(audio_list)}, path: {audio_path} length: {length}")
|
393 |
+
l += length
|
394 |
+
|
395 |
+
audio_emb = torch.cat(audio_emb_list)
|
396 |
+
audio_length = l
|
397 |
+
|
398 |
+
# remove tmp file
|
399 |
+
for audio_path in audio_list:
|
400 |
+
os.remove(audio_path)
|
401 |
+
|
402 |
+
return audio_emb, audio_length
|
403 |
+
|
404 |
+
def __enter__(self):
|
405 |
+
return self
|
406 |
+
|
407 |
+
|
src/models/audio/audio_proj.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module provides the implementation of an Audio Projection Model, which is designed for
|
3 |
+
audio processing tasks. The model takes audio embeddings as input and outputs context tokens
|
4 |
+
that can be used for various downstream applications, such as audio analysis or synthesis.
|
5 |
+
|
6 |
+
The AudioProjModel class is based on the ModelMixin class from the diffusers library, which
|
7 |
+
provides a foundation for building custom models. This implementation includes multiple linear
|
8 |
+
layers with ReLU activation functions and a LayerNorm for normalization.
|
9 |
+
|
10 |
+
Key Features:
|
11 |
+
- Audio embedding input with flexible sequence length and block structure.
|
12 |
+
- Multiple linear layers for feature transformation.
|
13 |
+
- ReLU activation for non-linear transformation.
|
14 |
+
- LayerNorm for stabilizing and speeding up training.
|
15 |
+
- Rearrangement of input embeddings to match the model's expected input shape.
|
16 |
+
- Customizable number of blocks, channels, and context tokens for adaptability.
|
17 |
+
|
18 |
+
The module is structured to be easily integrated into larger systems or used as a standalone
|
19 |
+
component for audio feature extraction and processing.
|
20 |
+
|
21 |
+
Classes:
|
22 |
+
- AudioProjModel: A class representing the audio projection model with configurable parameters.
|
23 |
+
|
24 |
+
Functions:
|
25 |
+
- (none)
|
26 |
+
|
27 |
+
Dependencies:
|
28 |
+
- torch: For tensor operations and neural network components.
|
29 |
+
- diffusers: For the ModelMixin base class.
|
30 |
+
- einops: For tensor rearrangement operations.
|
31 |
+
|
32 |
+
"""
|
33 |
+
|
34 |
+
import torch
|
35 |
+
from diffusers import ModelMixin
|
36 |
+
from einops import rearrange
|
37 |
+
from torch import nn
|
38 |
+
|
39 |
+
|
40 |
+
class AudioProjModel(ModelMixin):
|
41 |
+
"""Audio Projection Model
|
42 |
+
|
43 |
+
This class defines an audio projection model that takes audio embeddings as input
|
44 |
+
and produces context tokens as output. The model is based on the ModelMixin class
|
45 |
+
and consists of multiple linear layers and activation functions. It can be used
|
46 |
+
for various audio processing tasks.
|
47 |
+
|
48 |
+
Attributes:
|
49 |
+
seq_len (int): The length of the audio sequence.
|
50 |
+
blocks (int): The number of blocks in the audio projection model.
|
51 |
+
channels (int): The number of channels in the audio projection model.
|
52 |
+
intermediate_dim (int): The intermediate dimension of the model.
|
53 |
+
context_tokens (int): The number of context tokens in the output.
|
54 |
+
output_dim (int): The output dimension of the context tokens.
|
55 |
+
|
56 |
+
Methods:
|
57 |
+
__init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768):
|
58 |
+
Initializes the AudioProjModel with the given parameters.
|
59 |
+
forward(self, audio_embeds):
|
60 |
+
Defines the forward pass for the AudioProjModel.
|
61 |
+
Parameters:
|
62 |
+
audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
|
63 |
+
Returns:
|
64 |
+
context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
|
65 |
+
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
seq_len=5,
|
71 |
+
blocks=12, # add a new parameter blocks
|
72 |
+
channels=768, # add a new parameter channels
|
73 |
+
intermediate_dim=512,
|
74 |
+
output_dim=768,
|
75 |
+
context_tokens=32,
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
|
79 |
+
self.seq_len = seq_len
|
80 |
+
self.blocks = blocks
|
81 |
+
self.channels = channels
|
82 |
+
self.input_dim = (
|
83 |
+
seq_len * blocks * channels
|
84 |
+
) # update input_dim to be the product of blocks and channels.
|
85 |
+
self.intermediate_dim = intermediate_dim
|
86 |
+
self.context_tokens = context_tokens
|
87 |
+
self.output_dim = output_dim
|
88 |
+
|
89 |
+
# define multiple linear layers
|
90 |
+
self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
|
91 |
+
self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
|
92 |
+
self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
|
93 |
+
|
94 |
+
self.norm = nn.LayerNorm(output_dim)
|
95 |
+
|
96 |
+
def forward(self, audio_embeds):
|
97 |
+
"""
|
98 |
+
Defines the forward pass for the AudioProjModel.
|
99 |
+
|
100 |
+
Parameters:
|
101 |
+
audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
|
105 |
+
"""
|
106 |
+
# merge
|
107 |
+
video_length = audio_embeds.shape[1]
|
108 |
+
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
|
109 |
+
batch_size, window_size, blocks, channels = audio_embeds.shape
|
110 |
+
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
|
111 |
+
|
112 |
+
audio_embeds = torch.relu(self.proj1(audio_embeds))
|
113 |
+
audio_embeds = torch.relu(self.proj2(audio_embeds))
|
114 |
+
|
115 |
+
context_tokens = self.proj3(audio_embeds).reshape(
|
116 |
+
batch_size, self.context_tokens, self.output_dim
|
117 |
+
)
|
118 |
+
|
119 |
+
context_tokens = self.norm(context_tokens)
|
120 |
+
context_tokens = rearrange(
|
121 |
+
context_tokens, "(bz f) m c -> bz f m c", f=video_length
|
122 |
+
)
|
123 |
+
|
124 |
+
return context_tokens
|
src/models/audio/hubert.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from transformers import HubertModel
|
7 |
+
from transformers.modeling_outputs import BaseModelOutput
|
8 |
+
|
9 |
+
|
10 |
+
_CONFIG_FOR_DOC = 'HubertConfig'
|
11 |
+
|
12 |
+
|
13 |
+
def linear_interpolation(features, seq_len):
|
14 |
+
"""
|
15 |
+
Transpose the features to interpolate linearly.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
features (torch.Tensor): The extracted features to be interpolated.
|
19 |
+
seq_len (torch.Tensor): The sequence lengths of the features.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
torch.Tensor: The interpolated features.
|
23 |
+
"""
|
24 |
+
features = features.transpose(1, 2)
|
25 |
+
output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
|
26 |
+
return output_features.transpose(1, 2)
|
27 |
+
|
28 |
+
|
29 |
+
class HubertModel_(HubertModel):
|
30 |
+
def __init__(self, config):
|
31 |
+
super().__init__(config)
|
32 |
+
|
33 |
+
def forward(
|
34 |
+
self,
|
35 |
+
input_values: Optional[torch.Tensor],
|
36 |
+
seq_len: Optional[int],
|
37 |
+
sample_strategy: Optional[str] = "presample",
|
38 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
39 |
+
mask_time_indices: Optional[torch.FloatTensor] = None,
|
40 |
+
output_attentions: Optional[bool] = None,
|
41 |
+
output_hidden_states: Optional[bool] = None,
|
42 |
+
return_dict: Optional[bool] = None,
|
43 |
+
):
|
44 |
+
"""
|
45 |
+
Forward pass of the HuBERT model.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
self: The instance of the model.
|
49 |
+
input_values: The input values (waveform) to the model.
|
50 |
+
seq_len: The sequence length of the input values.
|
51 |
+
sample_strategy: The sample strategy to align features and seq_len, supports ['presample', 'postsample'].
|
52 |
+
attention_mask: Attention mask to be used for the model.
|
53 |
+
mask_time_indices: Mask indices to be used for the model.
|
54 |
+
output_attentions: If set to True, returns attentions.
|
55 |
+
output_hidden_states: If set to True, returns hidden states.
|
56 |
+
return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
The output of the HuBERT model.
|
60 |
+
"""
|
61 |
+
# output_fps=25,
|
62 |
+
# attention_mask=None,
|
63 |
+
# output_attentions=None,
|
64 |
+
# output_hidden_states=None,
|
65 |
+
# return_dict=None,
|
66 |
+
# frame_num=None
|
67 |
+
assert sample_strategy in ["presample", "postsample"], f"sample_strategy must be in ['presample', 'postsample]"
|
68 |
+
self.config.output_attentions = True
|
69 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
70 |
+
|
71 |
+
output_hidden_states = (
|
72 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
|
73 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
74 |
+
|
75 |
+
extract_features = self.feature_extractor(input_values) # (N, C, L)
|
76 |
+
extract_features = extract_features.transpose(1, 2)
|
77 |
+
if sample_strategy == "presample":
|
78 |
+
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
79 |
+
|
80 |
+
# # Resample the audio feature @ 50 fps to `output_fps`.
|
81 |
+
# if frame_num is not None:
|
82 |
+
# extract_features_len = round(frame_num * 50 / output_fps)
|
83 |
+
# extract_features = extract_features[:, :, :extract_features_len]
|
84 |
+
# extract_features = linear_interpolation(extract_features, 50, output_fps, output_len=frame_num)
|
85 |
+
# extract_features = extract_features.transpose(1, 2) # (N, L, C)
|
86 |
+
|
87 |
+
if attention_mask is not None:
|
88 |
+
# compute reduced attention_mask corresponding to feature vectors
|
89 |
+
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
|
90 |
+
|
91 |
+
hidden_states = self.feature_projection(extract_features)
|
92 |
+
hidden_states = self._mask_hidden_states(
|
93 |
+
hidden_states,
|
94 |
+
mask_time_indices=mask_time_indices,
|
95 |
+
attention_mask=attention_mask
|
96 |
+
)
|
97 |
+
|
98 |
+
encoder_outputs = self.encoder(
|
99 |
+
hidden_states,
|
100 |
+
attention_mask=attention_mask,
|
101 |
+
output_attentions=output_attentions,
|
102 |
+
output_hidden_states=output_hidden_states,
|
103 |
+
return_dict=return_dict,
|
104 |
+
)
|
105 |
+
|
106 |
+
hidden_states = encoder_outputs[0]
|
107 |
+
|
108 |
+
if sample_strategy == "postsample":
|
109 |
+
hidden_states = linear_interpolation(hidden_states, seq_len=seq_len)
|
110 |
+
for i in range(len(encoder_outputs.hidden_states)):
|
111 |
+
encoder_outputs.hidden_states[i] = linear_interpolation(encoder_outputs.hidden_states[i], seq_len=seq_len)
|
112 |
+
|
113 |
+
if not return_dict:
|
114 |
+
return (hidden_states,) + encoder_outputs[1:]
|
115 |
+
|
116 |
+
return BaseModelOutput(
|
117 |
+
last_hidden_state=hidden_states,
|
118 |
+
hidden_states=encoder_outputs.hidden_states,
|
119 |
+
attentions=encoder_outputs.attentions,
|
120 |
+
)
|
src/models/audio/hubert2.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from transformers import HubertModel
|
7 |
+
from transformers.modeling_outputs import BaseModelOutput
|
8 |
+
|
9 |
+
|
10 |
+
_CONFIG_FOR_DOC = 'HubertConfig'
|
11 |
+
|
12 |
+
|
13 |
+
def linear_interpolation(features, seq_len):
|
14 |
+
"""
|
15 |
+
Transpose the features to interpolate linearly.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
features (torch.Tensor): The extracted features to be interpolated.
|
19 |
+
seq_len (torch.Tensor): The sequence lengths of the features.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
torch.Tensor: The interpolated features.
|
23 |
+
"""
|
24 |
+
features = features.transpose(1, 2)
|
25 |
+
output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
|
26 |
+
return output_features.transpose(1, 2)
|
27 |
+
|
28 |
+
|
29 |
+
class HubertModel(HubertModel):
|
30 |
+
def __init__(self, config):
|
31 |
+
super().__init__(config)
|
32 |
+
|
33 |
+
def forward(
|
34 |
+
self,
|
35 |
+
input_values: Optional[torch.Tensor],
|
36 |
+
seq_len: Optional[int],
|
37 |
+
sample_strategy: Optional[str] = "presample",
|
38 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
39 |
+
mask_time_indices: Optional[torch.FloatTensor] = None,
|
40 |
+
output_attentions: Optional[bool] = None,
|
41 |
+
output_hidden_states: Optional[bool] = None,
|
42 |
+
return_dict: Optional[bool] = None,
|
43 |
+
):
|
44 |
+
"""
|
45 |
+
Forward pass of the HuBERT model.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
self: The instance of the model.
|
49 |
+
input_values: The input values (waveform) to the model.
|
50 |
+
seq_len: The sequence length of the input values.
|
51 |
+
sample_strategy: The sample strategy to align features and seq_len, supports ['presample', 'postsample'].
|
52 |
+
attention_mask: Attention mask to be used for the model.
|
53 |
+
mask_time_indices: Mask indices to be used for the model.
|
54 |
+
output_attentions: If set to True, returns attentions.
|
55 |
+
output_hidden_states: If set to True, returns hidden states.
|
56 |
+
return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
The output of the HuBERT model.
|
60 |
+
"""
|
61 |
+
# output_fps=25,
|
62 |
+
# attention_mask=None,
|
63 |
+
# output_attentions=None,
|
64 |
+
# output_hidden_states=None,
|
65 |
+
# return_dict=None,
|
66 |
+
# frame_num=None
|
67 |
+
assert sample_strategy in ["presample", "postsample"], f"sample_strategy must be in ['presample', 'postsample]"
|
68 |
+
self.config.output_attentions = True
|
69 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
70 |
+
|
71 |
+
output_hidden_states = (
|
72 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
|
73 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
74 |
+
|
75 |
+
extract_features = self.feature_extractor(input_values) # (N, C, L)
|
76 |
+
extract_features = extract_features.transpose(1, 2)
|
77 |
+
if sample_strategy == "presample":
|
78 |
+
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
79 |
+
|
80 |
+
# # Resample the audio feature @ 50 fps to `output_fps`.
|
81 |
+
# if frame_num is not None:
|
82 |
+
# extract_features_len = round(frame_num * 50 / output_fps)
|
83 |
+
# extract_features = extract_features[:, :, :extract_features_len]
|
84 |
+
# extract_features = linear_interpolation(extract_features, 50, output_fps, output_len=frame_num)
|
85 |
+
# extract_features = extract_features.transpose(1, 2) # (N, L, C)
|
86 |
+
|
87 |
+
if attention_mask is not None:
|
88 |
+
# compute reduced attention_mask corresponding to feature vectors
|
89 |
+
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
|
90 |
+
|
91 |
+
hidden_states = self.feature_projection(extract_features)
|
92 |
+
hidden_states = self._mask_hidden_states(
|
93 |
+
hidden_states,
|
94 |
+
mask_time_indices=mask_time_indices,
|
95 |
+
attention_mask=attention_mask
|
96 |
+
)
|
97 |
+
|
98 |
+
encoder_outputs = self.encoder(
|
99 |
+
hidden_states,
|
100 |
+
attention_mask=attention_mask,
|
101 |
+
output_attentions=output_attentions,
|
102 |
+
output_hidden_states=output_hidden_states,
|
103 |
+
return_dict=return_dict,
|
104 |
+
)
|
105 |
+
|
106 |
+
hidden_states = encoder_outputs[0]
|
107 |
+
|
108 |
+
if sample_strategy == "postsample":
|
109 |
+
hidden_states = linear_interpolation(hidden_states, seq_len=seq_len)
|
110 |
+
for i in range(len(encoder_outputs.hidden_states)):
|
111 |
+
encoder_outputs.hidden_states[i] = linear_interpolation(encoder_outputs.hidden_states[i], seq_len=seq_len)
|
112 |
+
|
113 |
+
if not return_dict:
|
114 |
+
return (hidden_states,) + encoder_outputs[1:]
|
115 |
+
|
116 |
+
return BaseModelOutput(
|
117 |
+
last_hidden_state=hidden_states,
|
118 |
+
hidden_states=encoder_outputs.hidden_states,
|
119 |
+
attentions=encoder_outputs.attentions,
|
120 |
+
)
|
src/models/audio/wav2vec.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
"""
|
4 |
+
This module defines the Wav2Vec model, which is a pre-trained model for speech recognition and understanding.
|
5 |
+
It inherits from the Wav2Vec2Model class in the transformers library and provides additional functionalities
|
6 |
+
such as feature extraction and encoding.
|
7 |
+
|
8 |
+
Classes:
|
9 |
+
Wav2VecModel: Inherits from Wav2Vec2Model and adds additional methods for feature extraction and encoding.
|
10 |
+
|
11 |
+
Functions:
|
12 |
+
linear_interpolation: Interpolates the features based on the sequence length.
|
13 |
+
"""
|
14 |
+
|
15 |
+
from typing import Optional, Tuple, Union
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from transformers import Wav2Vec2Model
|
19 |
+
from transformers.modeling_outputs import BaseModelOutput
|
20 |
+
|
21 |
+
|
22 |
+
class Wav2VecModel(Wav2Vec2Model):
|
23 |
+
"""
|
24 |
+
Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library.
|
25 |
+
It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding.
|
26 |
+
...
|
27 |
+
|
28 |
+
Attributes:
|
29 |
+
base_model (Wav2Vec2Model): The base Wav2Vec2Model object.
|
30 |
+
|
31 |
+
Methods:
|
32 |
+
forward(input_values, seq_len, attention_mask=None, mask_time_indices=None
|
33 |
+
, output_attentions=None, output_hidden_states=None, return_dict=None):
|
34 |
+
Forward pass of the Wav2VecModel.
|
35 |
+
It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model.
|
36 |
+
|
37 |
+
feature_extract(input_values, seq_len):
|
38 |
+
Extracts features from the input_values using the base model.
|
39 |
+
|
40 |
+
encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None):
|
41 |
+
Encodes the extracted features using the base model and returns the encoded features.
|
42 |
+
"""
|
43 |
+
def forward(
|
44 |
+
self,
|
45 |
+
input_values,
|
46 |
+
seq_len,
|
47 |
+
attention_mask=None,
|
48 |
+
mask_time_indices=None,
|
49 |
+
output_attentions=None,
|
50 |
+
output_hidden_states=None,
|
51 |
+
return_dict=None,
|
52 |
+
):
|
53 |
+
"""
|
54 |
+
Forward pass of the Wav2Vec model.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
self: The instance of the model.
|
58 |
+
input_values: The input values (waveform) to the model.
|
59 |
+
seq_len: The sequence length of the input values.
|
60 |
+
attention_mask: Attention mask to be used for the model.
|
61 |
+
mask_time_indices: Mask indices to be used for the model.
|
62 |
+
output_attentions: If set to True, returns attentions.
|
63 |
+
output_hidden_states: If set to True, returns hidden states.
|
64 |
+
return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
The output of the Wav2Vec model.
|
68 |
+
"""
|
69 |
+
self.config.output_attentions = True
|
70 |
+
|
71 |
+
output_hidden_states = (
|
72 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
73 |
+
)
|
74 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
75 |
+
|
76 |
+
extract_features = self.feature_extractor(input_values)
|
77 |
+
extract_features = extract_features.transpose(1, 2)
|
78 |
+
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
79 |
+
|
80 |
+
if attention_mask is not None:
|
81 |
+
# compute reduced attention_mask corresponding to feature vectors
|
82 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
83 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
84 |
+
)
|
85 |
+
|
86 |
+
hidden_states, extract_features = self.feature_projection(extract_features)
|
87 |
+
hidden_states = self._mask_hidden_states(
|
88 |
+
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
89 |
+
)
|
90 |
+
|
91 |
+
encoder_outputs = self.encoder(
|
92 |
+
hidden_states,
|
93 |
+
attention_mask=attention_mask,
|
94 |
+
output_attentions=output_attentions,
|
95 |
+
output_hidden_states=output_hidden_states,
|
96 |
+
return_dict=return_dict,
|
97 |
+
)
|
98 |
+
|
99 |
+
hidden_states = encoder_outputs[0]
|
100 |
+
|
101 |
+
if self.adapter is not None:
|
102 |
+
hidden_states = self.adapter(hidden_states)
|
103 |
+
|
104 |
+
if not return_dict:
|
105 |
+
return (hidden_states, ) + encoder_outputs[1:]
|
106 |
+
return BaseModelOutput(
|
107 |
+
last_hidden_state=hidden_states,
|
108 |
+
hidden_states=encoder_outputs.hidden_states,
|
109 |
+
attentions=encoder_outputs.attentions,
|
110 |
+
)
|
111 |
+
|
112 |
+
|
113 |
+
def feature_extract(
|
114 |
+
self,
|
115 |
+
input_values,
|
116 |
+
seq_len,
|
117 |
+
):
|
118 |
+
"""
|
119 |
+
Extracts features from the input values and returns the extracted features.
|
120 |
+
|
121 |
+
Parameters:
|
122 |
+
input_values (torch.Tensor): The input values to be processed.
|
123 |
+
seq_len (torch.Tensor): The sequence lengths of the input values.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
extracted_features (torch.Tensor): The extracted features from the input values.
|
127 |
+
"""
|
128 |
+
extract_features = self.feature_extractor(input_values)
|
129 |
+
extract_features = extract_features.transpose(1, 2)
|
130 |
+
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
131 |
+
|
132 |
+
return extract_features
|
133 |
+
|
134 |
+
def encode(
|
135 |
+
self,
|
136 |
+
extract_features,
|
137 |
+
attention_mask=None,
|
138 |
+
mask_time_indices=None,
|
139 |
+
output_attentions=None,
|
140 |
+
output_hidden_states=None,
|
141 |
+
return_dict=None,
|
142 |
+
):
|
143 |
+
"""
|
144 |
+
Encodes the input features into the output space.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
extract_features (torch.Tensor): The extracted features from the audio signal.
|
148 |
+
attention_mask (torch.Tensor, optional): Attention mask to be used for padding.
|
149 |
+
mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension.
|
150 |
+
output_attentions (bool, optional): If set to True, returns the attention weights.
|
151 |
+
output_hidden_states (bool, optional): If set to True, returns all hidden states.
|
152 |
+
return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple.
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
The encoded output features.
|
156 |
+
"""
|
157 |
+
self.config.output_attentions = True
|
158 |
+
|
159 |
+
output_hidden_states = (
|
160 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
161 |
+
)
|
162 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
163 |
+
|
164 |
+
if attention_mask is not None:
|
165 |
+
# compute reduced attention_mask corresponding to feature vectors
|
166 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
167 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
168 |
+
)
|
169 |
+
|
170 |
+
hidden_states, extract_features = self.feature_projection(extract_features)
|
171 |
+
hidden_states = self._mask_hidden_states(
|
172 |
+
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
173 |
+
)
|
174 |
+
|
175 |
+
encoder_outputs = self.encoder(
|
176 |
+
hidden_states,
|
177 |
+
attention_mask=attention_mask,
|
178 |
+
output_attentions=output_attentions,
|
179 |
+
output_hidden_states=output_hidden_states,
|
180 |
+
return_dict=return_dict,
|
181 |
+
)
|
182 |
+
|
183 |
+
hidden_states = encoder_outputs[0]
|
184 |
+
|
185 |
+
if self.adapter is not None:
|
186 |
+
hidden_states = self.adapter(hidden_states)
|
187 |
+
|
188 |
+
if not return_dict:
|
189 |
+
return (hidden_states, ) + encoder_outputs[1:]
|
190 |
+
return BaseModelOutput(
|
191 |
+
last_hidden_state=hidden_states,
|
192 |
+
hidden_states=encoder_outputs.hidden_states,
|
193 |
+
attentions=encoder_outputs.attentions,
|
194 |
+
)
|
195 |
+
|
196 |
+
|
197 |
+
def linear_interpolation(features, seq_len):
|
198 |
+
"""
|
199 |
+
Transpose the features to interpolate linearly.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
features (torch.Tensor): The extracted features to be interpolated.
|
203 |
+
seq_len (torch.Tensor): The sequence lengths of the features.
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
torch.Tensor: The interpolated features.
|
207 |
+
"""
|
208 |
+
features = features.transpose(1, 2)
|
209 |
+
output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
|
210 |
+
return output_features.transpose(1, 2)
|
src/models/audio/wav2vec2.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from packaging import version
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import transformers
|
8 |
+
from transformers import Wav2Vec2Model
|
9 |
+
from transformers.modeling_outputs import BaseModelOutput
|
10 |
+
|
11 |
+
_CONFIG_FOR_DOC = 'Wav2Vec2Config'
|
12 |
+
|
13 |
+
|
14 |
+
# the implementation of Wav2Vec2Model is borrowed from
|
15 |
+
# https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model
|
16 |
+
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
|
17 |
+
def _compute_mask_indices(shape: Tuple[int, int], mask_prob: float, mask_length: int,
|
18 |
+
attention_mask: Optional[torch.Tensor] = None, min_masks: int = 0, ) -> np.ndarray:
|
19 |
+
bsz, all_sz = shape
|
20 |
+
mask = np.full((bsz, all_sz), False)
|
21 |
+
|
22 |
+
all_num_mask = int(mask_prob * all_sz / float(mask_length) + np.random.rand())
|
23 |
+
all_num_mask = max(min_masks, all_num_mask)
|
24 |
+
mask_idcs = []
|
25 |
+
padding_mask = attention_mask.ne(1) if attention_mask is not None else None
|
26 |
+
for i in range(bsz):
|
27 |
+
if padding_mask is not None:
|
28 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
29 |
+
num_mask = int(mask_prob * sz / float(mask_length) + np.random.rand())
|
30 |
+
num_mask = max(min_masks, num_mask)
|
31 |
+
else:
|
32 |
+
sz = all_sz
|
33 |
+
num_mask = all_num_mask
|
34 |
+
|
35 |
+
lengths = np.full(num_mask, mask_length)
|
36 |
+
|
37 |
+
if sum(lengths) == 0:
|
38 |
+
lengths[0] = min(mask_length, sz - 1)
|
39 |
+
|
40 |
+
min_len = min(lengths)
|
41 |
+
if sz - min_len <= num_mask:
|
42 |
+
min_len = sz - num_mask - 1
|
43 |
+
|
44 |
+
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
45 |
+
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
|
46 |
+
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
47 |
+
|
48 |
+
min_len = min([len(m) for m in mask_idcs])
|
49 |
+
for i, mask_idc in enumerate(mask_idcs):
|
50 |
+
if len(mask_idc) > min_len:
|
51 |
+
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
52 |
+
mask[i, mask_idc] = True
|
53 |
+
return mask
|
54 |
+
|
55 |
+
|
56 |
+
# linear interpolation layer
|
57 |
+
def linear_interpolation(features, input_fps, output_fps, output_len=None):
|
58 |
+
# features: (N, C, L)
|
59 |
+
seq_len = features.shape[2] / float(input_fps)
|
60 |
+
if output_len is None:
|
61 |
+
output_len = int(seq_len * output_fps)
|
62 |
+
output_features = F.interpolate(features, size=output_len, align_corners=False, mode='linear')
|
63 |
+
return output_features
|
64 |
+
|
65 |
+
|
66 |
+
class Wav2Vec2Model(Wav2Vec2Model):
|
67 |
+
def __init__(self, config):
|
68 |
+
super().__init__(config)
|
69 |
+
self.is_old_version = version.parse(transformers.__version__) < version.parse('4.7.0')
|
70 |
+
|
71 |
+
def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None,
|
72 |
+
output_hidden_states=None, return_dict=None, frame_num=None):
|
73 |
+
self.config.output_attentions = True
|
74 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
75 |
+
output_hidden_states = (
|
76 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
|
77 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
78 |
+
print(f"data shape before feature extractor: {input_values.shape}")
|
79 |
+
hidden_states = self.feature_extractor(input_values) # (N, C, L)
|
80 |
+
print(f"data shape after feature extractor: {hidden_states.shape}")
|
81 |
+
# Resample the audio feature @ 50 fps to `output_fps`.
|
82 |
+
if frame_num is not None:
|
83 |
+
hidden_states_len = round(frame_num * 50 / output_fps)
|
84 |
+
hidden_states = hidden_states[:, :, :hidden_states_len]
|
85 |
+
hidden_states = linear_interpolation(hidden_states, 50, output_fps, output_len=frame_num)
|
86 |
+
hidden_states = hidden_states.transpose(1, 2) # (N, L, C)
|
87 |
+
|
88 |
+
if attention_mask is not None:
|
89 |
+
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
|
90 |
+
attention_mask = torch.zeros(hidden_states.shape[:2], dtype=hidden_states.dtype,
|
91 |
+
device=hidden_states.device)
|
92 |
+
attention_mask[(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)] = 1
|
93 |
+
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
94 |
+
|
95 |
+
if self.is_old_version:
|
96 |
+
hidden_states = self.feature_projection(hidden_states)
|
97 |
+
else:
|
98 |
+
hidden_states = self.feature_projection(hidden_states)[0]
|
99 |
+
|
100 |
+
if self.config.apply_spec_augment and self.training:
|
101 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
102 |
+
if self.config.mask_time_prob > 0:
|
103 |
+
mask_time_indices = _compute_mask_indices((batch_size, sequence_length), self.config.mask_time_prob,
|
104 |
+
self.config.mask_time_length, attention_mask=attention_mask,
|
105 |
+
min_masks=2, )
|
106 |
+
hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype)
|
107 |
+
if self.config.mask_feature_prob > 0:
|
108 |
+
mask_feature_indices = _compute_mask_indices((batch_size, hidden_size), self.config.mask_feature_prob,
|
109 |
+
self.config.mask_feature_length, )
|
110 |
+
mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device)
|
111 |
+
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
|
112 |
+
encoder_outputs = self.encoder(hidden_states, attention_mask=attention_mask,
|
113 |
+
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
|
114 |
+
return_dict=return_dict, )
|
115 |
+
hidden_states = encoder_outputs[0]
|
116 |
+
if not return_dict:
|
117 |
+
return (hidden_states,) + encoder_outputs[1:]
|
118 |
+
|
119 |
+
for i in range(len(encoder_outputs.hidden_states)):
|
120 |
+
print(f"hidden states {i} after encoder: {encoder_outputs.hidden_states[i].shape}")
|
121 |
+
|
122 |
+
return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states,
|
123 |
+
attentions=encoder_outputs.attentions, )
|
src/models/audio/wav2vec_modified.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
"""
|
4 |
+
This module defines the Wav2Vec model, which is a pre-trained model for speech recognition and understanding.
|
5 |
+
It inherits from the Wav2Vec2Model class in the transformers library and provides additional functionalities
|
6 |
+
such as feature extraction and encoding.
|
7 |
+
|
8 |
+
Classes:
|
9 |
+
Wav2VecModel: Inherits from Wav2Vec2Model and adds additional methods for feature extraction and encoding.
|
10 |
+
|
11 |
+
Functions:
|
12 |
+
linear_interpolation: Interpolates the features based on the sequence length.
|
13 |
+
"""
|
14 |
+
|
15 |
+
from typing import Optional, Tuple, Union
|
16 |
+
import torch
|
17 |
+
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from transformers import Wav2Vec2Model
|
20 |
+
from transformers.modeling_outputs import BaseModelOutput, Wav2Vec2BaseModelOutput
|
21 |
+
|
22 |
+
|
23 |
+
class Wav2VecModel(Wav2Vec2Model):
|
24 |
+
"""
|
25 |
+
Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library.
|
26 |
+
It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding.
|
27 |
+
...
|
28 |
+
|
29 |
+
Attributes:
|
30 |
+
base_model (Wav2Vec2Model): The base Wav2Vec2Model object.
|
31 |
+
|
32 |
+
Methods:
|
33 |
+
forward(input_values, seq_len, attention_mask=None, mask_time_indices=None
|
34 |
+
, output_attentions=None, output_hidden_states=None, return_dict=None):
|
35 |
+
Forward pass of the Wav2VecModel.
|
36 |
+
It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model.
|
37 |
+
|
38 |
+
feature_extract(input_values, seq_len):
|
39 |
+
Extracts features from the input_values using the base model.
|
40 |
+
|
41 |
+
encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None):
|
42 |
+
Encodes the extracted features using the base model and returns the encoded features.
|
43 |
+
"""
|
44 |
+
def forward(
|
45 |
+
self,
|
46 |
+
input_values: Optional[torch.Tensor],
|
47 |
+
seq_len: Optional[int],
|
48 |
+
sample_strategy: Optional[str] = "presample",
|
49 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
50 |
+
mask_time_indices: Optional[torch.FloatTensor] = None,
|
51 |
+
output_attentions: Optional[bool] = None,
|
52 |
+
output_hidden_states: Optional[bool] = None,
|
53 |
+
return_dict: Optional[bool] = None,
|
54 |
+
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
|
55 |
+
"""
|
56 |
+
Forward pass of the Wav2Vec model.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
self: The instance of the model.
|
60 |
+
input_values: The input values (waveform) to the model.
|
61 |
+
seq_len: The sequence length of the input values.
|
62 |
+
sample_strategy: The sample strategy to align features and seq_len, supports ['presample', 'postsample'].
|
63 |
+
attention_mask: Attention mask to be used for the model.
|
64 |
+
mask_time_indices: Mask indices to be used for the model.
|
65 |
+
output_attentions: If set to True, returns attentions.
|
66 |
+
output_hidden_states: If set to True, returns hidden states.
|
67 |
+
return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
The output of the Wav2Vec model.
|
71 |
+
"""
|
72 |
+
assert sample_strategy in ["presample", "postsample"], f"sample_strategy must be in ['presample', 'postsample]"
|
73 |
+
|
74 |
+
self.config.output_attentions = True
|
75 |
+
|
76 |
+
output_hidden_states = (
|
77 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
78 |
+
)
|
79 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
80 |
+
|
81 |
+
extract_features = self.feature_extractor(input_values)
|
82 |
+
extract_features = extract_features.transpose(1, 2)
|
83 |
+
if sample_strategy == "presample":
|
84 |
+
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
85 |
+
|
86 |
+
if attention_mask is not None:
|
87 |
+
# compute reduced attention_mask corresponding to feature vectors
|
88 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
89 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
90 |
+
)
|
91 |
+
|
92 |
+
hidden_states, extract_features = self.feature_projection(extract_features)
|
93 |
+
hidden_states = self._mask_hidden_states(
|
94 |
+
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
95 |
+
)
|
96 |
+
|
97 |
+
encoder_outputs = self.encoder(
|
98 |
+
hidden_states,
|
99 |
+
attention_mask=attention_mask,
|
100 |
+
output_attentions=output_attentions,
|
101 |
+
output_hidden_states=output_hidden_states,
|
102 |
+
return_dict=return_dict,
|
103 |
+
)
|
104 |
+
|
105 |
+
hidden_states = encoder_outputs[0]
|
106 |
+
|
107 |
+
if self.adapter is not None:
|
108 |
+
hidden_states = self.adapter(hidden_states)
|
109 |
+
|
110 |
+
if sample_strategy == "postsample":
|
111 |
+
hidden_states = linear_interpolation(hidden_states, seq_len=seq_len)
|
112 |
+
for i in range(len(encoder_outputs.hidden_states)):
|
113 |
+
encoder_outputs.hidden_states[i] = linear_interpolation(encoder_outputs.hidden_states[i], seq_len=seq_len)
|
114 |
+
|
115 |
+
if not return_dict:
|
116 |
+
return (hidden_states, extract_features) + encoder_outputs[1:]
|
117 |
+
|
118 |
+
return Wav2Vec2BaseModelOutput(
|
119 |
+
last_hidden_state=hidden_states,
|
120 |
+
extract_features=extract_features,
|
121 |
+
hidden_states=encoder_outputs.hidden_states,
|
122 |
+
attentions=encoder_outputs.attentions,
|
123 |
+
)
|
124 |
+
|
125 |
+
|
126 |
+
def feature_extract(
|
127 |
+
self,
|
128 |
+
input_values,
|
129 |
+
seq_len,
|
130 |
+
):
|
131 |
+
"""
|
132 |
+
Extracts features from the input values and returns the extracted features.
|
133 |
+
|
134 |
+
Parameters:
|
135 |
+
input_values (torch.Tensor): The input values to be processed.
|
136 |
+
seq_len (torch.Tensor): The sequence lengths of the input values.
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
extracted_features (torch.Tensor): The extracted features from the input values.
|
140 |
+
"""
|
141 |
+
extract_features = self.feature_extractor(input_values)
|
142 |
+
extract_features = extract_features.transpose(1, 2)
|
143 |
+
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
144 |
+
|
145 |
+
return extract_features
|
146 |
+
|
147 |
+
def encode(
|
148 |
+
self,
|
149 |
+
extract_features,
|
150 |
+
attention_mask=None,
|
151 |
+
mask_time_indices=None,
|
152 |
+
output_attentions=None,
|
153 |
+
output_hidden_states=None,
|
154 |
+
return_dict=None,
|
155 |
+
):
|
156 |
+
"""
|
157 |
+
Encodes the input features into the output space.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
extract_features (torch.Tensor): The extracted features from the audio signal.
|
161 |
+
attention_mask (torch.Tensor, optional): Attention mask to be used for padding.
|
162 |
+
mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension.
|
163 |
+
output_attentions (bool, optional): If set to True, returns the attention weights.
|
164 |
+
output_hidden_states (bool, optional): If set to True, returns all hidden states.
|
165 |
+
return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
The encoded output features.
|
169 |
+
"""
|
170 |
+
self.config.output_attentions = True
|
171 |
+
|
172 |
+
output_hidden_states = (
|
173 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
174 |
+
)
|
175 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
176 |
+
|
177 |
+
if attention_mask is not None:
|
178 |
+
# compute reduced attention_mask corresponding to feature vectors
|
179 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
180 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
181 |
+
)
|
182 |
+
|
183 |
+
hidden_states, extract_features = self.feature_projection(extract_features)
|
184 |
+
hidden_states = self._mask_hidden_states(
|
185 |
+
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
186 |
+
)
|
187 |
+
|
188 |
+
encoder_outputs = self.encoder(
|
189 |
+
hidden_states,
|
190 |
+
attention_mask=attention_mask,
|
191 |
+
output_attentions=output_attentions,
|
192 |
+
output_hidden_states=output_hidden_states,
|
193 |
+
return_dict=return_dict,
|
194 |
+
)
|
195 |
+
|
196 |
+
hidden_states = encoder_outputs[0]
|
197 |
+
|
198 |
+
if self.adapter is not None:
|
199 |
+
hidden_states = self.adapter(hidden_states)
|
200 |
+
|
201 |
+
if not return_dict:
|
202 |
+
return (hidden_states, ) + encoder_outputs[1:]
|
203 |
+
return BaseModelOutput(
|
204 |
+
last_hidden_state=hidden_states,
|
205 |
+
hidden_states=encoder_outputs.hidden_states,
|
206 |
+
attentions=encoder_outputs.attentions,
|
207 |
+
)
|
208 |
+
|
209 |
+
|
210 |
+
def linear_interpolation(features, seq_len):
|
211 |
+
"""
|
212 |
+
Transpose the features to interpolate linearly.
|
213 |
+
|
214 |
+
Args:
|
215 |
+
features (torch.Tensor): The extracted features to be interpolated.
|
216 |
+
seq_len (torch.Tensor): The sequence lengths of the features.
|
217 |
+
|
218 |
+
Returns:
|
219 |
+
torch.Tensor: The interpolated features.
|
220 |
+
"""
|
221 |
+
features = features.transpose(1, 2)
|
222 |
+
output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
|
223 |
+
return output_features.transpose(1, 2)
|