Delete main
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- main/app/app.py +0 -0
- main/app/parser.py +0 -340
- main/app/tensorboard.py +0 -30
- main/configs/config.json +0 -547
- main/configs/config.py +0 -90
- main/configs/decrypt.bin +0 -3
- main/configs/v1/32000.json +0 -46
- main/configs/v1/40000.json +0 -46
- main/configs/v1/48000.json +0 -46
- main/configs/v2/32000.json +0 -42
- main/configs/v2/40000.json +0 -42
- main/configs/v2/48000.json +0 -42
- main/inference/audio_effects.py +0 -180
- main/inference/audioldm2.py +0 -210
- main/inference/convert.py +0 -590
- main/inference/create_dataset.py +0 -230
- main/inference/create_index.py +0 -90
- main/inference/extract.py +0 -360
- main/inference/preprocess.py +0 -270
- main/inference/separator_music.py +0 -310
- main/inference/train.py +0 -990
- main/library/algorithm/commons.py +0 -60
- main/library/algorithm/modules.py +0 -60
- main/library/algorithm/mrf_hifigan.py +0 -150
- main/library/algorithm/onnx_export.py +0 -50
- main/library/algorithm/refinegan.py +0 -170
- main/library/algorithm/residuals.py +0 -140
- main/library/algorithm/separator.py +0 -320
- main/library/algorithm/stftpitchshift.py +0 -250
- main/library/algorithm/synthesizers.py +0 -490
- main/library/architectures/demucs_separator.py +0 -180
- main/library/architectures/fairseq.py +0 -1480
- main/library/architectures/mdx_separator.py +0 -320
- main/library/audioldm2/models.py +0 -330
- main/library/audioldm2/utils.py +0 -40
- main/library/predictors/CREPE.py +0 -210
- main/library/predictors/FCPE.py +0 -1097
- main/library/predictors/RMVPE.py +0 -260
- main/library/predictors/SWIPE.py +0 -140
- main/library/predictors/WORLD_WRAPPER.py +0 -90
- main/library/speaker_diarization/ECAPA_TDNN.py +0 -280
- main/library/speaker_diarization/audio.py +0 -170
- main/library/speaker_diarization/embedding.py +0 -90
- main/library/speaker_diarization/encoder.py +0 -250
- main/library/speaker_diarization/features.py +0 -520
- main/library/speaker_diarization/parameter_transfer.py +0 -120
- main/library/speaker_diarization/segment.py +0 -540
- main/library/speaker_diarization/speechbrain.py +0 -220
- main/library/speaker_diarization/whisper.py +0 -1290
- main/library/utils.py +0 -240
main/app/app.py
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
main/app/parser.py
DELETED
|
@@ -1,340 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
|
| 4 |
-
sys.path.append(os.getcwd())
|
| 5 |
-
|
| 6 |
-
try:
|
| 7 |
-
argv = sys.argv[1]
|
| 8 |
-
except IndexError:
|
| 9 |
-
argv = None
|
| 10 |
-
|
| 11 |
-
argv_is_allows = ["--audio_effects", "--audioldm2", "--convert", "--create_dataset", "--create_index", "--extract", "--preprocess", "--separator_music", "--train", "--help_audio_effects", "--help_audioldm2", "--help_convert", "--help_create_dataset", "--help_create_index", "--help_extract", "--help_preprocess", "--help_separator_music", "--help_train", "--help"]
|
| 12 |
-
|
| 13 |
-
if argv not in argv_is_allows:
|
| 14 |
-
print("Cú pháp không hợp lệ! Sử dụng --help để biết thêm")
|
| 15 |
-
quit()
|
| 16 |
-
|
| 17 |
-
if argv_is_allows[0] in argv: from main.inference.audio_effects import main
|
| 18 |
-
elif argv_is_allows[1] in argv: from main.inference.audioldm2 import main
|
| 19 |
-
elif argv_is_allows[2] in argv: from main.inference.convert import main
|
| 20 |
-
elif argv_is_allows[3] in argv: from main.inference.create_dataset import main
|
| 21 |
-
elif argv_is_allows[4] in argv: from main.inference.create_index import main
|
| 22 |
-
elif argv_is_allows[5] in argv: from main.inference.extract import main
|
| 23 |
-
elif argv_is_allows[6] in argv: from main.inference.preprocess import main
|
| 24 |
-
elif argv_is_allows[7] in argv: from main.inference.separator_music import main
|
| 25 |
-
elif argv_is_allows[8] in argv: from main.inference.train import main
|
| 26 |
-
elif argv_is_allows[9] in argv:
|
| 27 |
-
print("""Các tham số của `--audio_effects`:
|
| 28 |
-
1. Đường dẫn tệp:
|
| 29 |
-
- `--input_path` (bắt buộc): Đường dẫn đến tệp âm thanh đầu vào.
|
| 30 |
-
- `--output_path` (mặc định: `./audios/apply_effects.wav`): Đường dẫn lưu tệp đầu ra.
|
| 31 |
-
- `--export_format` (mặc định: `wav`): Định dạng xuất tệp (`wav`, `mp3`, ...).
|
| 32 |
-
|
| 33 |
-
2. Lấy mẫu lại:
|
| 34 |
-
- `--resample` (mặc định: `False`): Có lấy mẫu lại hay không.
|
| 35 |
-
- `--resample_sr` (mặc định: `0`): Tần số lấy mẫu mới (Hz).
|
| 36 |
-
|
| 37 |
-
3. Hiệu ứng chorus:
|
| 38 |
-
- `--chorus`: Bật/tắt chorus.
|
| 39 |
-
- `--chorus_depth`, `--chorus_rate`, `--chorus_mix`, `--chorus_delay`, `--chorus_feedback`: Các thông số điều chỉnh chorus.
|
| 40 |
-
|
| 41 |
-
4. Hiệu ứng distortion:
|
| 42 |
-
- `--distortion`: Bật/tắt distortion.
|
| 43 |
-
- `--drive_db`: Mức độ méo âm thanh.
|
| 44 |
-
|
| 45 |
-
5. Hiệu ứng reverb:
|
| 46 |
-
- `--reverb`: Bật/tắt hồi âm.
|
| 47 |
-
- `--reverb_room_size`, `--reverb_damping`, `--reverb_wet_level`, `--reverb_dry_level`, `--reverb_width`, `--reverb_freeze_mode`: Điều chỉnh hồi âm.
|
| 48 |
-
|
| 49 |
-
6. Hiệu ứng pitch shift:
|
| 50 |
-
- `--pitchshift`: Bật/tắt thay đổi cao độ.
|
| 51 |
-
- `--pitch_shift`: Giá trị dịch cao độ.
|
| 52 |
-
|
| 53 |
-
7. Hiệu ứng delay:
|
| 54 |
-
- `--delay`: Bật/tắt delay.
|
| 55 |
-
- `--delay_seconds`, `--delay_feedback`, `--delay_mix`: Điều chỉnh thời gian trễ, phản hồi và hòa trộn.
|
| 56 |
-
|
| 57 |
-
8. Compressor:
|
| 58 |
-
- `--compressor`: Bật/tắt compressor.
|
| 59 |
-
- `--compressor_threshold`, `--compressor_ratio`, `--compressor_attack_ms`, `--compressor_release_ms`: Các thông số nén.
|
| 60 |
-
|
| 61 |
-
9. Limiter:
|
| 62 |
-
- `--limiter`: Bật/tắt giới hạn mức âm thanh.
|
| 63 |
-
- `--limiter_threshold`, `--limiter_release`: Ngưỡng giới hạn và thời gian nhả.
|
| 64 |
-
|
| 65 |
-
10. Gain (Khuếch đại):
|
| 66 |
-
- `--gain`: Bật/tắt gain.
|
| 67 |
-
- `--gain_db`: Mức gain (dB).
|
| 68 |
-
|
| 69 |
-
11. Bitcrush:
|
| 70 |
-
- `--bitcrush`: Bật/tắt hiệu ứng giảm độ phân giải.
|
| 71 |
-
- `--bitcrush_bit_depth`: Số bit của bitcrush.
|
| 72 |
-
|
| 73 |
-
12. Clipping:
|
| 74 |
-
- `--clipping`: Bật/tắt cắt âm thanh.
|
| 75 |
-
- `--clipping_threshold`: Ngưỡng clipping.
|
| 76 |
-
|
| 77 |
-
13. Phaser:
|
| 78 |
-
- `--phaser`: Bật/tắt hiệu ứng phaser.
|
| 79 |
-
- `--phaser_rate_hz`, `--phaser_depth`, `--phaser_centre_frequency_hz`, `--phaser_feedback`, `--phaser_mix`: Điều chỉnh hiệu ứng phaser.
|
| 80 |
-
|
| 81 |
-
14. Boost bass & treble:
|
| 82 |
-
- `--treble_bass_boost`: Bật/tắt tăng cường âm bass và treble.
|
| 83 |
-
- `--bass_boost_db`, `--bass_boost_frequency`, `--treble_boost_db`, `--treble_boost_frequency`: Các thông số tăng bass và treble.
|
| 84 |
-
|
| 85 |
-
15. Fade in & fade out:
|
| 86 |
-
- `--fade_in_out`: Bật/tắt hiệu ứng fade.
|
| 87 |
-
- `--fade_in_duration`, `--fade_out_duration`: Thời gian fade vào/ra.
|
| 88 |
-
|
| 89 |
-
16. Kết hợp âm thanh:
|
| 90 |
-
- `--audio_combination`: Bật/tắt ghép nhiều tệp âm thanh.
|
| 91 |
-
- `--audio_combination_input`: Đường dẫn tệp âm thanh bổ sung.
|
| 92 |
-
""")
|
| 93 |
-
quit()
|
| 94 |
-
elif argv_is_allows[10] in argv:
|
| 95 |
-
print("""Các tham số của --audioldm2:
|
| 96 |
-
1. Đường dẫn tệp:
|
| 97 |
-
- `--input_path` (bắt buộc): Đường dẫn đến tệp âm thanh đầu vào.
|
| 98 |
-
- `--output_path` (mặc định: `./output.wav`): Đường dẫn lưu tệp đầu ra.
|
| 99 |
-
- `--export_format` (mặc định: `wav`): Định dạng xuất tệp.
|
| 100 |
-
|
| 101 |
-
2. Cấu hình âm thanh:
|
| 102 |
-
- `--sample_rate` (mặc định: `44100`): Tần số lấy mẫu (Hz).
|
| 103 |
-
|
| 104 |
-
3. Cấu hình mô hình AudioLDM:
|
| 105 |
-
- `--audioldm_model` (mặc định: `audioldm2-music`): Chọn mô hình AudioLDM để xử lý.
|
| 106 |
-
|
| 107 |
-
4. Prompt hướng dẫn mô hình:
|
| 108 |
-
- `--source_prompt` (mặc định: ``): Mô tả âm thanh nguồn.
|
| 109 |
-
- `--target_prompt` (mặc định: ``): Mô tả âm thanh đích.
|
| 110 |
-
|
| 111 |
-
5. Cấu hình thuật toán xử lý:
|
| 112 |
-
- `--steps` (mặc định: `200`): Số bước xử lý trong quá trình tổng hợp âm thanh.
|
| 113 |
-
- `--cfg_scale_src` (mặc định: `3.5`): Hệ số điều chỉnh hướng dẫn cho âm thanh nguồn.
|
| 114 |
-
- `--cfg_scale_tar` (mặc định: `12`): Hệ số điều chỉnh hướng dẫn cho âm thanh đích.
|
| 115 |
-
- `--t_start` (mặc định: `45`): Mức độ chỉnh sửa.
|
| 116 |
-
|
| 117 |
-
6. Tối ưu hóa tính toán:
|
| 118 |
-
- `--save_compute` (mặc định: `False`): Có bật chế độ tối ưu tính toán hay không.
|
| 119 |
-
""")
|
| 120 |
-
quit()
|
| 121 |
-
elif argv_is_allows[11] in argv:
|
| 122 |
-
print("""Các tham số của --convert:
|
| 123 |
-
1. Cấu hình xử lý giọng nói:
|
| 124 |
-
- `--pitch` (mặc định: `0`): Điều chỉnh cao độ.
|
| 125 |
-
- `--filter_radius` (mặc định: `3`): Độ mượt của đường F0.
|
| 126 |
-
- `--index_rate` (mặc định: `0.5`): Tỷ lệ sử dụng chỉ mục giọng nói.
|
| 127 |
-
- `--volume_envelope` (mặc định: `1`): Hệ số điều chỉnh biên độ âm lượng.
|
| 128 |
-
- `--protect` (mặc định: `0.33`): Bảo vệ phụ âm.
|
| 129 |
-
|
| 130 |
-
2. Cấu hình mẫu (frame hop):
|
| 131 |
-
- `--hop_length` (mặc định: `64`): Bước nhảy khi xử lý âm thanh.
|
| 132 |
-
|
| 133 |
-
3. Cấu hình F0:
|
| 134 |
-
- `--f0_method` (mặc định: `rmvpe`): Phương pháp dự đoán F0 (`pm`, `dio`, `mangio-crepe-tiny`, `mangio-crepe-small`, `mangio-crepe-medium`, `mangio-crepe-large`, `mangio-crepe-full`, `crepe-tiny`, `crepe-small`, `crepe-medium`, `crepe-large`, `crepe-full`, `fcpe`, `fcpe-legacy`, `rmvpe`, `rmvpe-legacy`, `harvest`, `yin`, `pyin`, `swipe`).
|
| 135 |
-
- `--f0_autotune` (mặc định: `False`): Có tự động điều chỉnh F0 hay không.
|
| 136 |
-
- `--f0_autotune_strength` (mặc định: `1`): Cường độ hiệu chỉnh tự động F0.
|
| 137 |
-
- `--f0_file` (mặc định: ``): Đường dẫn tệp F0 có sẵn.
|
| 138 |
-
- `--f0_onnx` (mặc định: `False`): Có sử dụng phiên bản ONNX của F0 hay không.
|
| 139 |
-
|
| 140 |
-
4. Mô hình nhúng:
|
| 141 |
-
- `--embedder_model` (mặc định: `contentvec_base`): Mô hình nhúng sử dụng.
|
| 142 |
-
- `--embedders_mode` (mặc định: `fairseq`): Chế độ nhúng (`fairseq`, `transformers`, `onnx`).
|
| 143 |
-
|
| 144 |
-
5. Đường dẫn tệp:
|
| 145 |
-
- `--input_path` (bắt buộc): Đường dẫn tệp âm thanh đầu vào.
|
| 146 |
-
- `--output_path` (mặc định: `./audios/output.wav`): Đường dẫn lưu tệp đầu ra.
|
| 147 |
-
- `--export_format` (mặc định: `wav`): Định dạng xuất tệp.
|
| 148 |
-
- `--pth_path` (bắt buộc): Đường dẫn đến tệp mô hình `.pth`.
|
| 149 |
-
- `--index_path` (mặc định: `None`): Đường dẫn tệp chỉ mục (nếu có).
|
| 150 |
-
|
| 151 |
-
6. Làm sạch âm thanh:
|
| 152 |
-
- `--clean_audio` (mặc định: `False`): Có áp dụng làm sạch âm thanh không.
|
| 153 |
-
- `--clean_strength` (mặc định: `0.7`): Mức độ làm sạch.
|
| 154 |
-
|
| 155 |
-
7. Resampling & chia nhỏ âm thanh:
|
| 156 |
-
- `--resample_sr` (mặc định: `0`): Tần số lấy mẫu mới (0 nghĩa là giữ nguyên).
|
| 157 |
-
- `--split_audio` (mặc định: `False`): Có chia nhỏ audio trước khi xử lý không.
|
| 158 |
-
|
| 159 |
-
8. Kiểm tra & tối ưu hóa:
|
| 160 |
-
- `--checkpointing` (mặc định: `False`): Bật/tắt checkpointing để tiết kiệm RAM.
|
| 161 |
-
|
| 162 |
-
9. Dịch formant:
|
| 163 |
-
- `--formant_shifting` (mặc định: `False`): Có bật hiệu ứng dịch formant không.
|
| 164 |
-
- `--formant_qfrency` (mặc định: `0.8`): Hệ số dịch formant theo tần số.
|
| 165 |
-
- `--formant_timbre` (mặc định: `0.8`): Hệ số thay đổi màu sắc giọng.
|
| 166 |
-
""")
|
| 167 |
-
quit()
|
| 168 |
-
elif argv_is_allows[12] in argv:
|
| 169 |
-
print("""Các tham số của --create_dataset:
|
| 170 |
-
1. Đường dẫn & cấu hình dataset:
|
| 171 |
-
- `--input_audio` (bắt buộc): Đường dẫn liên kết đến âm thanh (Liên kết Youtube, có thể dùng dấu `,` để dùng nhiều liên kết).
|
| 172 |
-
- `--output_dataset` (mặc định: `./dataset`): Thư mục xuất dữ liệu đầu ra.
|
| 173 |
-
- `--sample_rate` (mặc định: `44100`): Tần số lấy mẫu cho âm thanh.
|
| 174 |
-
|
| 175 |
-
2. Làm sạch dữ liệu:
|
| 176 |
-
- `--clean_dataset` (mặc định: `False`): Có áp dụng làm sạch dữ liệu hay không.
|
| 177 |
-
- `--clean_strength` (mặc định: `0.7`): Mức độ làm sạch dữ liệu.
|
| 178 |
-
|
| 179 |
-
3. Tách giọng & hiệu ứng:
|
| 180 |
-
- `--separator_reverb` (mặc định: `False`): Có tách vang giọng không.
|
| 181 |
-
- `--kim_vocal_version` (mặc định: `2`): Phiên bản mô hình Kim Vocal để tách (`1`, `2`).
|
| 182 |
-
|
| 183 |
-
4. Cấu hình phân đoạn âm thanh:
|
| 184 |
-
- `--overlap` (mặc định: `0.25`): Mức độ chồng lấn giữa các đoạn khi tách.
|
| 185 |
-
- `--segments_size` (mặc định: `256`): Kích thước của từng phân đoạn.
|
| 186 |
-
|
| 187 |
-
5. Cấu hình MDX (Music Demixing):
|
| 188 |
-
- `--mdx_hop_length` (mặc định: `1024`): Bước nhảy MDX khi xử lý.
|
| 189 |
-
- `--mdx_batch_size` (mặc định: `1`): Kích thước batch khi xử lý MDX.
|
| 190 |
-
- `--denoise_mdx` (mặc định: `False`): Có áp dụng khử nhiễu khi tách bằng MDX không.
|
| 191 |
-
|
| 192 |
-
6. Bỏ qua phần âm thanh:
|
| 193 |
-
- `--skip` (mặc định: `False`): Có bỏ qua giây âm thanh nào không.
|
| 194 |
-
- `--skip_start_audios` (mặc định: `0`): Thời gian (giây) cần bỏ qua ở đầu audio.
|
| 195 |
-
- `--skip_end_audios` (mặc định: `0`): Thời gian (giây) cần bỏ qua ở cuối audio.
|
| 196 |
-
""")
|
| 197 |
-
quit()
|
| 198 |
-
elif argv_is_allows[13] in argv:
|
| 199 |
-
print("""Các tham số của --create_index:
|
| 200 |
-
1. Thông tin mô hình:
|
| 201 |
-
- `--model_name` (bắt buộc): Tên mô hình.
|
| 202 |
-
- `--rvc_version` (mặc định: `v2`): Phiên bản (`v1`, `v2`).
|
| 203 |
-
- `--index_algorithm` (mặc định: `Auto`): Thuật toán index sử dụng (`Auto`, `Faiss`, `KMeans`).
|
| 204 |
-
""")
|
| 205 |
-
quit()
|
| 206 |
-
elif argv_is_allows[14] in argv:
|
| 207 |
-
print("""Các tham số của --extract:
|
| 208 |
-
1. Thông tin mô hình:
|
| 209 |
-
- `--model_name` (bắt buộc): Tên mô hình.
|
| 210 |
-
- `--rvc_version` (mặc định: `v2`): Phiên bản RVC (`v1`, `v2`).
|
| 211 |
-
|
| 212 |
-
2. Cấu hình F0:
|
| 213 |
-
- `--f0_method` (mặc định: `rmvpe`): Phương pháp dự đoán F0 (`pm`, `dio`, `mangio-crepe-tiny`, `mangio-crepe-small`, `mangio-crepe-medium`, `mangio-crepe-large`, `mangio-crepe-full`, `crepe-tiny`, `crepe-small`, `crepe-medium`, `crepe-large`, `crepe-full`, `fcpe`, `fcpe-legacy`, `rmvpe`, `rmvpe-legacy`, `harvest`, `yin`, `pyin`, `swipe`).
|
| 214 |
-
- `--pitch_guidance` (mặc định: `True`): Có sử dụng hướng dẫn cao độ hay không.
|
| 215 |
-
|
| 216 |
-
3. Cấu hình xử lý:
|
| 217 |
-
- `--hop_length` (mặc định: `128`): Độ dài bước nhảy trong quá trình xử lý.
|
| 218 |
-
- `--cpu_cores` (mặc định: `2`): Số lượng luồng CPU sử dụng.
|
| 219 |
-
- `--gpu` (mặc định: `-`): Chỉ định GPU sử dụng (ví dụ: `0` cho GPU đầu tiên, `-` để tắt GPU).
|
| 220 |
-
- `--sample_rate` (bắt buộc): Tần số lấy mẫu của âm thanh đầu vào.
|
| 221 |
-
|
| 222 |
-
4. Cấu hình nhúng:
|
| 223 |
-
- `--embedder_model` (mặc định: `contentvec_base`): Tên mô hình nhúng.
|
| 224 |
-
- `--f0_onnx` (mặc định: `False`): Có sử dụng phiên bản ONNX của F0 hay không.
|
| 225 |
-
- `--embedders_mode` (mặc định: `fairseq`): Chế độ nhúng (`fairseq`, `transformers`, `onnx`).
|
| 226 |
-
""")
|
| 227 |
-
quit()
|
| 228 |
-
elif argv_is_allows[15] in argv:
|
| 229 |
-
print("""Các tham số của --preprocess:
|
| 230 |
-
1. Thông tin mô hình:
|
| 231 |
-
- `--model_name` (bắt buộc): Tên mô hình.
|
| 232 |
-
|
| 233 |
-
2. Cấu hình dữ liệu:
|
| 234 |
-
- `--dataset_path` (mặc định: `./dataset`): Đường dẫn thư mục chứa tệp dữ liệu.
|
| 235 |
-
- `--sample_rate` (bắt buộc): Tần số lấy mẫu của dữ liệu âm thanh.
|
| 236 |
-
|
| 237 |
-
3. Cấu hình xử lý:
|
| 238 |
-
- `--cpu_cores` (mặc định: `2`): Số lượng luồng CPU sử dụng.
|
| 239 |
-
- `--cut_preprocess` (mặc định: `True`): Có cắt tệp dữ liệu hay không.
|
| 240 |
-
- `--process_effects` (mặc định: `False`): Có áp dụng tiền xử lý hay không.
|
| 241 |
-
- `--clean_dataset` (mặc định: `False`): Có làm sạch tệp dữ liệu hay không.
|
| 242 |
-
- `--clean_strength` (mặc định: `0.7`): Độ mạnh của quá trình làm sạch dữ liệu.
|
| 243 |
-
""")
|
| 244 |
-
quit()
|
| 245 |
-
elif argv_is_allows[16] in argv:
|
| 246 |
-
print("""Các tham số của --separator_music:
|
| 247 |
-
1. Đường dẫn dữ liệu:
|
| 248 |
-
- `--input_path` (bắt buộc): Đường dẫn tệp âm thanh đầu vào.
|
| 249 |
-
- `--output_path` (mặc định: `./audios`): Thư mục lưu tệp đầu ra.
|
| 250 |
-
- `--format` (mặc định: `wav`): Định dạng xuất tệp (`wav`, `mp3`,...).
|
| 251 |
-
|
| 252 |
-
2. Cấu hình xử lý âm thanh:
|
| 253 |
-
- `--shifts` (m���c định: `2`): Số lượng dự đoán.
|
| 254 |
-
- `--segments_size` (mặc định: `256`): Kích thước phân đoạn âm thanh.
|
| 255 |
-
- `--overlap` (mặc định: `0.25`): Mức độ chồng lấn giữa các đoạn.
|
| 256 |
-
- `--mdx_hop_length` (mặc định: `1024`): Bước nhảy MDX khi xử lý.
|
| 257 |
-
- `--mdx_batch_size` (mặc định: `1`): Kích thước lô.
|
| 258 |
-
|
| 259 |
-
3. Xử lý làm sạch:
|
| 260 |
-
- `--clean_audio` (mặc định: `False`): Có làm sạch âm thanh hay không.
|
| 261 |
-
- `--clean_strength` (mặc định: `0.7`): Độ mạnh của bộ lọc làm sạch.
|
| 262 |
-
|
| 263 |
-
4. Cấu hình mô hình:
|
| 264 |
-
- `--model_name` (mặc định: `HT-Normal`): Mô hình tách nhạc (`Main_340`, `Main_390`, `Main_406`, `Main_427`, `Main_438`, `Inst_full_292`, `Inst_HQ_1`, `Inst_HQ_2`, `Inst_HQ_3`, `Inst_HQ_4`, `Inst_HQ_5`, `Kim_Vocal_1`, `Kim_Vocal_2`, `Kim_Inst`, `Inst_187_beta`, `Inst_82_beta`, `Inst_90_beta`, `Voc_FT`, `Crowd_HQ`, `Inst_1`, `Inst_2`, `Inst_3`, `MDXNET_1_9703`, `MDXNET_2_9682`, `MDXNET_3_9662`, `Inst_Main`, `MDXNET_Main`, `MDXNET_9482`, `HT-Normal`, `HT-Tuned`, `HD_MMI`, `HT_6S`).
|
| 265 |
-
- `--kara_model` (mặc định: `Version-1`): Phiên bản mô hình tách bè (`Version-1`, `Version-2`).
|
| 266 |
-
|
| 267 |
-
5. Hiệu ứng và xử lý hậu kỳ:
|
| 268 |
-
- `--backing` (mặc định: `False`): Có tách bè hay không.
|
| 269 |
-
- `--mdx_denoise` (mặc định: `False`): Có sử dụng khử nhiễu MDX hay không.
|
| 270 |
-
- `--reverb` (mặc định: `False`): Có tách vang hay không.
|
| 271 |
-
- `--backing_reverb` (mặc định: `False`): có tách vang cho giọng bè không.
|
| 272 |
-
|
| 273 |
-
6. Tần số lấy mẫu:
|
| 274 |
-
- `--sample_rate` (mặc định: `44100`): Tần số lấy mẫu của âm thanh đầu ra.
|
| 275 |
-
""")
|
| 276 |
-
quit()
|
| 277 |
-
elif argv_is_allows[17] in argv:
|
| 278 |
-
print("""Các tham số của --train:
|
| 279 |
-
1. Cấu hình mô hình:
|
| 280 |
-
- `--model_name` (bắt buộc): Tên mô hình.
|
| 281 |
-
- `--rvc_version` (mặc định: `v2`): Phiên bản RVC (`v1`, `v2`).
|
| 282 |
-
- `--model_author` (tùy chọn): Tác giả của mô hình.
|
| 283 |
-
|
| 284 |
-
2. Cấu hình lưu:
|
| 285 |
-
- `--save_every_epoch` (bắt buộc): Số kỷ nguyên giữa mỗi lần lưu.
|
| 286 |
-
- `--save_only_latest` (mặc định: `True`): Chỉ lưu điểm mới nhất.
|
| 287 |
-
- `--save_every_weights` (mặc định: `True`): Lưu tất cả trọng số của mô hình.
|
| 288 |
-
|
| 289 |
-
3. Cấu hình huấn luyện:
|
| 290 |
-
- `--total_epoch` (mặc định: `300`): Tổng số kỷ nguyên huấn luyện.
|
| 291 |
-
- `--batch_size` (mặc định: `8`): Kích thước lô trong quá trình huấn luyện.
|
| 292 |
-
- `--sample_rate` (bắt buộc): Tần số lấy mẫu của âm thanh.
|
| 293 |
-
|
| 294 |
-
4. Cấu hình thiết bị:
|
| 295 |
-
- `--gpu` (mặc định: `0`): Chỉ định GPU để sử dụng (số hoặc `-` nếu không dùng GPU).
|
| 296 |
-
- `--cache_data_in_gpu` (mặc định: `False`): Lưu dữ liệu vào GPU để tăng tốc.
|
| 297 |
-
|
| 298 |
-
5. Cấu hình huấn luyện nâng cao:
|
| 299 |
-
- `--pitch_guidance` (mặc định: `True`): Sử dụng hướng dẫn cao độ.
|
| 300 |
-
- `--g_pretrained_path` (mặc định: ``): Đường dẫn đến trọng số G đã huấn luyện trước.
|
| 301 |
-
- `--d_pretrained_path` (mặc định: ``): Đường dẫn đến trọng số D đã huấn luyện trước.
|
| 302 |
-
- `--vocoder` (mặc định: `Default`): Bộ mã hóa được sử dụng (`Default`, `MRF-HiFi-GAN`, `RefineGAN`).
|
| 303 |
-
|
| 304 |
-
6. Phát hiện huấn luyện quá mức:
|
| 305 |
-
- `--overtraining_detector` (mặc định: `False`): Bật/tắt chế độ phát hiện huấn luyện quá mức.
|
| 306 |
-
- `--overtraining_threshold` (mặc định: `50`): Ngưỡng để xác định huấn luyện quá mức.
|
| 307 |
-
|
| 308 |
-
7. Xử lý dữ liệu:
|
| 309 |
-
- `--cleanup` (mặc định: `False`): Dọn dẹp tệp huấn luyện cũ để tiến hành huấn luyện lại từ đầu.
|
| 310 |
-
|
| 311 |
-
8. Tối ưu:
|
| 312 |
-
- `--checkpointing` (mặc định: `False`): Bật/tắt checkpointing để tiết kiệm RAM.
|
| 313 |
-
- `--deterministic` (mặc định: `False`): Khi bật sẽ sử dụng các thuật toán có tính xác định cao, đảm bảo rằng mỗi lần chạy cùng một dữ liệu đầu vào sẽ cho kết quả giống nhau.
|
| 314 |
-
- `--benchmark` (mặc định: `False`): Khi bật sẽ thử nghiệm và chọn thuật toán tối ưu nhất cho phần cứng và kích thước cụ thể.
|
| 315 |
-
""")
|
| 316 |
-
quit()
|
| 317 |
-
elif argv_is_allows[18] in argv:
|
| 318 |
-
print("""Sử dụng:
|
| 319 |
-
1. `--help_audio_effects`: Trợ giúp về phần thêm hiệu ứng âm thanh.
|
| 320 |
-
2. `--help_audioldm2`: Trợ giúp về phần chỉnh sửa nhạc.
|
| 321 |
-
3. `--help_convert`: Trợ giúp về chuyển đổi âm thanh.
|
| 322 |
-
4. `--help_create_dataset`: Trợ giúp về tạo dữ liệu huấn luyện.
|
| 323 |
-
5. `--help_create_index`: Trợ giúp về tạo chỉ mục.
|
| 324 |
-
6. `--help_extract`: Trợ giúp về trích xuất dữ liệu huấn luyện.
|
| 325 |
-
7. `--help_preprocess`: Trợ giúp về xử lý trước dữ liệu.
|
| 326 |
-
8. `--help_separator_music`: Trợ giúp về tách nhạc.
|
| 327 |
-
9. `--help_train`: Trợ giúp về huấn luyện mô hình.
|
| 328 |
-
""")
|
| 329 |
-
quit()
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
if __name__ == "__main__":
|
| 333 |
-
if "--train" in argv:
|
| 334 |
-
import torch.multiprocessing as mp
|
| 335 |
-
mp.set_start_method("spawn")
|
| 336 |
-
|
| 337 |
-
try:
|
| 338 |
-
main()
|
| 339 |
-
except:
|
| 340 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/app/tensorboard.py
DELETED
|
@@ -1,30 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import json
|
| 4 |
-
import logging
|
| 5 |
-
import webbrowser
|
| 6 |
-
|
| 7 |
-
from tensorboard import program
|
| 8 |
-
|
| 9 |
-
sys.path.append(os.getcwd())
|
| 10 |
-
|
| 11 |
-
from main.configs.config import Config
|
| 12 |
-
translations = Config().translations
|
| 13 |
-
|
| 14 |
-
with open(os.path.join("main", "configs", "config.json"), "r") as f:
|
| 15 |
-
configs = json.load(f)
|
| 16 |
-
|
| 17 |
-
def launch_tensorboard():
|
| 18 |
-
for l in ["root", "tensorboard"]:
|
| 19 |
-
logging.getLogger(l).setLevel(logging.ERROR)
|
| 20 |
-
|
| 21 |
-
tb = program.TensorBoard()
|
| 22 |
-
tb.configure(argv=[None, "--logdir", "assets/logs", f"--port={configs['tensorboard_port']}"])
|
| 23 |
-
url = tb.launch()
|
| 24 |
-
|
| 25 |
-
print(f"{translations['tensorboard_url']}: {url}")
|
| 26 |
-
if "--open" in sys.argv: webbrowser.open(url)
|
| 27 |
-
|
| 28 |
-
return f"{translations['tensorboard_url']}: {url}"
|
| 29 |
-
|
| 30 |
-
if __name__ == "__main__": launch_tensorboard()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/config.json
DELETED
|
@@ -1,547 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"language": "vi-VN",
|
| 3 |
-
"support_language": [
|
| 4 |
-
"en-US",
|
| 5 |
-
"vi-VN"
|
| 6 |
-
],
|
| 7 |
-
"theme": "NoCrypt/miku",
|
| 8 |
-
"themes": [
|
| 9 |
-
"NoCrypt/miku",
|
| 10 |
-
"gstaff/xkcd",
|
| 11 |
-
"JohnSmith9982/small_and_pretty",
|
| 12 |
-
"ParityError/Interstellar",
|
| 13 |
-
"earneleh/paris",
|
| 14 |
-
"shivi/calm_seafoam",
|
| 15 |
-
"Hev832/Applio",
|
| 16 |
-
"YTheme/Minecraft",
|
| 17 |
-
"gstaff/sketch",
|
| 18 |
-
"SebastianBravo/simci_css",
|
| 19 |
-
"allenai/gradio-theme",
|
| 20 |
-
"Nymbo/Nymbo_Theme_5",
|
| 21 |
-
"lone17/kotaemon",
|
| 22 |
-
"Zarkel/IBM_Carbon_Theme",
|
| 23 |
-
"SherlockRamos/Feliz",
|
| 24 |
-
"freddyaboulton/dracula_revamped",
|
| 25 |
-
"freddyaboulton/bad-theme-space",
|
| 26 |
-
"gradio/dracula_revamped",
|
| 27 |
-
"abidlabs/dracula_revamped",
|
| 28 |
-
"gradio/dracula_test",
|
| 29 |
-
"gradio/seafoam",
|
| 30 |
-
"gradio/glass",
|
| 31 |
-
"gradio/monochrome",
|
| 32 |
-
"gradio/soft",
|
| 33 |
-
"gradio/default",
|
| 34 |
-
"gradio/base",
|
| 35 |
-
"abidlabs/pakistan",
|
| 36 |
-
"dawood/microsoft_windows",
|
| 37 |
-
"ysharma/steampunk",
|
| 38 |
-
"ysharma/huggingface",
|
| 39 |
-
"abidlabs/Lime",
|
| 40 |
-
"freddyaboulton/this-theme-does-not-exist-2",
|
| 41 |
-
"aliabid94/new-theme",
|
| 42 |
-
"aliabid94/test2",
|
| 43 |
-
"aliabid94/test3",
|
| 44 |
-
"aliabid94/test4",
|
| 45 |
-
"abidlabs/banana",
|
| 46 |
-
"freddyaboulton/test-blue",
|
| 47 |
-
"gstaff/whiteboard",
|
| 48 |
-
"ysharma/llamas",
|
| 49 |
-
"abidlabs/font-test",
|
| 50 |
-
"YenLai/Superhuman",
|
| 51 |
-
"bethecloud/storj_theme",
|
| 52 |
-
"sudeepshouche/minimalist",
|
| 53 |
-
"knotdgaf/gradiotest",
|
| 54 |
-
"ParityError/Anime",
|
| 55 |
-
"Ajaxon6255/Emerald_Isle",
|
| 56 |
-
"ParityError/LimeFace",
|
| 57 |
-
"finlaymacklon/smooth_slate",
|
| 58 |
-
"finlaymacklon/boxy_violet",
|
| 59 |
-
"derekzen/stardust",
|
| 60 |
-
"EveryPizza/Cartoony-Gradio-Theme",
|
| 61 |
-
"Ifeanyi/Cyanister",
|
| 62 |
-
"Tshackelton/IBMPlex-DenseReadable",
|
| 63 |
-
"snehilsanyal/scikit-learn",
|
| 64 |
-
"Himhimhim/xkcd",
|
| 65 |
-
"nota-ai/theme",
|
| 66 |
-
"rawrsor1/Everforest",
|
| 67 |
-
"rottenlittlecreature/Moon_Goblin",
|
| 68 |
-
"abidlabs/test-yellow",
|
| 69 |
-
"abidlabs/test-yellow3",
|
| 70 |
-
"idspicQstitho/dracula_revamped",
|
| 71 |
-
"kfahn/AnimalPose",
|
| 72 |
-
"HaleyCH/HaleyCH_Theme",
|
| 73 |
-
"simulKitke/dracula_test",
|
| 74 |
-
"braintacles/CrimsonNight",
|
| 75 |
-
"wentaohe/whiteboardv2",
|
| 76 |
-
"reilnuud/polite",
|
| 77 |
-
"remilia/Ghostly",
|
| 78 |
-
"Franklisi/darkmode",
|
| 79 |
-
"coding-alt/soft",
|
| 80 |
-
"xiaobaiyuan/theme_land",
|
| 81 |
-
"step-3-profit/Midnight-Deep",
|
| 82 |
-
"xiaobaiyuan/theme_demo",
|
| 83 |
-
"Taithrah/Minimal",
|
| 84 |
-
"Insuz/SimpleIndigo",
|
| 85 |
-
"zkunn/Alipay_Gradio_theme",
|
| 86 |
-
"Insuz/Mocha",
|
| 87 |
-
"xiaobaiyuan/theme_brief",
|
| 88 |
-
"Ama434/434-base-Barlow",
|
| 89 |
-
"Ama434/def_barlow",
|
| 90 |
-
"Ama434/neutral-barlow",
|
| 91 |
-
"dawood/dracula_test",
|
| 92 |
-
"nuttea/Softblue",
|
| 93 |
-
"BlueDancer/Alien_Diffusion",
|
| 94 |
-
"naughtondale/monochrome",
|
| 95 |
-
"Dagfinn1962/standard",
|
| 96 |
-
"default"
|
| 97 |
-
],
|
| 98 |
-
"mdx_model": [
|
| 99 |
-
"Main_340",
|
| 100 |
-
"Main_390",
|
| 101 |
-
"Main_406",
|
| 102 |
-
"Main_427",
|
| 103 |
-
"Main_438",
|
| 104 |
-
"Inst_full_292",
|
| 105 |
-
"Inst_HQ_1",
|
| 106 |
-
"Inst_HQ_2",
|
| 107 |
-
"Inst_HQ_3",
|
| 108 |
-
"Inst_HQ_4",
|
| 109 |
-
"Inst_HQ_5",
|
| 110 |
-
"Kim_Vocal_1",
|
| 111 |
-
"Kim_Vocal_2",
|
| 112 |
-
"Kim_Inst",
|
| 113 |
-
"Inst_187_beta",
|
| 114 |
-
"Inst_82_beta",
|
| 115 |
-
"Inst_90_beta",
|
| 116 |
-
"Voc_FT",
|
| 117 |
-
"Crowd_HQ",
|
| 118 |
-
"Inst_1",
|
| 119 |
-
"Inst_2",
|
| 120 |
-
"Inst_3",
|
| 121 |
-
"MDXNET_1_9703",
|
| 122 |
-
"MDXNET_2_9682",
|
| 123 |
-
"MDXNET_3_9662",
|
| 124 |
-
"Inst_Main",
|
| 125 |
-
"MDXNET_Main",
|
| 126 |
-
"MDXNET_9482"
|
| 127 |
-
],
|
| 128 |
-
"demucs_model": [
|
| 129 |
-
"HT-Normal",
|
| 130 |
-
"HT-Tuned",
|
| 131 |
-
"HD_MMI",
|
| 132 |
-
"HT_6S"
|
| 133 |
-
],
|
| 134 |
-
"edge_tts": [
|
| 135 |
-
"af-ZA-AdriNeural",
|
| 136 |
-
"af-ZA-WillemNeural",
|
| 137 |
-
"sq-AL-AnilaNeural",
|
| 138 |
-
"sq-AL-IlirNeural",
|
| 139 |
-
"am-ET-AmehaNeural",
|
| 140 |
-
"am-ET-MekdesNeural",
|
| 141 |
-
"ar-DZ-AminaNeural",
|
| 142 |
-
"ar-DZ-IsmaelNeural",
|
| 143 |
-
"ar-BH-AliNeural",
|
| 144 |
-
"ar-BH-LailaNeural",
|
| 145 |
-
"ar-EG-SalmaNeural",
|
| 146 |
-
"ar-EG-ShakirNeural",
|
| 147 |
-
"ar-IQ-BasselNeural",
|
| 148 |
-
"ar-IQ-RanaNeural",
|
| 149 |
-
"ar-JO-SanaNeural",
|
| 150 |
-
"ar-JO-TaimNeural",
|
| 151 |
-
"ar-KW-FahedNeural",
|
| 152 |
-
"ar-KW-NouraNeural",
|
| 153 |
-
"ar-LB-LaylaNeural",
|
| 154 |
-
"ar-LB-RamiNeural",
|
| 155 |
-
"ar-LY-ImanNeural",
|
| 156 |
-
"ar-LY-OmarNeural",
|
| 157 |
-
"ar-MA-JamalNeural",
|
| 158 |
-
"ar-MA-MounaNeural",
|
| 159 |
-
"ar-OM-AbdullahNeural",
|
| 160 |
-
"ar-OM-AyshaNeural",
|
| 161 |
-
"ar-QA-AmalNeural",
|
| 162 |
-
"ar-QA-MoazNeural",
|
| 163 |
-
"ar-SA-HamedNeural",
|
| 164 |
-
"ar-SA-ZariyahNeural",
|
| 165 |
-
"ar-SY-AmanyNeural",
|
| 166 |
-
"ar-SY-LaithNeural",
|
| 167 |
-
"ar-TN-HediNeural",
|
| 168 |
-
"ar-TN-ReemNeural",
|
| 169 |
-
"ar-AE-FatimaNeural",
|
| 170 |
-
"ar-AE-HamdanNeural",
|
| 171 |
-
"ar-YE-MaryamNeural",
|
| 172 |
-
"ar-YE-SalehNeural",
|
| 173 |
-
"az-AZ-BabekNeural",
|
| 174 |
-
"az-AZ-BanuNeural",
|
| 175 |
-
"bn-BD-NabanitaNeural",
|
| 176 |
-
"bn-BD-PradeepNeural",
|
| 177 |
-
"bn-IN-BashkarNeural",
|
| 178 |
-
"bn-IN-TanishaaNeural",
|
| 179 |
-
"bs-BA-GoranNeural",
|
| 180 |
-
"bs-BA-VesnaNeural",
|
| 181 |
-
"bg-BG-BorislavNeural",
|
| 182 |
-
"bg-BG-KalinaNeural",
|
| 183 |
-
"my-MM-NilarNeural",
|
| 184 |
-
"my-MM-ThihaNeural",
|
| 185 |
-
"ca-ES-EnricNeural",
|
| 186 |
-
"ca-ES-JoanaNeural",
|
| 187 |
-
"zh-HK-HiuGaaiNeural",
|
| 188 |
-
"zh-HK-HiuMaanNeural",
|
| 189 |
-
"zh-HK-WanLungNeural",
|
| 190 |
-
"zh-CN-XiaoxiaoNeural",
|
| 191 |
-
"zh-CN-XiaoyiNeural",
|
| 192 |
-
"zh-CN-YunjianNeural",
|
| 193 |
-
"zh-CN-YunxiNeural",
|
| 194 |
-
"zh-CN-YunxiaNeural",
|
| 195 |
-
"zh-CN-YunyangNeural",
|
| 196 |
-
"zh-CN-liaoning-XiaobeiNeural",
|
| 197 |
-
"zh-TW-HsiaoChenNeural",
|
| 198 |
-
"zh-TW-YunJheNeural",
|
| 199 |
-
"zh-TW-HsiaoYuNeural",
|
| 200 |
-
"zh-CN-shaanxi-XiaoniNeural",
|
| 201 |
-
"hr-HR-GabrijelaNeural",
|
| 202 |
-
"hr-HR-SreckoNeural",
|
| 203 |
-
"cs-CZ-AntoninNeural",
|
| 204 |
-
"cs-CZ-VlastaNeural",
|
| 205 |
-
"da-DK-ChristelNeural",
|
| 206 |
-
"da-DK-JeppeNeural",
|
| 207 |
-
"nl-BE-ArnaudNeural",
|
| 208 |
-
"nl-BE-DenaNeural",
|
| 209 |
-
"nl-NL-ColetteNeural",
|
| 210 |
-
"nl-NL-FennaNeural",
|
| 211 |
-
"nl-NL-MaartenNeural",
|
| 212 |
-
"en-AU-NatashaNeural",
|
| 213 |
-
"en-AU-WilliamNeural",
|
| 214 |
-
"en-CA-ClaraNeural",
|
| 215 |
-
"en-CA-LiamNeural",
|
| 216 |
-
"en-HK-SamNeural",
|
| 217 |
-
"en-HK-YanNeural",
|
| 218 |
-
"en-IN-NeerjaExpressiveNeural",
|
| 219 |
-
"en-IN-NeerjaNeural",
|
| 220 |
-
"en-IN-PrabhatNeural",
|
| 221 |
-
"en-IE-ConnorNeural",
|
| 222 |
-
"en-IE-EmilyNeural",
|
| 223 |
-
"en-KE-AsiliaNeural",
|
| 224 |
-
"en-KE-ChilembaNeural",
|
| 225 |
-
"en-NZ-MitchellNeural",
|
| 226 |
-
"en-NZ-MollyNeural",
|
| 227 |
-
"en-NG-AbeoNeural",
|
| 228 |
-
"en-NG-EzinneNeural",
|
| 229 |
-
"en-PH-JamesNeural",
|
| 230 |
-
"en-PH-RosaNeural",
|
| 231 |
-
"en-SG-LunaNeural",
|
| 232 |
-
"en-SG-WayneNeural",
|
| 233 |
-
"en-ZA-LeahNeural",
|
| 234 |
-
"en-ZA-LukeNeural",
|
| 235 |
-
"en-TZ-ElimuNeural",
|
| 236 |
-
"en-TZ-ImaniNeural",
|
| 237 |
-
"en-GB-LibbyNeural",
|
| 238 |
-
"en-GB-MaisieNeural",
|
| 239 |
-
"en-GB-RyanNeural",
|
| 240 |
-
"en-GB-SoniaNeural",
|
| 241 |
-
"en-GB-ThomasNeural",
|
| 242 |
-
"en-US-AvaMultilingualNeural",
|
| 243 |
-
"en-US-AndrewMultilingualNeural",
|
| 244 |
-
"en-US-EmmaMultilingualNeural",
|
| 245 |
-
"en-US-BrianMultilingualNeural",
|
| 246 |
-
"en-US-AvaNeural",
|
| 247 |
-
"en-US-AndrewNeural",
|
| 248 |
-
"en-US-EmmaNeural",
|
| 249 |
-
"en-US-BrianNeural",
|
| 250 |
-
"en-US-AnaNeural",
|
| 251 |
-
"en-US-AriaNeural",
|
| 252 |
-
"en-US-ChristopherNeural",
|
| 253 |
-
"en-US-EricNeural",
|
| 254 |
-
"en-US-GuyNeural",
|
| 255 |
-
"en-US-JennyNeural",
|
| 256 |
-
"en-US-MichelleNeural",
|
| 257 |
-
"en-US-RogerNeural",
|
| 258 |
-
"en-US-SteffanNeural",
|
| 259 |
-
"et-EE-AnuNeural",
|
| 260 |
-
"et-EE-KertNeural",
|
| 261 |
-
"fil-PH-AngeloNeural",
|
| 262 |
-
"fil-PH-BlessicaNeural",
|
| 263 |
-
"fi-FI-HarriNeural",
|
| 264 |
-
"fi-FI-NooraNeural",
|
| 265 |
-
"fr-BE-CharlineNeural",
|
| 266 |
-
"fr-BE-GerardNeural",
|
| 267 |
-
"fr-CA-ThierryNeural",
|
| 268 |
-
"fr-CA-AntoineNeural",
|
| 269 |
-
"fr-CA-JeanNeural",
|
| 270 |
-
"fr-CA-SylvieNeural",
|
| 271 |
-
"fr-FR-VivienneMultilingualNeural",
|
| 272 |
-
"fr-FR-RemyMultilingualNeural",
|
| 273 |
-
"fr-FR-DeniseNeural",
|
| 274 |
-
"fr-FR-EloiseNeural",
|
| 275 |
-
"fr-FR-HenriNeural",
|
| 276 |
-
"fr-CH-ArianeNeural",
|
| 277 |
-
"fr-CH-FabriceNeural",
|
| 278 |
-
"gl-ES-RoiNeural",
|
| 279 |
-
"gl-ES-SabelaNeural",
|
| 280 |
-
"ka-GE-EkaNeural",
|
| 281 |
-
"ka-GE-GiorgiNeural",
|
| 282 |
-
"de-AT-IngridNeural",
|
| 283 |
-
"de-AT-JonasNeural",
|
| 284 |
-
"de-DE-SeraphinaMultilingualNeural",
|
| 285 |
-
"de-DE-FlorianMultilingualNeural",
|
| 286 |
-
"de-DE-AmalaNeural",
|
| 287 |
-
"de-DE-ConradNeural",
|
| 288 |
-
"de-DE-KatjaNeural",
|
| 289 |
-
"de-DE-KillianNeural",
|
| 290 |
-
"de-CH-JanNeural",
|
| 291 |
-
"de-CH-LeniNeural",
|
| 292 |
-
"el-GR-AthinaNeural",
|
| 293 |
-
"el-GR-NestorasNeural",
|
| 294 |
-
"gu-IN-DhwaniNeural",
|
| 295 |
-
"gu-IN-NiranjanNeural",
|
| 296 |
-
"he-IL-AvriNeural",
|
| 297 |
-
"he-IL-HilaNeural",
|
| 298 |
-
"hi-IN-MadhurNeural",
|
| 299 |
-
"hi-IN-SwaraNeural",
|
| 300 |
-
"hu-HU-NoemiNeural",
|
| 301 |
-
"hu-HU-TamasNeural",
|
| 302 |
-
"is-IS-GudrunNeural",
|
| 303 |
-
"is-IS-GunnarNeural",
|
| 304 |
-
"id-ID-ArdiNeural",
|
| 305 |
-
"id-ID-GadisNeural",
|
| 306 |
-
"ga-IE-ColmNeural",
|
| 307 |
-
"ga-IE-OrlaNeural",
|
| 308 |
-
"it-IT-GiuseppeNeural",
|
| 309 |
-
"it-IT-DiegoNeural",
|
| 310 |
-
"it-IT-ElsaNeural",
|
| 311 |
-
"it-IT-IsabellaNeural",
|
| 312 |
-
"ja-JP-KeitaNeural",
|
| 313 |
-
"ja-JP-NanamiNeural",
|
| 314 |
-
"jv-ID-DimasNeural",
|
| 315 |
-
"jv-ID-SitiNeural",
|
| 316 |
-
"kn-IN-GaganNeural",
|
| 317 |
-
"kn-IN-SapnaNeural",
|
| 318 |
-
"kk-KZ-AigulNeural",
|
| 319 |
-
"kk-KZ-DauletNeural",
|
| 320 |
-
"km-KH-PisethNeural",
|
| 321 |
-
"km-KH-SreymomNeural",
|
| 322 |
-
"ko-KR-HyunsuNeural",
|
| 323 |
-
"ko-KR-InJoonNeural",
|
| 324 |
-
"ko-KR-SunHiNeural",
|
| 325 |
-
"lo-LA-ChanthavongNeural",
|
| 326 |
-
"lo-LA-KeomanyNeural",
|
| 327 |
-
"lv-LV-EveritaNeural",
|
| 328 |
-
"lv-LV-NilsNeural",
|
| 329 |
-
"lt-LT-LeonasNeural",
|
| 330 |
-
"lt-LT-OnaNeural",
|
| 331 |
-
"mk-MK-AleksandarNeural",
|
| 332 |
-
"mk-MK-MarijaNeural",
|
| 333 |
-
"ms-MY-OsmanNeural",
|
| 334 |
-
"ms-MY-YasminNeural",
|
| 335 |
-
"ml-IN-MidhunNeural",
|
| 336 |
-
"ml-IN-SobhanaNeural",
|
| 337 |
-
"mt-MT-GraceNeural",
|
| 338 |
-
"mt-MT-JosephNeural",
|
| 339 |
-
"mr-IN-AarohiNeural",
|
| 340 |
-
"mr-IN-ManoharNeural",
|
| 341 |
-
"mn-MN-BataaNeural",
|
| 342 |
-
"mn-MN-YesuiNeural",
|
| 343 |
-
"ne-NP-HemkalaNeural",
|
| 344 |
-
"ne-NP-SagarNeural",
|
| 345 |
-
"nb-NO-FinnNeural",
|
| 346 |
-
"nb-NO-PernilleNeural",
|
| 347 |
-
"ps-AF-GulNawazNeural",
|
| 348 |
-
"ps-AF-LatifaNeural",
|
| 349 |
-
"fa-IR-DilaraNeural",
|
| 350 |
-
"fa-IR-FaridNeural",
|
| 351 |
-
"pl-PL-MarekNeural",
|
| 352 |
-
"pl-PL-ZofiaNeural",
|
| 353 |
-
"pt-BR-ThalitaNeural",
|
| 354 |
-
"pt-BR-AntonioNeural",
|
| 355 |
-
"pt-BR-FranciscaNeural",
|
| 356 |
-
"pt-PT-DuarteNeural",
|
| 357 |
-
"pt-PT-RaquelNeural",
|
| 358 |
-
"ro-RO-AlinaNeural",
|
| 359 |
-
"ro-RO-EmilNeural",
|
| 360 |
-
"ru-RU-DmitryNeural",
|
| 361 |
-
"ru-RU-SvetlanaNeural",
|
| 362 |
-
"sr-RS-NicholasNeural",
|
| 363 |
-
"sr-RS-SophieNeural",
|
| 364 |
-
"si-LK-SameeraNeural",
|
| 365 |
-
"si-LK-ThiliniNeural",
|
| 366 |
-
"sk-SK-LukasNeural",
|
| 367 |
-
"sk-SK-ViktoriaNeural",
|
| 368 |
-
"sl-SI-PetraNeural",
|
| 369 |
-
"sl-SI-RokNeural",
|
| 370 |
-
"so-SO-MuuseNeural",
|
| 371 |
-
"so-SO-UbaxNeural",
|
| 372 |
-
"es-AR-ElenaNeural",
|
| 373 |
-
"es-AR-TomasNeural",
|
| 374 |
-
"es-BO-MarceloNeural",
|
| 375 |
-
"es-BO-SofiaNeural",
|
| 376 |
-
"es-CL-CatalinaNeural",
|
| 377 |
-
"es-CL-LorenzoNeural",
|
| 378 |
-
"es-ES-XimenaNeural",
|
| 379 |
-
"es-CO-GonzaloNeural",
|
| 380 |
-
"es-CO-SalomeNeural",
|
| 381 |
-
"es-CR-JuanNeural",
|
| 382 |
-
"es-CR-MariaNeural",
|
| 383 |
-
"es-CU-BelkysNeural",
|
| 384 |
-
"es-CU-ManuelNeural",
|
| 385 |
-
"es-DO-EmilioNeural",
|
| 386 |
-
"es-DO-RamonaNeural",
|
| 387 |
-
"es-EC-AndreaNeural",
|
| 388 |
-
"es-EC-LuisNeural",
|
| 389 |
-
"es-SV-LorenaNeural",
|
| 390 |
-
"es-SV-RodrigoNeural",
|
| 391 |
-
"es-GQ-JavierNeural",
|
| 392 |
-
"es-GQ-TeresaNeural",
|
| 393 |
-
"es-GT-AndresNeural",
|
| 394 |
-
"es-GT-MartaNeural",
|
| 395 |
-
"es-HN-CarlosNeural",
|
| 396 |
-
"es-HN-KarlaNeural",
|
| 397 |
-
"es-MX-DaliaNeural",
|
| 398 |
-
"es-MX-JorgeNeural",
|
| 399 |
-
"es-NI-FedericoNeural",
|
| 400 |
-
"es-NI-YolandaNeural",
|
| 401 |
-
"es-PA-MargaritaNeural",
|
| 402 |
-
"es-PA-RobertoNeural",
|
| 403 |
-
"es-PY-MarioNeural",
|
| 404 |
-
"es-PY-TaniaNeural",
|
| 405 |
-
"es-PE-AlexNeural",
|
| 406 |
-
"es-PE-CamilaNeural",
|
| 407 |
-
"es-PR-KarinaNeural",
|
| 408 |
-
"es-PR-VictorNeural",
|
| 409 |
-
"es-ES-AlvaroNeural",
|
| 410 |
-
"es-ES-ElviraNeural",
|
| 411 |
-
"es-US-AlonsoNeural",
|
| 412 |
-
"es-US-PalomaNeural",
|
| 413 |
-
"es-UY-MateoNeural",
|
| 414 |
-
"es-UY-ValentinaNeural",
|
| 415 |
-
"es-VE-PaolaNeural",
|
| 416 |
-
"es-VE-SebastianNeural",
|
| 417 |
-
"su-ID-JajangNeural",
|
| 418 |
-
"su-ID-TutiNeural",
|
| 419 |
-
"sw-KE-RafikiNeural",
|
| 420 |
-
"sw-KE-ZuriNeural",
|
| 421 |
-
"sw-TZ-DaudiNeural",
|
| 422 |
-
"sw-TZ-RehemaNeural",
|
| 423 |
-
"sv-SE-MattiasNeural",
|
| 424 |
-
"sv-SE-SofieNeural",
|
| 425 |
-
"ta-IN-PallaviNeural",
|
| 426 |
-
"ta-IN-ValluvarNeural",
|
| 427 |
-
"ta-MY-KaniNeural",
|
| 428 |
-
"ta-MY-SuryaNeural",
|
| 429 |
-
"ta-SG-AnbuNeural",
|
| 430 |
-
"ta-SG-VenbaNeural",
|
| 431 |
-
"ta-LK-KumarNeural",
|
| 432 |
-
"ta-LK-SaranyaNeural",
|
| 433 |
-
"te-IN-MohanNeural",
|
| 434 |
-
"te-IN-ShrutiNeural",
|
| 435 |
-
"th-TH-NiwatNeural",
|
| 436 |
-
"th-TH-PremwadeeNeural",
|
| 437 |
-
"tr-TR-AhmetNeural",
|
| 438 |
-
"tr-TR-EmelNeural",
|
| 439 |
-
"uk-UA-OstapNeural",
|
| 440 |
-
"uk-UA-PolinaNeural",
|
| 441 |
-
"ur-IN-GulNeural",
|
| 442 |
-
"ur-IN-SalmanNeural",
|
| 443 |
-
"ur-PK-AsadNeural",
|
| 444 |
-
"ur-PK-UzmaNeural",
|
| 445 |
-
"uz-UZ-MadinaNeural",
|
| 446 |
-
"uz-UZ-SardorNeural",
|
| 447 |
-
"vi-VN-HoaiMyNeural",
|
| 448 |
-
"vi-VN-NamMinhNeural",
|
| 449 |
-
"cy-GB-AledNeural",
|
| 450 |
-
"cy-GB-NiaNeural",
|
| 451 |
-
"zu-ZA-ThandoNeural",
|
| 452 |
-
"zu-ZA-ThembaNeural"
|
| 453 |
-
],
|
| 454 |
-
"google_tts_voice": [
|
| 455 |
-
"af",
|
| 456 |
-
"am",
|
| 457 |
-
"ar",
|
| 458 |
-
"bg",
|
| 459 |
-
"bn",
|
| 460 |
-
"bs",
|
| 461 |
-
"ca",
|
| 462 |
-
"cs",
|
| 463 |
-
"cy",
|
| 464 |
-
"da",
|
| 465 |
-
"de",
|
| 466 |
-
"el",
|
| 467 |
-
"en",
|
| 468 |
-
"es",
|
| 469 |
-
"et",
|
| 470 |
-
"eu",
|
| 471 |
-
"fi",
|
| 472 |
-
"fr",
|
| 473 |
-
"fr-CA",
|
| 474 |
-
"gl",
|
| 475 |
-
"gu",
|
| 476 |
-
"ha",
|
| 477 |
-
"hi",
|
| 478 |
-
"hr",
|
| 479 |
-
"hu",
|
| 480 |
-
"id",
|
| 481 |
-
"is",
|
| 482 |
-
"it",
|
| 483 |
-
"iw",
|
| 484 |
-
"ja",
|
| 485 |
-
"jw",
|
| 486 |
-
"km",
|
| 487 |
-
"kn",
|
| 488 |
-
"ko",
|
| 489 |
-
"la",
|
| 490 |
-
"lt",
|
| 491 |
-
"lv",
|
| 492 |
-
"ml",
|
| 493 |
-
"mr",
|
| 494 |
-
"ms",
|
| 495 |
-
"my",
|
| 496 |
-
"ne",
|
| 497 |
-
"nl",
|
| 498 |
-
"no",
|
| 499 |
-
"pa",
|
| 500 |
-
"pl",
|
| 501 |
-
"pt",
|
| 502 |
-
"pt-PT",
|
| 503 |
-
"ro",
|
| 504 |
-
"ru",
|
| 505 |
-
"si",
|
| 506 |
-
"sk",
|
| 507 |
-
"sq",
|
| 508 |
-
"sr",
|
| 509 |
-
"su",
|
| 510 |
-
"sv",
|
| 511 |
-
"sw",
|
| 512 |
-
"ta",
|
| 513 |
-
"te",
|
| 514 |
-
"th",
|
| 515 |
-
"tl",
|
| 516 |
-
"tr",
|
| 517 |
-
"uk",
|
| 518 |
-
"ur",
|
| 519 |
-
"vi",
|
| 520 |
-
"yue",
|
| 521 |
-
"zh-CN",
|
| 522 |
-
"zh-TW",
|
| 523 |
-
"zh"
|
| 524 |
-
],
|
| 525 |
-
"fp16": true,
|
| 526 |
-
"separator_tab": true,
|
| 527 |
-
"convert_tab": true,
|
| 528 |
-
"convert_with_whisper": true,
|
| 529 |
-
"tts_tab": true,
|
| 530 |
-
"audioldm2": true,
|
| 531 |
-
"effects_tab": true,
|
| 532 |
-
"create_dataset_tab": true,
|
| 533 |
-
"training_tab": true,
|
| 534 |
-
"fushion_tab": true,
|
| 535 |
-
"read_tab": true,
|
| 536 |
-
"onnx_tab": true,
|
| 537 |
-
"downloads_tab": true,
|
| 538 |
-
"f0_extractor_tab": true,
|
| 539 |
-
"settings_tab": true,
|
| 540 |
-
"report_bug_tab": true,
|
| 541 |
-
"font": "https://fonts.googleapis.com/css2?family=Shadows+Into+Light&display=swap",
|
| 542 |
-
"app_port": 7860,
|
| 543 |
-
"tensorboard_port": 6870,
|
| 544 |
-
"num_of_restart": 5,
|
| 545 |
-
"server_name": "0.0.0.0",
|
| 546 |
-
"app_show_error": true
|
| 547 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/config.py
DELETED
|
@@ -1,90 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import json
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
version_config_paths = [os.path.join(version, size) for version in ["v1", "v2"] for size in ["32000.json", "40000.json", "48000.json"]]
|
| 7 |
-
|
| 8 |
-
def singleton(cls):
|
| 9 |
-
instances = {}
|
| 10 |
-
|
| 11 |
-
def get_instance(*args, **kwargs):
|
| 12 |
-
if cls not in instances: instances[cls] = cls(*args, **kwargs)
|
| 13 |
-
return instances[cls]
|
| 14 |
-
|
| 15 |
-
return get_instance
|
| 16 |
-
|
| 17 |
-
@singleton
|
| 18 |
-
class Config:
|
| 19 |
-
def __init__(self):
|
| 20 |
-
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 21 |
-
self.configs = json.load(open(os.path.join("main", "configs", "config.json"), "r"))
|
| 22 |
-
self.translations = self.multi_language()
|
| 23 |
-
self.json_config = self.load_config_json()
|
| 24 |
-
self.gpu_mem = None
|
| 25 |
-
self.per_preprocess = 3.7
|
| 26 |
-
self.is_half = self.is_fp16()
|
| 27 |
-
self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
|
| 28 |
-
|
| 29 |
-
def multi_language(self):
|
| 30 |
-
try:
|
| 31 |
-
lang = self.configs.get("language", "vi-VN")
|
| 32 |
-
if len([l for l in os.listdir(os.path.join("assets", "languages")) if l.endswith(".json")]) < 1: raise FileNotFoundError("Không tìm thấy bất cứ gói ngôn ngữ nào(No package languages found)")
|
| 33 |
-
|
| 34 |
-
if not lang: lang = "vi-VN"
|
| 35 |
-
if lang not in self.configs["support_language"]: raise ValueError("Ngôn ngữ không được hỗ trợ(Language not supported)")
|
| 36 |
-
|
| 37 |
-
lang_path = os.path.join("assets", "languages", f"{lang}.json")
|
| 38 |
-
if not os.path.exists(lang_path): lang_path = os.path.join("assets", "languages", "vi-VN.json")
|
| 39 |
-
|
| 40 |
-
with open(lang_path, encoding="utf-8") as f:
|
| 41 |
-
translations = json.load(f)
|
| 42 |
-
except json.JSONDecodeError:
|
| 43 |
-
print(self.translations["empty_json"].format(file=lang))
|
| 44 |
-
pass
|
| 45 |
-
|
| 46 |
-
return translations
|
| 47 |
-
|
| 48 |
-
def is_fp16(self):
|
| 49 |
-
fp16 = self.configs.get("fp16", False)
|
| 50 |
-
|
| 51 |
-
if self.device in ["cpu", "mps"] and fp16:
|
| 52 |
-
self.configs["fp16"] = False
|
| 53 |
-
fp16 = False
|
| 54 |
-
|
| 55 |
-
with open(os.path.join("main", "configs", "config.json"), "w") as f:
|
| 56 |
-
json.dump(self.configs, f, indent=4)
|
| 57 |
-
|
| 58 |
-
if not fp16: self.preprocess_per = 3.0
|
| 59 |
-
return fp16
|
| 60 |
-
|
| 61 |
-
def load_config_json(self):
|
| 62 |
-
configs = {}
|
| 63 |
-
|
| 64 |
-
for config_file in version_config_paths:
|
| 65 |
-
try:
|
| 66 |
-
with open(os.path.join("main", "configs", config_file), "r") as f:
|
| 67 |
-
configs[config_file] = json.load(f)
|
| 68 |
-
except json.JSONDecodeError:
|
| 69 |
-
print(self.translations["empty_json"].format(file=config_file))
|
| 70 |
-
pass
|
| 71 |
-
|
| 72 |
-
return configs
|
| 73 |
-
|
| 74 |
-
def device_config(self):
|
| 75 |
-
if self.device.startswith("cuda"): self.set_cuda_config()
|
| 76 |
-
elif self.has_mps(): self.device = "mps"
|
| 77 |
-
else: self.device = "cpu"
|
| 78 |
-
|
| 79 |
-
if self.gpu_mem is not None and self.gpu_mem <= 4:
|
| 80 |
-
self.preprocess_per = 3.0
|
| 81 |
-
return 1, 5, 30, 32
|
| 82 |
-
|
| 83 |
-
return (3, 10, 60, 65) if self.is_half else (1, 6, 38, 41)
|
| 84 |
-
|
| 85 |
-
def set_cuda_config(self):
|
| 86 |
-
i_device = int(self.device.split(":")[-1])
|
| 87 |
-
self.gpu_mem = torch.cuda.get_device_properties(i_device).total_memory // (1024**3)
|
| 88 |
-
|
| 89 |
-
def has_mps(self):
|
| 90 |
-
return torch.backends.mps.is_available()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/decrypt.bin
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:330268cbf6b9317a76510b533e1640ef48ed074a07c013e5b1abc4d48cfd9dce
|
| 3 |
-
size 32
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v1/32000.json
DELETED
|
@@ -1,46 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"train": {
|
| 3 |
-
"log_interval": 200,
|
| 4 |
-
"seed": 1234,
|
| 5 |
-
"epochs": 20000,
|
| 6 |
-
"learning_rate": 0.0001,
|
| 7 |
-
"betas": [0.8, 0.99],
|
| 8 |
-
"eps": 1e-09,
|
| 9 |
-
"batch_size": 4,
|
| 10 |
-
"lr_decay": 0.999875,
|
| 11 |
-
"segment_size": 12800,
|
| 12 |
-
"init_lr_ratio": 1,
|
| 13 |
-
"warmup_epochs": 0,
|
| 14 |
-
"c_mel": 45,
|
| 15 |
-
"c_kl": 1.0
|
| 16 |
-
},
|
| 17 |
-
"data": {
|
| 18 |
-
"max_wav_value": 32768.0,
|
| 19 |
-
"sample_rate": 32000,
|
| 20 |
-
"filter_length": 1024,
|
| 21 |
-
"hop_length": 320,
|
| 22 |
-
"win_length": 1024,
|
| 23 |
-
"n_mel_channels": 80,
|
| 24 |
-
"mel_fmin": 0.0,
|
| 25 |
-
"mel_fmax": null
|
| 26 |
-
},
|
| 27 |
-
"model": {
|
| 28 |
-
"inter_channels": 192,
|
| 29 |
-
"hidden_channels": 192,
|
| 30 |
-
"filter_channels": 768,
|
| 31 |
-
"text_enc_hidden_dim": 256,
|
| 32 |
-
"n_heads": 2,
|
| 33 |
-
"n_layers": 6,
|
| 34 |
-
"kernel_size": 3,
|
| 35 |
-
"p_dropout": 0,
|
| 36 |
-
"resblock": "1",
|
| 37 |
-
"resblock_kernel_sizes": [3, 7, 11],
|
| 38 |
-
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 39 |
-
"upsample_rates": [10, 4, 2, 2, 2],
|
| 40 |
-
"upsample_initial_channel": 512,
|
| 41 |
-
"upsample_kernel_sizes": [16, 16, 4, 4, 4],
|
| 42 |
-
"use_spectral_norm": false,
|
| 43 |
-
"gin_channels": 256,
|
| 44 |
-
"spk_embed_dim": 109
|
| 45 |
-
}
|
| 46 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v1/40000.json
DELETED
|
@@ -1,46 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"train": {
|
| 3 |
-
"log_interval": 200,
|
| 4 |
-
"seed": 1234,
|
| 5 |
-
"epochs": 20000,
|
| 6 |
-
"learning_rate": 0.0001,
|
| 7 |
-
"betas": [0.8, 0.99],
|
| 8 |
-
"eps": 1e-09,
|
| 9 |
-
"batch_size": 4,
|
| 10 |
-
"lr_decay": 0.999875,
|
| 11 |
-
"segment_size": 12800,
|
| 12 |
-
"init_lr_ratio": 1,
|
| 13 |
-
"warmup_epochs": 0,
|
| 14 |
-
"c_mel": 45,
|
| 15 |
-
"c_kl": 1.0
|
| 16 |
-
},
|
| 17 |
-
"data": {
|
| 18 |
-
"max_wav_value": 32768.0,
|
| 19 |
-
"sample_rate": 40000,
|
| 20 |
-
"filter_length": 2048,
|
| 21 |
-
"hop_length": 400,
|
| 22 |
-
"win_length": 2048,
|
| 23 |
-
"n_mel_channels": 125,
|
| 24 |
-
"mel_fmin": 0.0,
|
| 25 |
-
"mel_fmax": null
|
| 26 |
-
},
|
| 27 |
-
"model": {
|
| 28 |
-
"inter_channels": 192,
|
| 29 |
-
"hidden_channels": 192,
|
| 30 |
-
"filter_channels": 768,
|
| 31 |
-
"text_enc_hidden_dim": 256,
|
| 32 |
-
"n_heads": 2,
|
| 33 |
-
"n_layers": 6,
|
| 34 |
-
"kernel_size": 3,
|
| 35 |
-
"p_dropout": 0,
|
| 36 |
-
"resblock": "1",
|
| 37 |
-
"resblock_kernel_sizes": [3, 7, 11],
|
| 38 |
-
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 39 |
-
"upsample_rates": [10, 10, 2, 2],
|
| 40 |
-
"upsample_initial_channel": 512,
|
| 41 |
-
"upsample_kernel_sizes": [16, 16, 4, 4],
|
| 42 |
-
"use_spectral_norm": false,
|
| 43 |
-
"gin_channels": 256,
|
| 44 |
-
"spk_embed_dim": 109
|
| 45 |
-
}
|
| 46 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v1/48000.json
DELETED
|
@@ -1,46 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"train": {
|
| 3 |
-
"log_interval": 200,
|
| 4 |
-
"seed": 1234,
|
| 5 |
-
"epochs": 20000,
|
| 6 |
-
"learning_rate": 0.0001,
|
| 7 |
-
"betas": [0.8, 0.99],
|
| 8 |
-
"eps": 1e-09,
|
| 9 |
-
"batch_size": 4,
|
| 10 |
-
"lr_decay": 0.999875,
|
| 11 |
-
"segment_size": 11520,
|
| 12 |
-
"init_lr_ratio": 1,
|
| 13 |
-
"warmup_epochs": 0,
|
| 14 |
-
"c_mel": 45,
|
| 15 |
-
"c_kl": 1.0
|
| 16 |
-
},
|
| 17 |
-
"data": {
|
| 18 |
-
"max_wav_value": 32768.0,
|
| 19 |
-
"sample_rate": 48000,
|
| 20 |
-
"filter_length": 2048,
|
| 21 |
-
"hop_length": 480,
|
| 22 |
-
"win_length": 2048,
|
| 23 |
-
"n_mel_channels": 128,
|
| 24 |
-
"mel_fmin": 0.0,
|
| 25 |
-
"mel_fmax": null
|
| 26 |
-
},
|
| 27 |
-
"model": {
|
| 28 |
-
"inter_channels": 192,
|
| 29 |
-
"hidden_channels": 192,
|
| 30 |
-
"filter_channels": 768,
|
| 31 |
-
"text_enc_hidden_dim": 256,
|
| 32 |
-
"n_heads": 2,
|
| 33 |
-
"n_layers": 6,
|
| 34 |
-
"kernel_size": 3,
|
| 35 |
-
"p_dropout": 0,
|
| 36 |
-
"resblock": "1",
|
| 37 |
-
"resblock_kernel_sizes": [3, 7, 11],
|
| 38 |
-
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 39 |
-
"upsample_rates": [10, 6, 2, 2, 2],
|
| 40 |
-
"upsample_initial_channel": 512,
|
| 41 |
-
"upsample_kernel_sizes": [16, 16, 4, 4, 4],
|
| 42 |
-
"use_spectral_norm": false,
|
| 43 |
-
"gin_channels": 256,
|
| 44 |
-
"spk_embed_dim": 109
|
| 45 |
-
}
|
| 46 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v2/32000.json
DELETED
|
@@ -1,42 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"train": {
|
| 3 |
-
"log_interval": 200,
|
| 4 |
-
"seed": 1234,
|
| 5 |
-
"learning_rate": 0.0001,
|
| 6 |
-
"betas": [0.8, 0.99],
|
| 7 |
-
"eps": 1e-09,
|
| 8 |
-
"lr_decay": 0.999875,
|
| 9 |
-
"segment_size": 12800,
|
| 10 |
-
"c_mel": 45,
|
| 11 |
-
"c_kl": 1.0
|
| 12 |
-
},
|
| 13 |
-
"data": {
|
| 14 |
-
"max_wav_value": 32768.0,
|
| 15 |
-
"sample_rate": 32000,
|
| 16 |
-
"filter_length": 1024,
|
| 17 |
-
"hop_length": 320,
|
| 18 |
-
"win_length": 1024,
|
| 19 |
-
"n_mel_channels": 80,
|
| 20 |
-
"mel_fmin": 0.0,
|
| 21 |
-
"mel_fmax": null
|
| 22 |
-
},
|
| 23 |
-
"model": {
|
| 24 |
-
"inter_channels": 192,
|
| 25 |
-
"hidden_channels": 192,
|
| 26 |
-
"filter_channels": 768,
|
| 27 |
-
"text_enc_hidden_dim": 768,
|
| 28 |
-
"n_heads": 2,
|
| 29 |
-
"n_layers": 6,
|
| 30 |
-
"kernel_size": 3,
|
| 31 |
-
"p_dropout": 0,
|
| 32 |
-
"resblock": "1",
|
| 33 |
-
"resblock_kernel_sizes": [3, 7, 11],
|
| 34 |
-
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 35 |
-
"upsample_rates": [10, 8, 2, 2],
|
| 36 |
-
"upsample_initial_channel": 512,
|
| 37 |
-
"upsample_kernel_sizes": [20, 16, 4, 4],
|
| 38 |
-
"use_spectral_norm": false,
|
| 39 |
-
"gin_channels": 256,
|
| 40 |
-
"spk_embed_dim": 109
|
| 41 |
-
}
|
| 42 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v2/40000.json
DELETED
|
@@ -1,42 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"train": {
|
| 3 |
-
"log_interval": 200,
|
| 4 |
-
"seed": 1234,
|
| 5 |
-
"learning_rate": 0.0001,
|
| 6 |
-
"betas": [0.8, 0.99],
|
| 7 |
-
"eps": 1e-09,
|
| 8 |
-
"lr_decay": 0.999875,
|
| 9 |
-
"segment_size": 12800,
|
| 10 |
-
"c_mel": 45,
|
| 11 |
-
"c_kl": 1.0
|
| 12 |
-
},
|
| 13 |
-
"data": {
|
| 14 |
-
"max_wav_value": 32768.0,
|
| 15 |
-
"sample_rate": 40000,
|
| 16 |
-
"filter_length": 2048,
|
| 17 |
-
"hop_length": 400,
|
| 18 |
-
"win_length": 2048,
|
| 19 |
-
"n_mel_channels": 125,
|
| 20 |
-
"mel_fmin": 0.0,
|
| 21 |
-
"mel_fmax": null
|
| 22 |
-
},
|
| 23 |
-
"model": {
|
| 24 |
-
"inter_channels": 192,
|
| 25 |
-
"hidden_channels": 192,
|
| 26 |
-
"filter_channels": 768,
|
| 27 |
-
"text_enc_hidden_dim": 768,
|
| 28 |
-
"n_heads": 2,
|
| 29 |
-
"n_layers": 6,
|
| 30 |
-
"kernel_size": 3,
|
| 31 |
-
"p_dropout": 0,
|
| 32 |
-
"resblock": "1",
|
| 33 |
-
"resblock_kernel_sizes": [3, 7, 11],
|
| 34 |
-
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 35 |
-
"upsample_rates": [10, 10, 2, 2],
|
| 36 |
-
"upsample_initial_channel": 512,
|
| 37 |
-
"upsample_kernel_sizes": [16, 16, 4, 4],
|
| 38 |
-
"use_spectral_norm": false,
|
| 39 |
-
"gin_channels": 256,
|
| 40 |
-
"spk_embed_dim": 109
|
| 41 |
-
}
|
| 42 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v2/48000.json
DELETED
|
@@ -1,42 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"train": {
|
| 3 |
-
"log_interval": 200,
|
| 4 |
-
"seed": 1234,
|
| 5 |
-
"learning_rate": 0.0001,
|
| 6 |
-
"betas": [0.8, 0.99],
|
| 7 |
-
"eps": 1e-09,
|
| 8 |
-
"lr_decay": 0.999875,
|
| 9 |
-
"segment_size": 17280,
|
| 10 |
-
"c_mel": 45,
|
| 11 |
-
"c_kl": 1.0
|
| 12 |
-
},
|
| 13 |
-
"data": {
|
| 14 |
-
"max_wav_value": 32768.0,
|
| 15 |
-
"sample_rate": 48000,
|
| 16 |
-
"filter_length": 2048,
|
| 17 |
-
"hop_length": 480,
|
| 18 |
-
"win_length": 2048,
|
| 19 |
-
"n_mel_channels": 128,
|
| 20 |
-
"mel_fmin": 0.0,
|
| 21 |
-
"mel_fmax": null
|
| 22 |
-
},
|
| 23 |
-
"model": {
|
| 24 |
-
"inter_channels": 192,
|
| 25 |
-
"hidden_channels": 192,
|
| 26 |
-
"filter_channels": 768,
|
| 27 |
-
"text_enc_hidden_dim": 768,
|
| 28 |
-
"n_heads": 2,
|
| 29 |
-
"n_layers": 6,
|
| 30 |
-
"kernel_size": 3,
|
| 31 |
-
"p_dropout": 0,
|
| 32 |
-
"resblock": "1",
|
| 33 |
-
"resblock_kernel_sizes": [3, 7, 11],
|
| 34 |
-
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 35 |
-
"upsample_rates": [12, 10, 2, 2],
|
| 36 |
-
"upsample_initial_channel": 512,
|
| 37 |
-
"upsample_kernel_sizes": [24, 20, 4, 4],
|
| 38 |
-
"use_spectral_norm": false,
|
| 39 |
-
"gin_channels": 256,
|
| 40 |
-
"spk_embed_dim": 109
|
| 41 |
-
}
|
| 42 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/audio_effects.py
DELETED
|
@@ -1,180 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import librosa
|
| 4 |
-
import argparse
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
import soundfile as sf
|
| 8 |
-
|
| 9 |
-
from distutils.util import strtobool
|
| 10 |
-
from scipy.signal import butter, filtfilt
|
| 11 |
-
from pedalboard import Pedalboard, Chorus, Distortion, Reverb, PitchShift, Delay, Limiter, Gain, Bitcrush, Clipping, Compressor, Phaser, HighpassFilter
|
| 12 |
-
|
| 13 |
-
sys.path.append(os.getcwd())
|
| 14 |
-
|
| 15 |
-
from main.configs.config import Config
|
| 16 |
-
from main.library.utils import pydub_convert, pydub_load
|
| 17 |
-
|
| 18 |
-
translations = Config().translations
|
| 19 |
-
|
| 20 |
-
def parse_arguments():
|
| 21 |
-
parser = argparse.ArgumentParser()
|
| 22 |
-
parser.add_argument("--input_path", type=str, required=True)
|
| 23 |
-
parser.add_argument("--output_path", type=str, default="./audios/apply_effects.wav")
|
| 24 |
-
parser.add_argument("--export_format", type=str, default="wav")
|
| 25 |
-
parser.add_argument("--resample", type=lambda x: bool(strtobool(x)), default=False)
|
| 26 |
-
parser.add_argument("--resample_sr", type=int, default=0)
|
| 27 |
-
parser.add_argument("--chorus", type=lambda x: bool(strtobool(x)), default=False)
|
| 28 |
-
parser.add_argument("--chorus_depth", type=float, default=0.5)
|
| 29 |
-
parser.add_argument("--chorus_rate", type=float, default=1.5)
|
| 30 |
-
parser.add_argument("--chorus_mix", type=float, default=0.5)
|
| 31 |
-
parser.add_argument("--chorus_delay", type=int, default=10)
|
| 32 |
-
parser.add_argument("--chorus_feedback", type=float, default=0)
|
| 33 |
-
parser.add_argument("--distortion", type=lambda x: bool(strtobool(x)), default=False)
|
| 34 |
-
parser.add_argument("--drive_db", type=int, default=20)
|
| 35 |
-
parser.add_argument("--reverb", type=lambda x: bool(strtobool(x)), default=False)
|
| 36 |
-
parser.add_argument("--reverb_room_size", type=float, default=0.5)
|
| 37 |
-
parser.add_argument("--reverb_damping", type=float, default=0.5)
|
| 38 |
-
parser.add_argument("--reverb_wet_level", type=float, default=0.33)
|
| 39 |
-
parser.add_argument("--reverb_dry_level", type=float, default=0.67)
|
| 40 |
-
parser.add_argument("--reverb_width", type=float, default=1)
|
| 41 |
-
parser.add_argument("--reverb_freeze_mode", type=lambda x: bool(strtobool(x)), default=False)
|
| 42 |
-
parser.add_argument("--pitchshift", type=lambda x: bool(strtobool(x)), default=False)
|
| 43 |
-
parser.add_argument("--pitch_shift", type=int, default=0)
|
| 44 |
-
parser.add_argument("--delay", type=lambda x: bool(strtobool(x)), default=False)
|
| 45 |
-
parser.add_argument("--delay_seconds", type=float, default=0.5)
|
| 46 |
-
parser.add_argument("--delay_feedback", type=float, default=0.5)
|
| 47 |
-
parser.add_argument("--delay_mix", type=float, default=0.5)
|
| 48 |
-
parser.add_argument("--compressor", type=lambda x: bool(strtobool(x)), default=False)
|
| 49 |
-
parser.add_argument("--compressor_threshold", type=int, default=-20)
|
| 50 |
-
parser.add_argument("--compressor_ratio", type=float, default=4)
|
| 51 |
-
parser.add_argument("--compressor_attack_ms", type=float, default=10)
|
| 52 |
-
parser.add_argument("--compressor_release_ms", type=int, default=200)
|
| 53 |
-
parser.add_argument("--limiter", type=lambda x: bool(strtobool(x)), default=False)
|
| 54 |
-
parser.add_argument("--limiter_threshold", type=int, default=0)
|
| 55 |
-
parser.add_argument("--limiter_release", type=int, default=100)
|
| 56 |
-
parser.add_argument("--gain", type=lambda x: bool(strtobool(x)), default=False)
|
| 57 |
-
parser.add_argument("--gain_db", type=int, default=0)
|
| 58 |
-
parser.add_argument("--bitcrush", type=lambda x: bool(strtobool(x)), default=False)
|
| 59 |
-
parser.add_argument("--bitcrush_bit_depth", type=int, default=16)
|
| 60 |
-
parser.add_argument("--clipping", type=lambda x: bool(strtobool(x)), default=False)
|
| 61 |
-
parser.add_argument("--clipping_threshold", type=int, default=-10)
|
| 62 |
-
parser.add_argument("--phaser", type=lambda x: bool(strtobool(x)), default=False)
|
| 63 |
-
parser.add_argument("--phaser_rate_hz", type=float, default=0.5)
|
| 64 |
-
parser.add_argument("--phaser_depth", type=float, default=0.5)
|
| 65 |
-
parser.add_argument("--phaser_centre_frequency_hz", type=int, default=1000)
|
| 66 |
-
parser.add_argument("--phaser_feedback", type=float, default=0)
|
| 67 |
-
parser.add_argument("--phaser_mix", type=float, default=0.5)
|
| 68 |
-
parser.add_argument("--treble_bass_boost", type=lambda x: bool(strtobool(x)), default=False)
|
| 69 |
-
parser.add_argument("--bass_boost_db", type=int, default=0)
|
| 70 |
-
parser.add_argument("--bass_boost_frequency", type=int, default=100)
|
| 71 |
-
parser.add_argument("--treble_boost_db", type=int, default=0)
|
| 72 |
-
parser.add_argument("--treble_boost_frequency", type=int, default=3000)
|
| 73 |
-
parser.add_argument("--fade_in_out", type=lambda x: bool(strtobool(x)), default=False)
|
| 74 |
-
parser.add_argument("--fade_in_duration", type=float, default=2000)
|
| 75 |
-
parser.add_argument("--fade_out_duration", type=float, default=2000)
|
| 76 |
-
parser.add_argument("--audio_combination", type=lambda x: bool(strtobool(x)), default=False)
|
| 77 |
-
parser.add_argument("--audio_combination_input", type=str)
|
| 78 |
-
|
| 79 |
-
return parser.parse_args()
|
| 80 |
-
|
| 81 |
-
def process_audio(input_path, output_path, resample, resample_sr, chorus_depth, chorus_rate, chorus_mix, chorus_delay, chorus_feedback, distortion_drive, reverb_room_size, reverb_damping, reverb_wet_level, reverb_dry_level, reverb_width, reverb_freeze_mode, pitch_shift, delay_seconds, delay_feedback, delay_mix, compressor_threshold, compressor_ratio, compressor_attack_ms, compressor_release_ms, limiter_threshold, limiter_release, gain_db, bitcrush_bit_depth, clipping_threshold, phaser_rate_hz, phaser_depth, phaser_centre_frequency_hz, phaser_feedback, phaser_mix, bass_boost_db, bass_boost_frequency, treble_boost_db, treble_boost_frequency, fade_in_duration, fade_out_duration, export_format, chorus, distortion, reverb, pitchshift, delay, compressor, limiter, gain, bitcrush, clipping, phaser, treble_bass_boost, fade_in_out, audio_combination, audio_combination_input):
|
| 82 |
-
def bass_boost(audio, gain_db, frequency, sample_rate):
|
| 83 |
-
if gain_db >= 1:
|
| 84 |
-
b, a = butter(4, frequency / (0.5 * sample_rate), btype='low')
|
| 85 |
-
|
| 86 |
-
return filtfilt(b, a, audio) * 10 ** (gain_db / 20)
|
| 87 |
-
else: return audio
|
| 88 |
-
|
| 89 |
-
def treble_boost(audio, gain_db, frequency, sample_rate):
|
| 90 |
-
if gain_db >=1:
|
| 91 |
-
b, a = butter(4, frequency / (0.5 * sample_rate), btype='high')
|
| 92 |
-
|
| 93 |
-
return filtfilt(b, a, audio) * 10 ** (gain_db / 20)
|
| 94 |
-
else: return audio
|
| 95 |
-
|
| 96 |
-
def fade_out_effect(audio, sr, duration=3.0):
|
| 97 |
-
length = int(duration * sr)
|
| 98 |
-
end = audio.shape[0]
|
| 99 |
-
|
| 100 |
-
if length > end: length = end
|
| 101 |
-
start = end - length
|
| 102 |
-
|
| 103 |
-
audio[start:end] = audio[start:end] * np.linspace(1.0, 0.0, length)
|
| 104 |
-
return audio
|
| 105 |
-
|
| 106 |
-
def fade_in_effect(audio, sr, duration=3.0):
|
| 107 |
-
length = int(duration * sr)
|
| 108 |
-
start = 0
|
| 109 |
-
|
| 110 |
-
if length > audio.shape[0]: length = audio.shape[0]
|
| 111 |
-
end = length
|
| 112 |
-
|
| 113 |
-
audio[start:end] = audio[start:end] * np.linspace(0.0, 1.0, length)
|
| 114 |
-
return audio
|
| 115 |
-
|
| 116 |
-
if not input_path or not os.path.exists(input_path):
|
| 117 |
-
print(translations["input_not_valid"])
|
| 118 |
-
sys.exit(1)
|
| 119 |
-
|
| 120 |
-
if not output_path:
|
| 121 |
-
print(translations["output_not_valid"])
|
| 122 |
-
sys.exit(1)
|
| 123 |
-
|
| 124 |
-
if os.path.exists(output_path): os.remove(output_path)
|
| 125 |
-
|
| 126 |
-
try:
|
| 127 |
-
input_path = input_path.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
|
| 128 |
-
|
| 129 |
-
try:
|
| 130 |
-
audio, sample_rate = sf.read(input_path, dtype=np.float32)
|
| 131 |
-
except:
|
| 132 |
-
audio, sample_rate = librosa.load(input_path, sr=None)
|
| 133 |
-
except Exception as e:
|
| 134 |
-
raise RuntimeError(f"{translations['errors_loading_audio']}: {e}")
|
| 135 |
-
|
| 136 |
-
audio = audio.flatten()
|
| 137 |
-
|
| 138 |
-
try:
|
| 139 |
-
board = Pedalboard([HighpassFilter()])
|
| 140 |
-
|
| 141 |
-
if chorus: board.append(Chorus(depth=chorus_depth, rate_hz=chorus_rate, mix=chorus_mix, centre_delay_ms=chorus_delay, feedback=chorus_feedback))
|
| 142 |
-
if distortion: board.append(Distortion(drive_db=distortion_drive))
|
| 143 |
-
if reverb: board.append(Reverb(room_size=reverb_room_size, damping=reverb_damping, wet_level=reverb_wet_level, dry_level=reverb_dry_level, width=reverb_width, freeze_mode=1 if reverb_freeze_mode else 0))
|
| 144 |
-
if pitchshift: board.append(PitchShift(semitones=pitch_shift))
|
| 145 |
-
if delay: board.append(Delay(delay_seconds=delay_seconds, feedback=delay_feedback, mix=delay_mix))
|
| 146 |
-
if compressor: board.append(Compressor(threshold_db=compressor_threshold, ratio=compressor_ratio, attack_ms=compressor_attack_ms, release_ms=compressor_release_ms))
|
| 147 |
-
if limiter: board.append(Limiter(threshold_db=limiter_threshold, release_ms=limiter_release))
|
| 148 |
-
if gain: board.append(Gain(gain_db=gain_db))
|
| 149 |
-
if bitcrush: board.append(Bitcrush(bit_depth=bitcrush_bit_depth))
|
| 150 |
-
if clipping: board.append(Clipping(threshold_db=clipping_threshold))
|
| 151 |
-
if phaser: board.append(Phaser(rate_hz=phaser_rate_hz, depth=phaser_depth, centre_frequency_hz=phaser_centre_frequency_hz, feedback=phaser_feedback, mix=phaser_mix))
|
| 152 |
-
|
| 153 |
-
processed_audio = board(audio, sample_rate)
|
| 154 |
-
|
| 155 |
-
if treble_bass_boost:
|
| 156 |
-
processed_audio = bass_boost(processed_audio, bass_boost_db, bass_boost_frequency, sample_rate)
|
| 157 |
-
processed_audio = treble_boost(processed_audio, treble_boost_db, treble_boost_frequency, sample_rate)
|
| 158 |
-
|
| 159 |
-
if fade_in_out:
|
| 160 |
-
processed_audio = fade_in_effect(processed_audio, sample_rate, fade_in_duration)
|
| 161 |
-
processed_audio = fade_out_effect(processed_audio, sample_rate, fade_out_duration)
|
| 162 |
-
|
| 163 |
-
if resample_sr != sample_rate and resample_sr > 0 and resample:
|
| 164 |
-
target_sr = min([8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100, 48000, 96000], key=lambda x: abs(x - resample_sr))
|
| 165 |
-
processed_audio = librosa.resample(processed_audio, orig_sr=sample_rate, target_sr=target_sr, res_type="soxr_vhq")
|
| 166 |
-
sample_rate = target_sr
|
| 167 |
-
|
| 168 |
-
sf.write(output_path.replace("wav", export_format), processed_audio, sample_rate, format=export_format)
|
| 169 |
-
|
| 170 |
-
if audio_combination: pydub_convert(pydub_load(audio_combination_input)).overlay(pydub_convert(pydub_load(output_path.replace("wav", export_format)))).export(output_path.replace("wav", export_format), format=export_format)
|
| 171 |
-
except Exception as e:
|
| 172 |
-
raise RuntimeError(translations["apply_error"].format(e=e))
|
| 173 |
-
|
| 174 |
-
return output_path
|
| 175 |
-
|
| 176 |
-
def main():
|
| 177 |
-
args = parse_arguments()
|
| 178 |
-
process_audio(input_path=args.input_path, output_path=args.output_path, resample=args.resample, resample_sr=args.resample_sr, chorus_depth=args.chorus_depth, chorus_rate=args.chorus_rate, chorus_mix=args.chorus_mix, chorus_delay=args.chorus_delay, chorus_feedback=args.chorus_feedback, distortion_drive=args.drive_db, reverb_room_size=args.reverb_room_size, reverb_damping=args.reverb_damping, reverb_wet_level=args.reverb_wet_level, reverb_dry_level=args.reverb_dry_level, reverb_width=args.reverb_width, reverb_freeze_mode=args.reverb_freeze_mode, pitch_shift=args.pitch_shift, delay_seconds=args.delay_seconds, delay_feedback=args.delay_feedback, delay_mix=args.delay_mix, compressor_threshold=args.compressor_threshold, compressor_ratio=args.compressor_ratio, compressor_attack_ms=args.compressor_attack_ms, compressor_release_ms=args.compressor_release_ms, limiter_threshold=args.limiter_threshold, limiter_release=args.limiter_release, gain_db=args.gain_db, bitcrush_bit_depth=args.bitcrush_bit_depth, clipping_threshold=args.clipping_threshold, phaser_rate_hz=args.phaser_rate_hz, phaser_depth=args.phaser_depth, phaser_centre_frequency_hz=args.phaser_centre_frequency_hz, phaser_feedback=args.phaser_feedback, phaser_mix=args.phaser_mix, bass_boost_db=args.bass_boost_db, bass_boost_frequency=args.bass_boost_frequency, treble_boost_db=args.treble_boost_db, treble_boost_frequency=args.treble_boost_frequency, fade_in_duration=args.fade_in_duration, fade_out_duration=args.fade_out_duration, export_format=args.export_format, chorus=args.chorus, distortion=args.distortion, reverb=args.reverb, pitchshift=args.pitchshift, delay=args.delay, compressor=args.compressor, limiter=args.limiter, gain=args.gain, bitcrush=args.bitcrush, clipping=args.clipping, phaser=args.phaser, treble_bass_boost=args.treble_bass_boost, fade_in_out=args.fade_in_out, audio_combination=args.audio_combination, audio_combination_input=args.audio_combination_input)
|
| 179 |
-
|
| 180 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/audioldm2.py
DELETED
|
@@ -1,210 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import time
|
| 4 |
-
import tqdm
|
| 5 |
-
import torch
|
| 6 |
-
import logging
|
| 7 |
-
import librosa
|
| 8 |
-
import argparse
|
| 9 |
-
import scipy.signal
|
| 10 |
-
import logging.handlers
|
| 11 |
-
|
| 12 |
-
import numpy as np
|
| 13 |
-
import soundfile as sf
|
| 14 |
-
|
| 15 |
-
from torch import inference_mode
|
| 16 |
-
from distutils.util import strtobool
|
| 17 |
-
|
| 18 |
-
sys.path.append(os.getcwd())
|
| 19 |
-
|
| 20 |
-
from main.configs.config import Config
|
| 21 |
-
from main.library.audioldm2.utils import load_audio
|
| 22 |
-
from main.library.audioldm2.models import load_model
|
| 23 |
-
|
| 24 |
-
config = Config()
|
| 25 |
-
translations = config.translations
|
| 26 |
-
logger = logging.getLogger(__name__)
|
| 27 |
-
logger.propagate = False
|
| 28 |
-
|
| 29 |
-
for l in ["torch", "httpx", "httpcore", "diffusers", "transformers"]:
|
| 30 |
-
logging.getLogger(l).setLevel(logging.ERROR)
|
| 31 |
-
|
| 32 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
| 33 |
-
else:
|
| 34 |
-
console_handler = logging.StreamHandler()
|
| 35 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
| 36 |
-
console_handler.setFormatter(console_formatter)
|
| 37 |
-
console_handler.setLevel(logging.INFO)
|
| 38 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "audioldm2.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
| 39 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
| 40 |
-
file_handler.setFormatter(file_formatter)
|
| 41 |
-
file_handler.setLevel(logging.DEBUG)
|
| 42 |
-
logger.addHandler(console_handler)
|
| 43 |
-
logger.addHandler(file_handler)
|
| 44 |
-
logger.setLevel(logging.DEBUG)
|
| 45 |
-
|
| 46 |
-
def parse_arguments():
|
| 47 |
-
parser = argparse.ArgumentParser()
|
| 48 |
-
parser.add_argument("--input_path", type=str, required=True)
|
| 49 |
-
parser.add_argument("--output_path", type=str, default="./output.wav")
|
| 50 |
-
parser.add_argument("--export_format", type=str, default="wav")
|
| 51 |
-
parser.add_argument("--sample_rate", type=int, default=44100)
|
| 52 |
-
parser.add_argument("--audioldm_model", type=str, default="audioldm2-music")
|
| 53 |
-
parser.add_argument("--source_prompt", type=str, default="")
|
| 54 |
-
parser.add_argument("--target_prompt", type=str, default="")
|
| 55 |
-
parser.add_argument("--steps", type=int, default=200)
|
| 56 |
-
parser.add_argument("--cfg_scale_src", type=float, default=3.5)
|
| 57 |
-
parser.add_argument("--cfg_scale_tar", type=float, default=12)
|
| 58 |
-
parser.add_argument("--t_start", type=int, default=45)
|
| 59 |
-
parser.add_argument("--save_compute", type=lambda x: bool(strtobool(x)), default=False)
|
| 60 |
-
|
| 61 |
-
return parser.parse_args()
|
| 62 |
-
|
| 63 |
-
def main():
|
| 64 |
-
args = parse_arguments()
|
| 65 |
-
input_path, output_path, export_format, sample_rate, audioldm_model, source_prompt, target_prompt, steps, cfg_scale_src, cfg_scale_tar, t_start, save_compute = args.input_path, args.output_path, args.export_format, args.sample_rate, args.audioldm_model, args.source_prompt, args.target_prompt, args.steps, args.cfg_scale_src, args.cfg_scale_tar, args.t_start, args.save_compute
|
| 66 |
-
|
| 67 |
-
log_data = {translations['audio_path']: input_path, translations['output_path']: output_path.replace('wav', export_format), translations['model_name']: audioldm_model, translations['export_format']: export_format, translations['sample_rate']: sample_rate, translations['steps']: steps, translations['source_prompt']: source_prompt, translations['target_prompt']: target_prompt, translations['cfg_scale_src']: cfg_scale_src, translations['cfg_scale_tar']: cfg_scale_tar, translations['t_start']: t_start, translations['save_compute']: save_compute}
|
| 68 |
-
|
| 69 |
-
for key, value in log_data.items():
|
| 70 |
-
logger.debug(f"{key}: {value}")
|
| 71 |
-
|
| 72 |
-
start_time = time.time()
|
| 73 |
-
logger.info(translations["start_edit"].format(input_path=input_path))
|
| 74 |
-
pid_path = os.path.join("assets", "audioldm2_pid.txt")
|
| 75 |
-
with open(pid_path, "w") as pid_file:
|
| 76 |
-
pid_file.write(str(os.getpid()))
|
| 77 |
-
|
| 78 |
-
try:
|
| 79 |
-
edit(input_path, output_path, audioldm_model, source_prompt, target_prompt, steps, cfg_scale_src, cfg_scale_tar, t_start, save_compute, sample_rate, config.device, export_format=export_format)
|
| 80 |
-
except Exception as e:
|
| 81 |
-
logger.error(translations["error_edit"].format(e=e))
|
| 82 |
-
import traceback
|
| 83 |
-
logger.debug(traceback.format_exc())
|
| 84 |
-
|
| 85 |
-
logger.info(translations["edit_success"].format(time=f"{(time.time() - start_time):.2f}", output_path=output_path.replace('wav', export_format)))
|
| 86 |
-
|
| 87 |
-
def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src, duration, save_compute):
|
| 88 |
-
with inference_mode():
|
| 89 |
-
w0 = ldm_stable.vae_encode(x0)
|
| 90 |
-
|
| 91 |
-
_, zs, wts, extra_info = inversion_forward_process(ldm_stable, w0, etas=1, prompts=[prompt_src], cfg_scales=[cfg_scale_src], num_inference_steps=num_diffusion_steps, numerical_fix=True, duration=duration, save_compute=save_compute)
|
| 92 |
-
return zs, wts, extra_info
|
| 93 |
-
|
| 94 |
-
def low_pass_filter(audio, cutoff=7500, sr=16000):
|
| 95 |
-
b, a = scipy.signal.butter(4, cutoff / (sr / 2), btype='low')
|
| 96 |
-
return scipy.signal.filtfilt(b, a, audio)
|
| 97 |
-
|
| 98 |
-
def sample(output_audio, sr, ldm_stable, zs, wts, extra_info, prompt_tar, tstart, cfg_scale_tar, duration, save_compute, export_format = "wav"):
|
| 99 |
-
tstart = torch.tensor(tstart, dtype=torch.int32)
|
| 100 |
-
w0, _ = inversion_reverse_process(ldm_stable, xT=wts, tstart=tstart, etas=1., prompts=[prompt_tar], neg_prompts=[""], cfg_scales=[cfg_scale_tar], zs=zs[:int(tstart)], duration=duration, extra_info=extra_info, save_compute=save_compute)
|
| 101 |
-
|
| 102 |
-
with inference_mode():
|
| 103 |
-
x0_dec = ldm_stable.vae_decode(w0.to(torch.float16 if config.is_half else torch.float32))
|
| 104 |
-
|
| 105 |
-
if x0_dec.dim() < 4: x0_dec = x0_dec[None, :, :, :]
|
| 106 |
-
|
| 107 |
-
with torch.no_grad():
|
| 108 |
-
audio = ldm_stable.decode_to_mel(x0_dec.to(torch.float16 if config.is_half else torch.float32))
|
| 109 |
-
|
| 110 |
-
audio = audio.float().squeeze().cpu().numpy()
|
| 111 |
-
orig_sr = 16000
|
| 112 |
-
|
| 113 |
-
if sr != 16000 and sr > 0:
|
| 114 |
-
audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=sr, res_type="soxr_vhq")
|
| 115 |
-
orig_sr = sr
|
| 116 |
-
|
| 117 |
-
audio = low_pass_filter(audio, 7500, orig_sr)
|
| 118 |
-
|
| 119 |
-
sf.write(output_audio, np.tile(audio, (2, 1)).T, orig_sr, format=export_format)
|
| 120 |
-
return output_audio
|
| 121 |
-
|
| 122 |
-
def edit(input_audio, output_audio, model_id, source_prompt = "", target_prompt = "", steps = 200, cfg_scale_src = 3.5, cfg_scale_tar = 12, t_start = 45, save_compute = True, sr = 44100, device = "cpu", export_format = "wav"):
|
| 123 |
-
ldm_stable = load_model(model_id, device=device)
|
| 124 |
-
ldm_stable.model.scheduler.set_timesteps(steps, device=device)
|
| 125 |
-
x0, duration = load_audio(input_audio, ldm_stable.get_melspectrogram(), device=device)
|
| 126 |
-
zs_tensor, wts_tensor, extra_info_list = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt, num_diffusion_steps=steps, cfg_scale_src=cfg_scale_src, duration=duration, save_compute=save_compute)
|
| 127 |
-
|
| 128 |
-
return sample(output_audio, sr, ldm_stable, zs_tensor, wts_tensor, extra_info_list, prompt_tar=target_prompt, tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar, duration=duration, save_compute=save_compute, export_format=export_format)
|
| 129 |
-
|
| 130 |
-
def inversion_forward_process(model, x0, etas = None, prompts = [""], cfg_scales = [3.5], num_inference_steps = 50, numerical_fix = False, duration = None, first_order = False, save_compute = True):
|
| 131 |
-
if len(prompts) > 1 or prompts[0] != "":
|
| 132 |
-
text_embeddings_hidden_states, text_embeddings_class_labels, text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
|
| 133 |
-
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text([""], negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1] if text_embeddings_class_labels is not None else None)
|
| 134 |
-
else: uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text([""], negative=True, save_compute=False)
|
| 135 |
-
|
| 136 |
-
timesteps = model.model.scheduler.timesteps.to(model.device)
|
| 137 |
-
variance_noise_shape = model.get_noise_shape(x0, num_inference_steps)
|
| 138 |
-
|
| 139 |
-
if type(etas) in [int, float]: etas = [etas]*model.model.scheduler.num_inference_steps
|
| 140 |
-
|
| 141 |
-
xts = model.sample_xts_from_x0(x0, num_inference_steps=num_inference_steps)
|
| 142 |
-
zs = torch.zeros(size=variance_noise_shape, device=model.device)
|
| 143 |
-
extra_info = [None] * len(zs)
|
| 144 |
-
|
| 145 |
-
if timesteps[0].dtype == torch.int64: t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
| 146 |
-
elif timesteps[0].dtype == torch.float32: t_to_idx = {float(v): k for k, v in enumerate(timesteps)}
|
| 147 |
-
|
| 148 |
-
xt = x0
|
| 149 |
-
model.setup_extra_inputs(xt, init_timestep=timesteps[0], audio_end_in_s=duration, save_compute=save_compute and prompts[0] != "")
|
| 150 |
-
|
| 151 |
-
for t in tqdm.tqdm(timesteps, desc=translations["inverting"], ncols=100, unit="a"):
|
| 152 |
-
idx = num_inference_steps - t_to_idx[int(t) if timesteps[0].dtype == torch.int64 else float(t)] - 1
|
| 153 |
-
xt = xts[idx + 1][None]
|
| 154 |
-
xt_inp = model.model.scheduler.scale_model_input(xt, t).to(torch.float16 if config.is_half else torch.float32)
|
| 155 |
-
|
| 156 |
-
with torch.no_grad():
|
| 157 |
-
if save_compute and prompts[0] != "":
|
| 158 |
-
comb_out, _, _ = model.unet_forward(xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1), timestep=t, encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states], dim=0) if uncond_embeddings_hidden_states is not None else None, class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0) if uncond_embeddings_class_lables is not None else None, encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask], dim=0) if uncond_boolean_prompt_mask is not None else None)
|
| 159 |
-
out, cond_out = comb_out.sample.chunk(2, dim=0)
|
| 160 |
-
else:
|
| 161 |
-
out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=uncond_embeddings_hidden_states, class_labels=uncond_embeddings_class_lables, encoder_attention_mask=uncond_boolean_prompt_mask)[0].sample
|
| 162 |
-
if len(prompts) > 1 or prompts[0] != "": cond_out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=text_embeddings_hidden_states, class_labels=text_embeddings_class_labels, encoder_attention_mask=text_embeddings_boolean_prompt_mask)[0].sample
|
| 163 |
-
|
| 164 |
-
if len(prompts) > 1 or prompts[0] != "": noise_pred = out + (cfg_scales[0] * (cond_out - out)).sum(axis=0).unsqueeze(0)
|
| 165 |
-
else: noise_pred = out
|
| 166 |
-
|
| 167 |
-
xtm1 = xts[idx][None]
|
| 168 |
-
z, xtm1, extra = model.get_zs_from_xts(xt, xtm1, noise_pred, t, eta=etas[idx], numerical_fix=numerical_fix, first_order=first_order)
|
| 169 |
-
zs[idx] = z
|
| 170 |
-
xts[idx] = xtm1
|
| 171 |
-
extra_info[idx] = extra
|
| 172 |
-
|
| 173 |
-
if zs is not None: zs[0] = torch.zeros_like(zs[0])
|
| 174 |
-
return xt, zs, xts, extra_info
|
| 175 |
-
|
| 176 |
-
def inversion_reverse_process(model, xT, tstart, etas = 0, prompts = [""], neg_prompts = [""], cfg_scales = None, zs = None, duration = None, first_order = False, extra_info = None, save_compute = True):
|
| 177 |
-
text_embeddings_hidden_states, text_embeddings_class_labels, text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
|
| 178 |
-
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text(neg_prompts, negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1] if text_embeddings_class_labels is not None else None)
|
| 179 |
-
xt = xT[tstart.max()].unsqueeze(0)
|
| 180 |
-
|
| 181 |
-
if etas is None: etas = 0
|
| 182 |
-
if type(etas) in [int, float]: etas = [etas]*model.model.scheduler.num_inference_steps
|
| 183 |
-
|
| 184 |
-
assert len(etas) == model.model.scheduler.num_inference_steps
|
| 185 |
-
timesteps = model.model.scheduler.timesteps.to(model.device)
|
| 186 |
-
|
| 187 |
-
if timesteps[0].dtype == torch.int64: t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
|
| 188 |
-
elif timesteps[0].dtype == torch.float32: t_to_idx = {float(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
|
| 189 |
-
|
| 190 |
-
model.setup_extra_inputs(xt, extra_info=extra_info, init_timestep=timesteps[-zs.shape[0]], audio_end_in_s=duration, save_compute=save_compute)
|
| 191 |
-
|
| 192 |
-
for t in tqdm.tqdm(timesteps[-zs.shape[0]:], desc=translations["editing"], ncols=100, unit="a"):
|
| 193 |
-
idx = model.model.scheduler.num_inference_steps - t_to_idx[int(t) if timesteps[0].dtype == torch.int64 else float(t)] - (model.model.scheduler.num_inference_steps - zs.shape[0] + 1)
|
| 194 |
-
xt_inp = model.model.scheduler.scale_model_input(xt, t).to(torch.float16 if config.is_half else torch.float32)
|
| 195 |
-
|
| 196 |
-
with torch.no_grad():
|
| 197 |
-
if save_compute:
|
| 198 |
-
comb_out, _, _ = model.unet_forward(xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1), timestep=t, encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states], dim=0) if uncond_embeddings_hidden_states is not None else None, class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0) if uncond_embeddings_class_lables is not None else None, encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask], dim=0) if uncond_boolean_prompt_mask is not None else None)
|
| 199 |
-
uncond_out, cond_out = comb_out.sample.chunk(2, dim=0)
|
| 200 |
-
else:
|
| 201 |
-
uncond_out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=uncond_embeddings_hidden_states, class_labels=uncond_embeddings_class_lables, encoder_attention_mask=uncond_boolean_prompt_mask)[0].sample
|
| 202 |
-
cond_out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=text_embeddings_hidden_states, class_labels=text_embeddings_class_labels, encoder_attention_mask=text_embeddings_boolean_prompt_mask)[0].sample
|
| 203 |
-
|
| 204 |
-
z = zs[idx] if zs is not None else None
|
| 205 |
-
noise_pred = uncond_out + (cfg_scales[0] * (cond_out - uncond_out)).sum(axis=0).unsqueeze(0)
|
| 206 |
-
xt = model.reverse_step_with_custom_noise(noise_pred, t, xt, variance_noise=z.unsqueeze(0), eta=etas[idx], first_order=first_order)
|
| 207 |
-
|
| 208 |
-
return xt, zs
|
| 209 |
-
|
| 210 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/convert.py
DELETED
|
@@ -1,590 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
import os
|
| 3 |
-
import gc
|
| 4 |
-
import sys
|
| 5 |
-
import time
|
| 6 |
-
import faiss
|
| 7 |
-
import torch
|
| 8 |
-
import librosa
|
| 9 |
-
import logging
|
| 10 |
-
import argparse
|
| 11 |
-
import warnings
|
| 12 |
-
import onnxruntime
|
| 13 |
-
import logging.handlers
|
| 14 |
-
|
| 15 |
-
import numpy as np
|
| 16 |
-
import soundfile as sf
|
| 17 |
-
import torch.nn.functional as F
|
| 18 |
-
|
| 19 |
-
from tqdm import tqdm
|
| 20 |
-
from scipy import signal
|
| 21 |
-
from distutils.util import strtobool
|
| 22 |
-
|
| 23 |
-
warnings.filterwarnings("ignore")
|
| 24 |
-
sys.path.append(os.getcwd())
|
| 25 |
-
|
| 26 |
-
from main.configs.config import Config
|
| 27 |
-
from main.library.algorithm.synthesizers import Synthesizer
|
| 28 |
-
from main.library.utils import check_predictors, check_embedders, load_audio, load_embedders_model, cut, restore
|
| 29 |
-
|
| 30 |
-
bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=16000)
|
| 31 |
-
config = Config()
|
| 32 |
-
translations = config.translations
|
| 33 |
-
logger = logging.getLogger(__name__)
|
| 34 |
-
logger.propagate = False
|
| 35 |
-
|
| 36 |
-
for l in ["torch", "faiss", "httpx", "fairseq", "httpcore", "faiss.loader", "numba.core", "urllib3", "transformers", "matplotlib"]:
|
| 37 |
-
logging.getLogger(l).setLevel(logging.ERROR)
|
| 38 |
-
|
| 39 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
| 40 |
-
else:
|
| 41 |
-
console_handler = logging.StreamHandler()
|
| 42 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
| 43 |
-
console_handler.setFormatter(console_formatter)
|
| 44 |
-
console_handler.setLevel(logging.INFO)
|
| 45 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "convert.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
| 46 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
| 47 |
-
file_handler.setFormatter(file_formatter)
|
| 48 |
-
file_handler.setLevel(logging.DEBUG)
|
| 49 |
-
logger.addHandler(console_handler)
|
| 50 |
-
logger.addHandler(file_handler)
|
| 51 |
-
logger.setLevel(logging.DEBUG)
|
| 52 |
-
|
| 53 |
-
def parse_arguments():
|
| 54 |
-
parser = argparse.ArgumentParser()
|
| 55 |
-
parser.add_argument("--pitch", type=int, default=0)
|
| 56 |
-
parser.add_argument("--filter_radius", type=int, default=3)
|
| 57 |
-
parser.add_argument("--index_rate", type=float, default=0.5)
|
| 58 |
-
parser.add_argument("--volume_envelope", type=float, default=1)
|
| 59 |
-
parser.add_argument("--protect", type=float, default=0.33)
|
| 60 |
-
parser.add_argument("--hop_length", type=int, default=64)
|
| 61 |
-
parser.add_argument("--f0_method", type=str, default="rmvpe")
|
| 62 |
-
parser.add_argument("--embedder_model", type=str, default="contentvec_base")
|
| 63 |
-
parser.add_argument("--input_path", type=str, required=True)
|
| 64 |
-
parser.add_argument("--output_path", type=str, default="./audios/output.wav")
|
| 65 |
-
parser.add_argument("--export_format", type=str, default="wav")
|
| 66 |
-
parser.add_argument("--pth_path", type=str, required=True)
|
| 67 |
-
parser.add_argument("--index_path", type=str)
|
| 68 |
-
parser.add_argument("--f0_autotune", type=lambda x: bool(strtobool(x)), default=False)
|
| 69 |
-
parser.add_argument("--f0_autotune_strength", type=float, default=1)
|
| 70 |
-
parser.add_argument("--clean_audio", type=lambda x: bool(strtobool(x)), default=False)
|
| 71 |
-
parser.add_argument("--clean_strength", type=float, default=0.7)
|
| 72 |
-
parser.add_argument("--resample_sr", type=int, default=0)
|
| 73 |
-
parser.add_argument("--split_audio", type=lambda x: bool(strtobool(x)), default=False)
|
| 74 |
-
parser.add_argument("--checkpointing", type=lambda x: bool(strtobool(x)), default=False)
|
| 75 |
-
parser.add_argument("--f0_file", type=str, default="")
|
| 76 |
-
parser.add_argument("--f0_onnx", type=lambda x: bool(strtobool(x)), default=False)
|
| 77 |
-
parser.add_argument("--embedders_mode", type=str, default="fairseq")
|
| 78 |
-
parser.add_argument("--formant_shifting", type=lambda x: bool(strtobool(x)), default=False)
|
| 79 |
-
parser.add_argument("--formant_qfrency", type=float, default=0.8)
|
| 80 |
-
parser.add_argument("--formant_timbre", type=float, default=0.8)
|
| 81 |
-
|
| 82 |
-
return parser.parse_args()
|
| 83 |
-
|
| 84 |
-
def main():
|
| 85 |
-
args = parse_arguments()
|
| 86 |
-
pitch, filter_radius, index_rate, volume_envelope, protect, hop_length, f0_method, input_path, output_path, pth_path, index_path, f0_autotune, f0_autotune_strength, clean_audio, clean_strength, export_format, embedder_model, resample_sr, split_audio, checkpointing, f0_file, f0_onnx, embedders_mode, formant_shifting, formant_qfrency, formant_timbre = args.pitch, args.filter_radius, args.index_rate, args.volume_envelope,args.protect, args.hop_length, args.f0_method, args.input_path, args.output_path, args.pth_path, args.index_path, args.f0_autotune, args.f0_autotune_strength, args.clean_audio, args.clean_strength, args.export_format, args.embedder_model, args.resample_sr, args.split_audio, args.checkpointing, args.f0_file, args.f0_onnx, args.embedders_mode, args.formant_shifting, args.formant_qfrency, args.formant_timbre
|
| 87 |
-
|
| 88 |
-
log_data = {translations['pitch']: pitch, translations['filter_radius']: filter_radius, translations['index_strength']: index_rate, translations['volume_envelope']: volume_envelope, translations['protect']: protect, "Hop length": hop_length, translations['f0_method']: f0_method, translations['audio_path']: input_path, translations['output_path']: output_path.replace('wav', export_format), translations['model_path']: pth_path, translations['indexpath']: index_path, translations['autotune']: f0_autotune, translations['clear_audio']: clean_audio, translations['export_format']: export_format, translations['hubert_model']: embedder_model, translations['split_audio']: split_audio, translations['memory_efficient_training']: checkpointing, translations["f0_onnx_mode"]: f0_onnx, translations["embed_mode"]: embedders_mode}
|
| 89 |
-
|
| 90 |
-
if clean_audio: log_data[translations['clean_strength']] = clean_strength
|
| 91 |
-
if resample_sr != 0: log_data[translations['sample_rate']] = resample_sr
|
| 92 |
-
|
| 93 |
-
if f0_autotune: log_data[translations['autotune_rate_info']] = f0_autotune_strength
|
| 94 |
-
if os.path.isfile(f0_file): log_data[translations['f0_file']] = f0_file
|
| 95 |
-
|
| 96 |
-
if formant_shifting:
|
| 97 |
-
log_data[translations['formant_qfrency']] = formant_qfrency
|
| 98 |
-
log_data[translations['formant_timbre']] = formant_timbre
|
| 99 |
-
|
| 100 |
-
for key, value in log_data.items():
|
| 101 |
-
logger.debug(f"{key}: {value}")
|
| 102 |
-
|
| 103 |
-
run_convert_script(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, input_path=input_path, output_path=output_path, pth_path=pth_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, embedder_model=embedder_model, resample_sr=resample_sr, split_audio=split_audio, checkpointing=checkpointing, f0_file=f0_file, f0_onnx=f0_onnx, embedders_mode=embedders_mode, formant_shifting=formant_shifting, formant_qfrency=formant_qfrency, formant_timbre=formant_timbre)
|
| 104 |
-
|
| 105 |
-
def run_convert_script(pitch=0, filter_radius=3, index_rate=0.5, volume_envelope=1, protect=0.5, hop_length=64, f0_method="rmvpe", input_path=None, output_path="./output.wav", pth_path=None, index_path=None, f0_autotune=False, f0_autotune_strength=1, clean_audio=False, clean_strength=0.7, export_format="wav", embedder_model="contentvec_base", resample_sr=0, split_audio=False, checkpointing=False, f0_file=None, f0_onnx=False, embedders_mode="fairseq", formant_shifting=False, formant_qfrency=0.8, formant_timbre=0.8):
|
| 106 |
-
check_predictors(f0_method, f0_onnx); check_embedders(embedder_model, embedders_mode)
|
| 107 |
-
|
| 108 |
-
if not pth_path or not os.path.exists(pth_path) or os.path.isdir(pth_path) or not pth_path.endswith((".pth", ".onnx")):
|
| 109 |
-
logger.warning(translations["provide_file"].format(filename=translations["model"]))
|
| 110 |
-
sys.exit(1)
|
| 111 |
-
|
| 112 |
-
cvt = VoiceConverter(pth_path, 0)
|
| 113 |
-
start_time = time.time()
|
| 114 |
-
|
| 115 |
-
pid_path = os.path.join("assets", "convert_pid.txt")
|
| 116 |
-
with open(pid_path, "w") as pid_file:
|
| 117 |
-
pid_file.write(str(os.getpid()))
|
| 118 |
-
|
| 119 |
-
if os.path.isdir(input_path):
|
| 120 |
-
logger.info(translations["convert_batch"])
|
| 121 |
-
audio_files = [f for f in os.listdir(input_path) if f.lower().endswith(("wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3"))]
|
| 122 |
-
|
| 123 |
-
if not audio_files:
|
| 124 |
-
logger.warning(translations["not_found_audio"])
|
| 125 |
-
sys.exit(1)
|
| 126 |
-
|
| 127 |
-
logger.info(translations["found_audio"].format(audio_files=len(audio_files)))
|
| 128 |
-
|
| 129 |
-
for audio in audio_files:
|
| 130 |
-
audio_path = os.path.join(input_path, audio)
|
| 131 |
-
output_audio = os.path.join(input_path, os.path.splitext(audio)[0] + f"_output.{export_format}")
|
| 132 |
-
|
| 133 |
-
logger.info(f"{translations['convert_audio']} '{audio_path}'...")
|
| 134 |
-
if os.path.exists(output_audio): os.remove(output_audio)
|
| 135 |
-
cvt.convert_audio(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, audio_input_path=audio_path, audio_output_path=output_audio, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, embedder_model=embedder_model, resample_sr=resample_sr, checkpointing=checkpointing, f0_file=f0_file, f0_onnx=f0_onnx, embedders_mode=embedders_mode, formant_shifting=formant_shifting, formant_qfrency=formant_qfrency, formant_timbre=formant_timbre, split_audio=split_audio)
|
| 136 |
-
|
| 137 |
-
logger.info(translations["convert_batch_success"].format(elapsed_time=f"{(time.time() - start_time):.2f}", output_path=output_path.replace('wav', export_format)))
|
| 138 |
-
else:
|
| 139 |
-
if not os.path.exists(input_path):
|
| 140 |
-
logger.warning(translations["not_found_audio"])
|
| 141 |
-
sys.exit(1)
|
| 142 |
-
|
| 143 |
-
logger.info(f"{translations['convert_audio']} '{input_path}'...")
|
| 144 |
-
if os.path.exists(output_path): os.remove(output_path)
|
| 145 |
-
cvt.convert_audio(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, audio_input_path=input_path, audio_output_path=output_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, embedder_model=embedder_model, resample_sr=resample_sr, checkpointing=checkpointing, f0_file=f0_file, f0_onnx=f0_onnx, embedders_mode=embedders_mode, formant_shifting=formant_shifting, formant_qfrency=formant_qfrency, formant_timbre=formant_timbre, split_audio=split_audio)
|
| 146 |
-
|
| 147 |
-
if os.path.exists(pid_path): os.remove(pid_path)
|
| 148 |
-
logger.info(translations["convert_audio_success"].format(input_path=input_path, elapsed_time=f"{(time.time() - start_time):.2f}", output_path=output_path.replace('wav', export_format)))
|
| 149 |
-
|
| 150 |
-
def change_rms(source_audio, source_rate, target_audio, target_rate, rate):
|
| 151 |
-
rms2 = F.interpolate(torch.from_numpy(librosa.feature.rms(y=target_audio, frame_length=target_rate // 2 * 2, hop_length=target_rate // 2)).float().unsqueeze(0), size=target_audio.shape[0], mode="linear").squeeze()
|
| 152 |
-
return (target_audio * (torch.pow(F.interpolate(torch.from_numpy(librosa.feature.rms(y=source_audio, frame_length=source_rate // 2 * 2, hop_length=source_rate // 2)).float().unsqueeze(0), size=target_audio.shape[0], mode="linear").squeeze(), 1 - rate) * torch.pow(torch.maximum(rms2, torch.zeros_like(rms2) + 1e-6), rate - 1)).numpy())
|
| 153 |
-
|
| 154 |
-
def clear_gpu_cache():
|
| 155 |
-
gc.collect()
|
| 156 |
-
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 157 |
-
elif torch.backends.mps.is_available(): torch.mps.empty_cache()
|
| 158 |
-
|
| 159 |
-
def get_providers():
|
| 160 |
-
ort_providers = onnxruntime.get_available_providers()
|
| 161 |
-
|
| 162 |
-
if "CUDAExecutionProvider" in ort_providers: providers = ["CUDAExecutionProvider"]
|
| 163 |
-
elif "CoreMLExecutionProvider" in ort_providers: providers = ["CoreMLExecutionProvider"]
|
| 164 |
-
else: providers = ["CPUExecutionProvider"]
|
| 165 |
-
|
| 166 |
-
return providers
|
| 167 |
-
|
| 168 |
-
class Autotune:
|
| 169 |
-
def __init__(self, ref_freqs):
|
| 170 |
-
self.ref_freqs = ref_freqs
|
| 171 |
-
self.note_dict = self.ref_freqs
|
| 172 |
-
|
| 173 |
-
def autotune_f0(self, f0, f0_autotune_strength):
|
| 174 |
-
autotuned_f0 = np.zeros_like(f0)
|
| 175 |
-
|
| 176 |
-
for i, freq in enumerate(f0):
|
| 177 |
-
autotuned_f0[i] = freq + (min(self.note_dict, key=lambda x: abs(x - freq)) - freq) * f0_autotune_strength
|
| 178 |
-
|
| 179 |
-
return autotuned_f0
|
| 180 |
-
|
| 181 |
-
class VC:
|
| 182 |
-
def __init__(self, tgt_sr, config):
|
| 183 |
-
self.x_pad = config.x_pad
|
| 184 |
-
self.x_query = config.x_query
|
| 185 |
-
self.x_center = config.x_center
|
| 186 |
-
self.x_max = config.x_max
|
| 187 |
-
self.sample_rate = 16000
|
| 188 |
-
self.window = 160
|
| 189 |
-
self.t_pad = self.sample_rate * self.x_pad
|
| 190 |
-
self.t_pad_tgt = tgt_sr * self.x_pad
|
| 191 |
-
self.t_pad2 = self.t_pad * 2
|
| 192 |
-
self.t_query = self.sample_rate * self.x_query
|
| 193 |
-
self.t_center = self.sample_rate * self.x_center
|
| 194 |
-
self.t_max = self.sample_rate * self.x_max
|
| 195 |
-
self.time_step = self.window / self.sample_rate * 1000
|
| 196 |
-
self.f0_min = 50
|
| 197 |
-
self.f0_max = 1100
|
| 198 |
-
self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
|
| 199 |
-
self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
|
| 200 |
-
self.device = config.device
|
| 201 |
-
self.is_half = config.is_half
|
| 202 |
-
self.ref_freqs = [49.00, 51.91, 55.00, 58.27, 61.74, 65.41, 69.30, 73.42, 77.78, 82.41, 87.31, 92.50, 98.00, 103.83, 110.00, 116.54, 123.47, 130.81, 138.59, 146.83, 155.56, 164.81, 174.61, 185.00, 196.00, 207.65, 220.00, 233.08, 246.94, 261.63, 277.18, 293.66, 311.13, 329.63, 349.23, 369.99, 392.00, 415.30, 440.00, 466.16, 493.88, 523.25, 554.37, 587.33, 622.25, 659.25, 698.46, 739.99, 783.99, 830.61, 880.00, 932.33, 987.77, 1046.50]
|
| 203 |
-
self.autotune = Autotune(self.ref_freqs)
|
| 204 |
-
self.note_dict = self.autotune.note_dict
|
| 205 |
-
|
| 206 |
-
def get_f0_pm(self, x, p_len):
|
| 207 |
-
import parselmouth
|
| 208 |
-
|
| 209 |
-
f0 = (parselmouth.Sound(x, self.sample_rate).to_pitch_ac(time_step=self.window / self.sample_rate * 1000 / 1000, voicing_threshold=0.6, pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array["frequency"])
|
| 210 |
-
pad_size = (p_len - len(f0) + 1) // 2
|
| 211 |
-
|
| 212 |
-
if pad_size > 0 or p_len - len(f0) - pad_size > 0: f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
|
| 213 |
-
return f0
|
| 214 |
-
|
| 215 |
-
def get_f0_mangio_crepe(self, x, p_len, hop_length, model="full", onnx=False):
|
| 216 |
-
from main.library.predictors.CREPE import predict
|
| 217 |
-
|
| 218 |
-
x = x.astype(np.float32)
|
| 219 |
-
x /= np.quantile(np.abs(x), 0.999)
|
| 220 |
-
|
| 221 |
-
audio = torch.unsqueeze(torch.from_numpy(x).to(self.device, copy=True), dim=0)
|
| 222 |
-
if audio.ndim == 2 and audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True).detach()
|
| 223 |
-
|
| 224 |
-
p_len = p_len or x.shape[0] // hop_length
|
| 225 |
-
source = np.array(predict(audio.detach(), self.sample_rate, hop_length, self.f0_min, self.f0_max, model, batch_size=hop_length * 2, device=self.device, pad=True, providers=get_providers(), onnx=onnx).squeeze(0).cpu().float().numpy())
|
| 226 |
-
source[source < 0.001] = np.nan
|
| 227 |
-
|
| 228 |
-
return np.nan_to_num(np.interp(np.arange(0, len(source) * p_len, len(source)) / p_len, np.arange(0, len(source)), source))
|
| 229 |
-
|
| 230 |
-
def get_f0_crepe(self, x, model="full", onnx=False):
|
| 231 |
-
from main.library.predictors.CREPE import predict, mean, median
|
| 232 |
-
|
| 233 |
-
f0, pd = predict(torch.tensor(np.copy(x))[None].float(), self.sample_rate, self.window, self.f0_min, self.f0_max, model, batch_size=512, device=self.device, return_periodicity=True, providers=get_providers(), onnx=onnx)
|
| 234 |
-
f0, pd = mean(f0, 3), median(pd, 3)
|
| 235 |
-
f0[pd < 0.1] = 0
|
| 236 |
-
|
| 237 |
-
return f0[0].cpu().numpy()
|
| 238 |
-
|
| 239 |
-
def get_f0_fcpe(self, x, p_len, hop_length, onnx=False, legacy=False):
|
| 240 |
-
from main.library.predictors.FCPE import FCPE
|
| 241 |
-
|
| 242 |
-
model_fcpe = FCPE(os.path.join("assets", "models", "predictors", ("fcpe_legacy" if legacy else "fcpe") + (".onnx" if onnx else ".pt")), hop_length=int(hop_length), f0_min=int(self.f0_min), f0_max=int(self.f0_max), dtype=torch.float32, device=self.device, sample_rate=self.sample_rate, threshold=0.03 if legacy else 0.006, providers=get_providers(), onnx=onnx, legacy=legacy)
|
| 243 |
-
f0 = model_fcpe.compute_f0(x, p_len=p_len)
|
| 244 |
-
|
| 245 |
-
del model_fcpe
|
| 246 |
-
return f0
|
| 247 |
-
|
| 248 |
-
def get_f0_rmvpe(self, x, legacy=False, onnx=False):
|
| 249 |
-
from main.library.predictors.RMVPE import RMVPE
|
| 250 |
-
|
| 251 |
-
rmvpe_model = RMVPE(os.path.join("assets", "models", "predictors", "rmvpe" + (".onnx" if onnx else ".pt")), is_half=self.is_half, device=self.device, onnx=onnx, providers=get_providers())
|
| 252 |
-
f0 = rmvpe_model.infer_from_audio_with_pitch(x, thred=0.03, f0_min=self.f0_min, f0_max=self.f0_max) if legacy else rmvpe_model.infer_from_audio(x, thred=0.03)
|
| 253 |
-
|
| 254 |
-
del rmvpe_model
|
| 255 |
-
return f0
|
| 256 |
-
|
| 257 |
-
def get_f0_pyworld(self, x, filter_radius, model="harvest"):
|
| 258 |
-
from main.library.predictors.WORLD_WRAPPER import PYWORLD
|
| 259 |
-
|
| 260 |
-
pw = PYWORLD()
|
| 261 |
-
x = x.astype(np.double)
|
| 262 |
-
|
| 263 |
-
if model == "harvest": f0, t = pw.harvest(x, fs=self.sample_rate, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=10)
|
| 264 |
-
elif model == "dio": f0, t = pw.dio(x, fs=self.sample_rate, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=10)
|
| 265 |
-
else: raise ValueError(translations["method_not_valid"])
|
| 266 |
-
|
| 267 |
-
f0 = pw.stonemask(x, self.sample_rate, t, f0)
|
| 268 |
-
|
| 269 |
-
if filter_radius > 2 or model == "dio": f0 = signal.medfilt(f0, filter_radius)
|
| 270 |
-
return f0
|
| 271 |
-
|
| 272 |
-
def get_f0_swipe(self, x):
|
| 273 |
-
from main.library.predictors.SWIPE import swipe
|
| 274 |
-
|
| 275 |
-
f0, _ = swipe(x.astype(np.float32), self.sample_rate, f0_floor=self.f0_min, f0_ceil=self.f0_max, frame_period=10)
|
| 276 |
-
return f0
|
| 277 |
-
|
| 278 |
-
def get_f0_yin(self, x, hop_length, p_len, mode="yin"):
|
| 279 |
-
source = np.array(librosa.yin(x.astype(np.float32), sr=self.sample_rate, fmin=self.f0_min, fmax=self.f0_max, hop_length=hop_length) if mode == "yin" else librosa.pyin(x.astype(np.float32), fmin=self.f0_min, fmax=self.f0_max, sr=self.sample_rate, hop_length=hop_length)[0])
|
| 280 |
-
source[source < 0.001] = np.nan
|
| 281 |
-
return np.nan_to_num(np.interp(np.arange(0, len(source) * p_len, len(source)) / p_len, np.arange(0, len(source)), source))
|
| 282 |
-
|
| 283 |
-
def get_f0_hybrid(self, methods_str, x, p_len, hop_length, filter_radius, onnx_mode):
|
| 284 |
-
methods_str = re.search("hybrid\[(.+)\]", methods_str)
|
| 285 |
-
if methods_str: methods = [method.strip() for method in methods_str.group(1).split("+")]
|
| 286 |
-
|
| 287 |
-
f0_computation_stack, resampled_stack = [], []
|
| 288 |
-
logger.debug(translations["hybrid_methods"].format(methods=methods))
|
| 289 |
-
|
| 290 |
-
x = x.astype(np.float32)
|
| 291 |
-
x /= np.quantile(np.abs(x), 0.999)
|
| 292 |
-
|
| 293 |
-
for method in methods:
|
| 294 |
-
f0 = None
|
| 295 |
-
f0_methods = {"pm": lambda: self.get_f0_pm(x, p_len), "dio": lambda: self.get_f0_pyworld(x, filter_radius, "dio"), "mangio-crepe-tiny": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "tiny", onnx=onnx_mode), "mangio-crepe-small": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "small", onnx=onnx_mode), "mangio-crepe-medium": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "medium", onnx=onnx_mode), "mangio-crepe-large": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "large", onnx=onnx_mode), "mangio-crepe-full": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "full", onnx=onnx_mode), "crepe-tiny": lambda: self.get_f0_crepe(x, "tiny", onnx=onnx_mode), "crepe-small": lambda: self.get_f0_crepe(x, "small", onnx=onnx_mode), "crepe-medium": lambda: self.get_f0_crepe(x, "medium", onnx=onnx_mode), "crepe-large": lambda: self.get_f0_crepe(x, "large", onnx=onnx_mode), "crepe-full": lambda: self.get_f0_crepe(x, "full", onnx=onnx_mode), "fcpe": lambda: self.get_f0_fcpe(x, p_len, int(hop_length), onnx=onnx_mode), "fcpe-legacy": lambda: self.get_f0_fcpe(x, p_len, int(hop_length), legacy=True, onnx=onnx_mode), "rmvpe": lambda: self.get_f0_rmvpe(x, onnx=onnx_mode), "rmvpe-legacy": lambda: self.get_f0_rmvpe(x, legacy=True, onnx=onnx_mode), "harvest": lambda: self.get_f0_pyworld(x, filter_radius, "harvest"), "yin": lambda: self.get_f0_yin(x, int(hop_length), p_len, mode="yin"), "pyin": lambda: self.get_f0_yin(x, int(hop_length), p_len, mode="pyin"), "swipe": lambda: self.get_f0_swipe(x)}
|
| 296 |
-
f0 = f0_methods.get(method, lambda: ValueError(translations["method_not_valid"]))()
|
| 297 |
-
f0_computation_stack.append(f0)
|
| 298 |
-
|
| 299 |
-
for f0 in f0_computation_stack:
|
| 300 |
-
resampled_stack.append(np.interp(np.linspace(0, len(f0), p_len), np.arange(len(f0)), f0))
|
| 301 |
-
|
| 302 |
-
return resampled_stack[0] if len(resampled_stack) == 1 else np.nanmedian(np.vstack(resampled_stack), axis=0)
|
| 303 |
-
|
| 304 |
-
def get_f0(self, x, p_len, pitch, f0_method, filter_radius, hop_length, f0_autotune, f0_autotune_strength, inp_f0=None, onnx_mode=False):
|
| 305 |
-
f0_methods = {"pm": lambda: self.get_f0_pm(x, p_len), "dio": lambda: self.get_f0_pyworld(x, filter_radius, "dio"), "mangio-crepe-tiny": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "tiny", onnx=onnx_mode), "mangio-crepe-small": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "small", onnx=onnx_mode), "mangio-crepe-medium": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "medium", onnx=onnx_mode), "mangio-crepe-large": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "large", onnx=onnx_mode), "mangio-crepe-full": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "full", onnx=onnx_mode), "crepe-tiny": lambda: self.get_f0_crepe(x, "tiny", onnx=onnx_mode), "crepe-small": lambda: self.get_f0_crepe(x, "small", onnx=onnx_mode), "crepe-medium": lambda: self.get_f0_crepe(x, "medium", onnx=onnx_mode), "crepe-large": lambda: self.get_f0_crepe(x, "large", onnx=onnx_mode), "crepe-full": lambda: self.get_f0_crepe(x, "full", onnx=onnx_mode), "fcpe": lambda: self.get_f0_fcpe(x, p_len, int(hop_length), onnx=onnx_mode), "fcpe-legacy": lambda: self.get_f0_fcpe(x, p_len, int(hop_length), legacy=True, onnx=onnx_mode), "rmvpe": lambda: self.get_f0_rmvpe(x, onnx=onnx_mode), "rmvpe-legacy": lambda: self.get_f0_rmvpe(x, legacy=True, onnx=onnx_mode), "harvest": lambda: self.get_f0_pyworld(x, filter_radius, "harvest"), "yin": lambda: self.get_f0_yin(x, int(hop_length), p_len, mode="yin"), "pyin": lambda: self.get_f0_yin(x, int(hop_length), p_len, mode="pyin"), "swipe": lambda: self.get_f0_swipe(x)}
|
| 306 |
-
f0 = self.get_f0_hybrid(f0_method, x, p_len, hop_length, filter_radius, onnx_mode) if "hybrid" in f0_method else f0_methods.get(f0_method, lambda: ValueError(translations["method_not_valid"]))()
|
| 307 |
-
|
| 308 |
-
if f0_autotune: f0 = Autotune.autotune_f0(self, f0, f0_autotune_strength)
|
| 309 |
-
if isinstance(f0, tuple): f0 = f0[0]
|
| 310 |
-
|
| 311 |
-
f0 *= pow(2, pitch / 12)
|
| 312 |
-
tf0 = self.sample_rate // self.window
|
| 313 |
-
|
| 314 |
-
if inp_f0 is not None:
|
| 315 |
-
replace_f0 = np.interp(list(range(np.round((inp_f0[:, 0].max() - inp_f0[:, 0].min()) * tf0 + 1).astype(np.int16))), inp_f0[:, 0] * 100, inp_f0[:, 1])
|
| 316 |
-
f0[self.x_pad * tf0 : self.x_pad * tf0 + len(replace_f0)] = replace_f0[:f0[self.x_pad * tf0 : self.x_pad * tf0 + len(replace_f0)].shape[0]]
|
| 317 |
-
|
| 318 |
-
f0_mel = 1127 * np.log(1 + f0 / 700)
|
| 319 |
-
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * 254 / (self.f0_mel_max - self.f0_mel_min) + 1
|
| 320 |
-
f0_mel[f0_mel <= 1] = 1
|
| 321 |
-
f0_mel[f0_mel > 255] = 255
|
| 322 |
-
|
| 323 |
-
return np.rint(f0_mel).astype(np.int32), f0.copy()
|
| 324 |
-
|
| 325 |
-
def extract_features(self, model, feats, version):
|
| 326 |
-
return torch.as_tensor(model.run([model.get_outputs()[0].name, model.get_outputs()[1].name], {"feats": feats.detach().cpu().numpy()})[0 if version == "v1" else 1], dtype=torch.float32, device=feats.device)
|
| 327 |
-
|
| 328 |
-
def voice_conversion(self, model, net_g, sid, audio0, pitch, pitchf, index, big_npy, index_rate, version, protect):
|
| 329 |
-
pitch_guidance = pitch != None and pitchf != None
|
| 330 |
-
feats = (torch.from_numpy(audio0).half() if self.is_half else torch.from_numpy(audio0).float())
|
| 331 |
-
|
| 332 |
-
if feats.dim() == 2: feats = feats.mean(-1)
|
| 333 |
-
assert feats.dim() == 1, feats.dim()
|
| 334 |
-
feats = feats.view(1, -1)
|
| 335 |
-
|
| 336 |
-
with torch.no_grad():
|
| 337 |
-
if self.embed_suffix == ".pt":
|
| 338 |
-
padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False)
|
| 339 |
-
logits = model.extract_features(**{"source": feats.to(self.device), "padding_mask": padding_mask, "output_layer": 9 if version == "v1" else 12})
|
| 340 |
-
feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
|
| 341 |
-
elif self.embed_suffix == ".onnx": feats = self.extract_features(model, feats.to(self.device), version).to(self.device)
|
| 342 |
-
elif self.embed_suffix == ".safetensors":
|
| 343 |
-
logits = model(feats.to(self.device))["last_hidden_state"]
|
| 344 |
-
feats = (model.final_proj(logits[0]).unsqueeze(0) if version == "v1" else logits)
|
| 345 |
-
else: raise ValueError(translations["option_not_valid"])
|
| 346 |
-
|
| 347 |
-
if protect < 0.5 and pitch_guidance: feats0 = feats.clone()
|
| 348 |
-
|
| 349 |
-
if (not isinstance(index, type(None)) and not isinstance(big_npy, type(None)) and index_rate != 0):
|
| 350 |
-
npy = feats[0].cpu().numpy()
|
| 351 |
-
if self.is_half: npy = npy.astype(np.float32)
|
| 352 |
-
|
| 353 |
-
score, ix = index.search(npy, k=8)
|
| 354 |
-
weight = np.square(1 / score)
|
| 355 |
-
|
| 356 |
-
npy = np.sum(big_npy[ix] * np.expand_dims(weight / weight.sum(axis=1, keepdims=True), axis=2), axis=1)
|
| 357 |
-
if self.is_half: npy = npy.astype(np.float16)
|
| 358 |
-
|
| 359 |
-
feats = (torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate + (1 - index_rate) * feats)
|
| 360 |
-
|
| 361 |
-
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
|
| 362 |
-
if protect < 0.5 and pitch_guidance: feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
|
| 363 |
-
|
| 364 |
-
p_len = audio0.shape[0] // self.window
|
| 365 |
-
|
| 366 |
-
if feats.shape[1] < p_len:
|
| 367 |
-
p_len = feats.shape[1]
|
| 368 |
-
if pitch_guidance:
|
| 369 |
-
pitch = pitch[:, :p_len]
|
| 370 |
-
pitchf = pitchf[:, :p_len]
|
| 371 |
-
|
| 372 |
-
if protect < 0.5 and pitch_guidance:
|
| 373 |
-
pitchff = pitchf.clone()
|
| 374 |
-
pitchff[pitchf > 0] = 1
|
| 375 |
-
pitchff[pitchf < 1] = protect
|
| 376 |
-
pitchff = pitchff.unsqueeze(-1)
|
| 377 |
-
|
| 378 |
-
feats = (feats * pitchff + feats0 * (1 - pitchff)).to(feats0.dtype)
|
| 379 |
-
|
| 380 |
-
p_len = torch.tensor([p_len], device=self.device).long()
|
| 381 |
-
audio1 = ((net_g.infer(feats.half() if self.is_half else feats.float(), p_len, pitch if pitch_guidance else None, (pitchf.half() if self.is_half else pitchf.float()) if pitch_guidance else None, sid)[0][0, 0]).data.cpu().float().numpy()) if self.suffix == ".pth" else (net_g.run([net_g.get_outputs()[0].name], ({net_g.get_inputs()[0].name: feats.cpu().numpy().astype(np.float32), net_g.get_inputs()[1].name: p_len.cpu().numpy(), net_g.get_inputs()[2].name: np.array([sid.cpu().item()], dtype=np.int64), net_g.get_inputs()[3].name: np.random.randn(1, 192, p_len).astype(np.float32), net_g.get_inputs()[4].name: pitch.cpu().numpy().astype(np.int64), net_g.get_inputs()[5].name: pitchf.cpu().numpy().astype(np.float32)} if pitch_guidance else {net_g.get_inputs()[0].name: feats.cpu().numpy().astype(np.float32), net_g.get_inputs()[1].name: p_len.cpu().numpy(), net_g.get_inputs()[2].name: np.array([sid.cpu().item()], dtype=np.int64), net_g.get_inputs()[3].name: np.random.randn(1, 192, p_len).astype(np.float32)}))[0][0, 0])
|
| 382 |
-
|
| 383 |
-
if self.embed_suffix == ".pt": del padding_mask
|
| 384 |
-
del feats, p_len, net_g
|
| 385 |
-
clear_gpu_cache()
|
| 386 |
-
return audio1
|
| 387 |
-
|
| 388 |
-
def pipeline(self, model, net_g, sid, audio, pitch, f0_method, file_index, index_rate, pitch_guidance, filter_radius, volume_envelope, version, protect, hop_length, f0_autotune, f0_autotune_strength, suffix, embed_suffix, f0_file=None, f0_onnx=False, pbar=None):
|
| 389 |
-
self.suffix = suffix
|
| 390 |
-
self.embed_suffix = embed_suffix
|
| 391 |
-
|
| 392 |
-
if file_index != "" and os.path.exists(file_index) and index_rate != 0:
|
| 393 |
-
try:
|
| 394 |
-
index = faiss.read_index(file_index)
|
| 395 |
-
big_npy = index.reconstruct_n(0, index.ntotal)
|
| 396 |
-
except Exception as e:
|
| 397 |
-
logger.error(translations["read_faiss_index_error"].format(e=e))
|
| 398 |
-
index = big_npy = None
|
| 399 |
-
else: index = big_npy = None
|
| 400 |
-
|
| 401 |
-
pbar.update(1)
|
| 402 |
-
opt_ts, audio_opt = [], []
|
| 403 |
-
audio = signal.filtfilt(bh, ah, audio)
|
| 404 |
-
audio_pad = np.pad(audio, (self.window // 2, self.window // 2), mode="reflect")
|
| 405 |
-
|
| 406 |
-
if audio_pad.shape[0] > self.t_max:
|
| 407 |
-
audio_sum = np.zeros_like(audio)
|
| 408 |
-
for i in range(self.window):
|
| 409 |
-
audio_sum += audio_pad[i : i - self.window]
|
| 410 |
-
|
| 411 |
-
for t in range(self.t_center, audio.shape[0], self.t_center):
|
| 412 |
-
opt_ts.append(t - self.t_query + np.where(np.abs(audio_sum[t - self.t_query : t + self.t_query]) == np.abs(audio_sum[t - self.t_query : t + self.t_query]).min())[0][0])
|
| 413 |
-
|
| 414 |
-
s = 0
|
| 415 |
-
t, inp_f0 = None, None
|
| 416 |
-
audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect")
|
| 417 |
-
sid = torch.tensor(sid, device=self.device).unsqueeze(0).long()
|
| 418 |
-
p_len = audio_pad.shape[0] // self.window
|
| 419 |
-
|
| 420 |
-
if hasattr(f0_file, "name"):
|
| 421 |
-
try:
|
| 422 |
-
with open(f0_file.name, "r") as f:
|
| 423 |
-
raw_lines = f.read()
|
| 424 |
-
if len(raw_lines) > 0:
|
| 425 |
-
inp_f0 = []
|
| 426 |
-
for line in raw_lines.strip("\n").split("\n"):
|
| 427 |
-
inp_f0.append([float(i) for i in line.split(",")])
|
| 428 |
-
|
| 429 |
-
inp_f0 = np.array(inp_f0, dtype=np.float32)
|
| 430 |
-
except:
|
| 431 |
-
logger.error(translations["error_readfile"])
|
| 432 |
-
inp_f0 = None
|
| 433 |
-
|
| 434 |
-
pbar.update(1)
|
| 435 |
-
if pitch_guidance:
|
| 436 |
-
pitch, pitchf = self.get_f0(audio_pad, p_len, pitch, f0_method, filter_radius, hop_length, f0_autotune, f0_autotune_strength, inp_f0, onnx_mode=f0_onnx)
|
| 437 |
-
pitch, pitchf = pitch[:p_len], pitchf[:p_len]
|
| 438 |
-
if self.device == "mps": pitchf = pitchf.astype(np.float32)
|
| 439 |
-
pitch, pitchf = torch.tensor(pitch, device=self.device).unsqueeze(0).long(), torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
|
| 440 |
-
|
| 441 |
-
pbar.update(1)
|
| 442 |
-
for t in opt_ts:
|
| 443 |
-
t = t // self.window * self.window
|
| 444 |
-
audio_opt.append(self.voice_conversion(model, net_g, sid, audio_pad[s : t + self.t_pad2 + self.window], pitch[:, s // self.window : (t + self.t_pad2) // self.window] if pitch_guidance else None, pitchf[:, s // self.window : (t + self.t_pad2) // self.window] if pitch_guidance else None, index, big_npy, index_rate, version, protect)[self.t_pad_tgt : -self.t_pad_tgt])
|
| 445 |
-
s = t
|
| 446 |
-
|
| 447 |
-
audio_opt.append(self.voice_conversion(model, net_g, sid, audio_pad[t:], (pitch[:, t // self.window :] if t is not None else pitch) if pitch_guidance else None, (pitchf[:, t // self.window :] if t is not None else pitchf) if pitch_guidance else None, index, big_npy, index_rate, version, protect)[self.t_pad_tgt : -self.t_pad_tgt])
|
| 448 |
-
audio_opt = np.concatenate(audio_opt)
|
| 449 |
-
if volume_envelope != 1: audio_opt = change_rms(audio, self.sample_rate, audio_opt, self.sample_rate, volume_envelope)
|
| 450 |
-
audio_max = np.abs(audio_opt).max() / 0.99
|
| 451 |
-
if audio_max > 1: audio_opt /= audio_max
|
| 452 |
-
|
| 453 |
-
if pitch_guidance: del pitch, pitchf
|
| 454 |
-
del sid
|
| 455 |
-
clear_gpu_cache()
|
| 456 |
-
pbar.update(1)
|
| 457 |
-
|
| 458 |
-
return audio_opt
|
| 459 |
-
|
| 460 |
-
class VoiceConverter:
|
| 461 |
-
def __init__(self, model_path, sid = 0):
|
| 462 |
-
self.config = config
|
| 463 |
-
self.device = config.device
|
| 464 |
-
self.hubert_model = None
|
| 465 |
-
self.tgt_sr = None
|
| 466 |
-
self.net_g = None
|
| 467 |
-
self.vc = None
|
| 468 |
-
self.cpt = None
|
| 469 |
-
self.version = None
|
| 470 |
-
self.n_spk = None
|
| 471 |
-
self.use_f0 = None
|
| 472 |
-
self.loaded_model = None
|
| 473 |
-
self.vocoder = "Default"
|
| 474 |
-
self.checkpointing = False
|
| 475 |
-
self.sample_rate = 16000
|
| 476 |
-
self.sid = sid
|
| 477 |
-
self.get_vc(model_path, sid)
|
| 478 |
-
|
| 479 |
-
def convert_audio(self, audio_input_path, audio_output_path, index_path, embedder_model, pitch, f0_method, index_rate, volume_envelope, protect, hop_length, f0_autotune, f0_autotune_strength, filter_radius, clean_audio, clean_strength, export_format, resample_sr = 0, checkpointing = False, f0_file = None, f0_onnx = False, embedders_mode = "fairseq", formant_shifting = False, formant_qfrency = 0.8, formant_timbre = 0.8, split_audio = False):
|
| 480 |
-
try:
|
| 481 |
-
with tqdm(total=10, desc=translations["convert_audio"], ncols=100, unit="a") as pbar:
|
| 482 |
-
audio = load_audio(logger, audio_input_path, self.sample_rate, formant_shifting=formant_shifting, formant_qfrency=formant_qfrency, formant_timbre=formant_timbre)
|
| 483 |
-
self.checkpointing = checkpointing
|
| 484 |
-
audio_max = np.abs(audio).max() / 0.95
|
| 485 |
-
if audio_max > 1: audio /= audio_max
|
| 486 |
-
|
| 487 |
-
pbar.update(1)
|
| 488 |
-
if not self.hubert_model:
|
| 489 |
-
models, _, embed_suffix = load_embedders_model(embedder_model, embedders_mode, providers=get_providers())
|
| 490 |
-
self.hubert_model = (models.to(self.device).half() if self.config.is_half else models.to(self.device).float()).eval() if embed_suffix in [".pt", ".safetensors"] else models
|
| 491 |
-
self.embed_suffix = embed_suffix
|
| 492 |
-
|
| 493 |
-
pbar.update(1)
|
| 494 |
-
if self.tgt_sr != resample_sr >= self.sample_rate: self.tgt_sr = resample_sr
|
| 495 |
-
target_sr = min([8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100, 48000, 96000], key=lambda x: abs(x - self.tgt_sr))
|
| 496 |
-
|
| 497 |
-
if split_audio:
|
| 498 |
-
chunks = cut(audio, self.sample_rate, db_thresh=-60, min_interval=500)
|
| 499 |
-
pbar.total = len(chunks) * 4 + 6
|
| 500 |
-
logger.info(f"{translations['split_total']}: {len(chunks)}")
|
| 501 |
-
else: chunks = [(audio, 0, 0)]
|
| 502 |
-
|
| 503 |
-
converted_chunks = []
|
| 504 |
-
pbar.update(1)
|
| 505 |
-
|
| 506 |
-
for waveform, start, end in chunks:
|
| 507 |
-
converted_chunks.append((start, end, self.vc.pipeline(model=self.hubert_model, net_g=self.net_g, sid=self.sid, audio=waveform, pitch=pitch, f0_method=f0_method, file_index=(index_path.strip().strip('"').strip("\n").strip('"').strip().replace("trained", "added")), index_rate=index_rate, pitch_guidance=self.use_f0, filter_radius=filter_radius, volume_envelope=volume_envelope, version=self.version, protect=protect, hop_length=hop_length, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, suffix=self.suffix, embed_suffix=self.embed_suffix, f0_file=f0_file, f0_onnx=f0_onnx, pbar=pbar)))
|
| 508 |
-
|
| 509 |
-
pbar.update(1)
|
| 510 |
-
audio_output = restore(converted_chunks, total_len=len(audio), dtype=converted_chunks[0][2].dtype) if split_audio else converted_chunks[0][2]
|
| 511 |
-
if target_sr >= self.sample_rate and self.tgt_sr != target_sr: audio_output = librosa.resample(audio_output, orig_sr=self.tgt_sr, target_sr=target_sr, res_type="soxr_vhq")
|
| 512 |
-
|
| 513 |
-
pbar.update(1)
|
| 514 |
-
if clean_audio:
|
| 515 |
-
from main.tools.noisereduce import reduce_noise
|
| 516 |
-
audio_output = reduce_noise(y=audio_output, sr=target_sr, prop_decrease=clean_strength, device=self.device)
|
| 517 |
-
|
| 518 |
-
sf.write(audio_output_path, audio_output, target_sr, format=export_format)
|
| 519 |
-
pbar.update(1)
|
| 520 |
-
except Exception as e:
|
| 521 |
-
logger.error(translations["error_convert"].format(e=e))
|
| 522 |
-
import traceback
|
| 523 |
-
logger.debug(traceback.format_exc())
|
| 524 |
-
|
| 525 |
-
def get_vc(self, weight_root, sid):
|
| 526 |
-
if sid == "" or sid == []:
|
| 527 |
-
self.cleanup()
|
| 528 |
-
clear_gpu_cache()
|
| 529 |
-
|
| 530 |
-
if not self.loaded_model or self.loaded_model != weight_root:
|
| 531 |
-
self.loaded_model = weight_root
|
| 532 |
-
self.load_model()
|
| 533 |
-
if self.cpt is not None: self.setup()
|
| 534 |
-
|
| 535 |
-
def cleanup(self):
|
| 536 |
-
if self.hubert_model is not None:
|
| 537 |
-
del self.net_g, self.n_spk, self.vc, self.hubert_model, self.tgt_sr
|
| 538 |
-
self.hubert_model = self.net_g = self.n_spk = self.vc = self.tgt_sr = None
|
| 539 |
-
clear_gpu_cache()
|
| 540 |
-
|
| 541 |
-
del self.net_g, self.cpt
|
| 542 |
-
clear_gpu_cache()
|
| 543 |
-
self.cpt = None
|
| 544 |
-
|
| 545 |
-
def load_model(self):
|
| 546 |
-
if os.path.isfile(self.loaded_model):
|
| 547 |
-
if self.loaded_model.endswith(".pth"): self.cpt = torch.load(self.loaded_model, map_location="cpu")
|
| 548 |
-
else:
|
| 549 |
-
sess_options = onnxruntime.SessionOptions()
|
| 550 |
-
sess_options.log_severity_level = 3
|
| 551 |
-
self.cpt = onnxruntime.InferenceSession(self.loaded_model, sess_options=sess_options, providers=get_providers())
|
| 552 |
-
else: self.cpt = None
|
| 553 |
-
|
| 554 |
-
def setup(self):
|
| 555 |
-
if self.cpt is not None:
|
| 556 |
-
if self.loaded_model.endswith(".pth"):
|
| 557 |
-
self.tgt_sr = self.cpt["config"][-1]
|
| 558 |
-
self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0]
|
| 559 |
-
self.use_f0 = self.cpt.get("f0", 1)
|
| 560 |
-
self.version = self.cpt.get("version", "v1")
|
| 561 |
-
self.vocoder = self.cpt.get("vocoder", "Default")
|
| 562 |
-
if self.vocoder != "Default": self.config.is_half = False
|
| 563 |
-
|
| 564 |
-
self.net_g = Synthesizer(*self.cpt["config"], use_f0=self.use_f0, text_enc_hidden_dim=768 if self.version == "v2" else 256, vocoder=self.vocoder, checkpointing=self.checkpointing)
|
| 565 |
-
del self.net_g.enc_q
|
| 566 |
-
|
| 567 |
-
self.net_g.load_state_dict(self.cpt["weight"], strict=False)
|
| 568 |
-
self.net_g.eval().to(self.device)
|
| 569 |
-
self.net_g = (self.net_g.half() if self.config.is_half else self.net_g.float())
|
| 570 |
-
self.n_spk = self.cpt["config"][-3]
|
| 571 |
-
self.suffix = ".pth"
|
| 572 |
-
else:
|
| 573 |
-
import json
|
| 574 |
-
import onnx
|
| 575 |
-
|
| 576 |
-
metadata_dict = None
|
| 577 |
-
for prop in onnx.load(self.loaded_model).metadata_props:
|
| 578 |
-
if prop.key == "model_info":
|
| 579 |
-
metadata_dict = json.loads(prop.value)
|
| 580 |
-
break
|
| 581 |
-
|
| 582 |
-
self.net_g = self.cpt
|
| 583 |
-
self.tgt_sr = metadata_dict.get("sr", 32000)
|
| 584 |
-
self.use_f0 = metadata_dict.get("f0", 1)
|
| 585 |
-
self.version = metadata_dict.get("version", "v1")
|
| 586 |
-
self.suffix = ".onnx"
|
| 587 |
-
|
| 588 |
-
self.vc = VC(self.tgt_sr, self.config)
|
| 589 |
-
|
| 590 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/create_dataset.py
DELETED
|
@@ -1,230 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import time
|
| 4 |
-
import yt_dlp
|
| 5 |
-
import shutil
|
| 6 |
-
import librosa
|
| 7 |
-
import logging
|
| 8 |
-
import argparse
|
| 9 |
-
import warnings
|
| 10 |
-
import logging.handlers
|
| 11 |
-
|
| 12 |
-
from soundfile import read, write
|
| 13 |
-
from distutils.util import strtobool
|
| 14 |
-
|
| 15 |
-
sys.path.append(os.getcwd())
|
| 16 |
-
|
| 17 |
-
from main.configs.config import Config
|
| 18 |
-
from main.library.algorithm.separator import Separator
|
| 19 |
-
|
| 20 |
-
config = Config()
|
| 21 |
-
translations = config.translations
|
| 22 |
-
dataset_temp = os.path.join("dataset_temp")
|
| 23 |
-
logger = logging.getLogger(__name__)
|
| 24 |
-
|
| 25 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
| 26 |
-
else:
|
| 27 |
-
console_handler = logging.StreamHandler()
|
| 28 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
| 29 |
-
console_handler.setFormatter(console_formatter)
|
| 30 |
-
console_handler.setLevel(logging.INFO)
|
| 31 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "create_dataset.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
| 32 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
| 33 |
-
file_handler.setFormatter(file_formatter)
|
| 34 |
-
file_handler.setLevel(logging.DEBUG)
|
| 35 |
-
logger.addHandler(console_handler)
|
| 36 |
-
logger.addHandler(file_handler)
|
| 37 |
-
logger.setLevel(logging.DEBUG)
|
| 38 |
-
|
| 39 |
-
def parse_arguments():
|
| 40 |
-
parser = argparse.ArgumentParser()
|
| 41 |
-
parser.add_argument("--input_audio", type=str, required=True)
|
| 42 |
-
parser.add_argument("--output_dataset", type=str, default="./dataset")
|
| 43 |
-
parser.add_argument("--sample_rate", type=int, default=44100)
|
| 44 |
-
parser.add_argument("--clean_dataset", type=lambda x: bool(strtobool(x)), default=False)
|
| 45 |
-
parser.add_argument("--clean_strength", type=float, default=0.7)
|
| 46 |
-
parser.add_argument("--separator_reverb", type=lambda x: bool(strtobool(x)), default=False)
|
| 47 |
-
parser.add_argument("--kim_vocal_version", type=int, default=2)
|
| 48 |
-
parser.add_argument("--overlap", type=float, default=0.25)
|
| 49 |
-
parser.add_argument("--segments_size", type=int, default=256)
|
| 50 |
-
parser.add_argument("--mdx_hop_length", type=int, default=1024)
|
| 51 |
-
parser.add_argument("--mdx_batch_size", type=int, default=1)
|
| 52 |
-
parser.add_argument("--denoise_mdx", type=lambda x: bool(strtobool(x)), default=False)
|
| 53 |
-
parser.add_argument("--skip", type=lambda x: bool(strtobool(x)), default=False)
|
| 54 |
-
parser.add_argument("--skip_start_audios", type=str, default="0")
|
| 55 |
-
parser.add_argument("--skip_end_audios", type=str, default="0")
|
| 56 |
-
|
| 57 |
-
return parser.parse_args()
|
| 58 |
-
|
| 59 |
-
def main():
|
| 60 |
-
pid_path = os.path.join("assets", "create_dataset_pid.txt")
|
| 61 |
-
with open(pid_path, "w") as pid_file:
|
| 62 |
-
pid_file.write(str(os.getpid()))
|
| 63 |
-
|
| 64 |
-
args = parse_arguments()
|
| 65 |
-
input_audio, output_dataset, sample_rate, clean_dataset, clean_strength, separator_reverb, kim_vocal_version, overlap, segments_size, hop_length, batch_size, denoise_mdx, skip, skip_start_audios, skip_end_audios = args.input_audio, args.output_dataset, args.sample_rate, args.clean_dataset, args.clean_strength, args.separator_reverb, args.kim_vocal_version, args.overlap, args.segments_size, args.mdx_hop_length, args.mdx_batch_size, args.denoise_mdx, args.skip, args.skip_start_audios, args.skip_end_audios
|
| 66 |
-
log_data = {translations['audio_path']: input_audio, translations['output_path']: output_dataset, translations['sr']: sample_rate, translations['clear_dataset']: clean_dataset, translations['dereveb_audio']: separator_reverb, translations['segments_size']: segments_size, translations['overlap']: overlap, "Hop length": hop_length, translations['batch_size']: batch_size, translations['denoise_mdx']: denoise_mdx, translations['skip']: skip}
|
| 67 |
-
|
| 68 |
-
if clean_dataset: log_data[translations['clean_strength']] = clean_strength
|
| 69 |
-
if skip:
|
| 70 |
-
log_data[translations['skip_start']] = skip_start_audios
|
| 71 |
-
log_data[translations['skip_end']] = skip_end_audios
|
| 72 |
-
|
| 73 |
-
for key, value in log_data.items():
|
| 74 |
-
logger.debug(f"{key}: {value}")
|
| 75 |
-
|
| 76 |
-
if kim_vocal_version not in [1, 2]: raise ValueError(translations["version_not_valid"])
|
| 77 |
-
start_time = time.time()
|
| 78 |
-
|
| 79 |
-
try:
|
| 80 |
-
paths = []
|
| 81 |
-
|
| 82 |
-
if not os.path.exists(dataset_temp): os.makedirs(dataset_temp, exist_ok=True)
|
| 83 |
-
urls = input_audio.replace(", ", ",").split(",")
|
| 84 |
-
|
| 85 |
-
for url in urls:
|
| 86 |
-
path = downloader(url, urls.index(url))
|
| 87 |
-
paths.append(path)
|
| 88 |
-
|
| 89 |
-
if skip:
|
| 90 |
-
skip_start_audios, skip_end_audios = skip_start_audios.replace(", ", ",").split(","), skip_end_audios.replace(", ", ",").split(",")
|
| 91 |
-
|
| 92 |
-
if len(skip_start_audios) < len(paths) or len(skip_end_audios) < len(paths):
|
| 93 |
-
logger.warning(translations["skip<audio"])
|
| 94 |
-
sys.exit(1)
|
| 95 |
-
elif len(skip_start_audios) > len(paths) or len(skip_end_audios) > len(paths):
|
| 96 |
-
logger.warning(translations["skip>audio"])
|
| 97 |
-
sys.exit(1)
|
| 98 |
-
else:
|
| 99 |
-
for audio, skip_start_audio, skip_end_audio in zip(paths, skip_start_audios, skip_end_audios):
|
| 100 |
-
skip_start(audio, skip_start_audio)
|
| 101 |
-
skip_end(audio, skip_end_audio)
|
| 102 |
-
|
| 103 |
-
separator_paths = []
|
| 104 |
-
|
| 105 |
-
for audio in paths:
|
| 106 |
-
vocals = separator_music_main(audio, dataset_temp, segments_size, overlap, denoise_mdx, kim_vocal_version, hop_length, batch_size, sample_rate)
|
| 107 |
-
if separator_reverb: vocals = separator_reverb_audio(vocals, dataset_temp, segments_size, overlap, denoise_mdx, hop_length, batch_size, sample_rate)
|
| 108 |
-
separator_paths.append(vocals)
|
| 109 |
-
|
| 110 |
-
paths = separator_paths
|
| 111 |
-
|
| 112 |
-
for audio_path in paths:
|
| 113 |
-
data, sample_rate = read(audio_path)
|
| 114 |
-
data = librosa.to_mono(data.T)
|
| 115 |
-
|
| 116 |
-
if clean_dataset:
|
| 117 |
-
from main.tools.noisereduce import reduce_noise
|
| 118 |
-
data = reduce_noise(y=data, prop_decrease=clean_strength, device=config.device)
|
| 119 |
-
|
| 120 |
-
write(audio_path, data, sample_rate)
|
| 121 |
-
except Exception as e:
|
| 122 |
-
logger.error(f"{translations['create_dataset_error']}: {e}")
|
| 123 |
-
import traceback
|
| 124 |
-
logger.error(traceback.format_exc())
|
| 125 |
-
finally:
|
| 126 |
-
for audio in paths:
|
| 127 |
-
shutil.move(audio, output_dataset)
|
| 128 |
-
|
| 129 |
-
if os.path.exists(dataset_temp): shutil.rmtree(dataset_temp, ignore_errors=True)
|
| 130 |
-
|
| 131 |
-
elapsed_time = time.time() - start_time
|
| 132 |
-
if os.path.exists(pid_path): os.remove(pid_path)
|
| 133 |
-
logger.info(translations["create_dataset_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
|
| 134 |
-
|
| 135 |
-
def downloader(url, name):
|
| 136 |
-
with warnings.catch_warnings():
|
| 137 |
-
warnings.simplefilter("ignore")
|
| 138 |
-
|
| 139 |
-
ydl_opts = {"format": "bestaudio/best", "outtmpl": os.path.join(dataset_temp, f"{name}"), "postprocessors": [{"key": "FFmpegExtractAudio", "preferredcodec": "wav", "preferredquality": "192"}], "no_warnings": True, "noplaylist": True, "noplaylist": True, "verbose": False}
|
| 140 |
-
logger.info(f"{translations['starting_download']}: {url}...")
|
| 141 |
-
|
| 142 |
-
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
| 143 |
-
ydl.extract_info(url)
|
| 144 |
-
logger.info(f"{translations['download_success']}: {url}")
|
| 145 |
-
|
| 146 |
-
return os.path.join(dataset_temp, f"{name}" + ".wav")
|
| 147 |
-
|
| 148 |
-
def skip_start(input_file, seconds):
|
| 149 |
-
data, sr = read(input_file)
|
| 150 |
-
total_duration = len(data) / sr
|
| 151 |
-
|
| 152 |
-
if seconds <= 0: logger.warning(translations["=<0"])
|
| 153 |
-
elif seconds >= total_duration: logger.warning(translations["skip_warning"].format(seconds=seconds, total_duration=f"{total_duration:.2f}"))
|
| 154 |
-
else:
|
| 155 |
-
logger.info(f"{translations['skip_start']}: {input_file}...")
|
| 156 |
-
write(input_file, data[int(seconds * sr):], sr)
|
| 157 |
-
|
| 158 |
-
logger.info(translations["skip_start_audio"].format(input_file=input_file))
|
| 159 |
-
|
| 160 |
-
def skip_end(input_file, seconds):
|
| 161 |
-
data, sr = read(input_file)
|
| 162 |
-
total_duration = len(data) / sr
|
| 163 |
-
|
| 164 |
-
if seconds <= 0: logger.warning(translations["=<0"])
|
| 165 |
-
elif seconds > total_duration: logger.warning(translations["skip_warning"].format(seconds=seconds, total_duration=f"{total_duration:.2f}"))
|
| 166 |
-
else:
|
| 167 |
-
logger.info(f"{translations['skip_end']}: {input_file}...")
|
| 168 |
-
write(input_file, data[:-int(seconds * sr)], sr)
|
| 169 |
-
|
| 170 |
-
logger.info(translations["skip_end_audio"].format(input_file=input_file))
|
| 171 |
-
|
| 172 |
-
def separator_music_main(input, output, segments_size, overlap, denoise, version, hop_length, batch_size, sample_rate):
|
| 173 |
-
if not os.path.exists(input):
|
| 174 |
-
logger.warning(translations["input_not_valid"])
|
| 175 |
-
return None
|
| 176 |
-
|
| 177 |
-
if not os.path.exists(output):
|
| 178 |
-
logger.warning(translations["output_not_valid"])
|
| 179 |
-
return None
|
| 180 |
-
|
| 181 |
-
model = f"Kim_Vocal_{version}.onnx"
|
| 182 |
-
output_separator = separator_main(audio_file=input, model_filename=model, output_format="wav", output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate)
|
| 183 |
-
|
| 184 |
-
for f in output_separator:
|
| 185 |
-
path = os.path.join(output, f)
|
| 186 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
| 187 |
-
|
| 188 |
-
if '_(Instrumental)_' in f: os.rename(path, os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav")
|
| 189 |
-
elif '_(Vocals)_' in f:
|
| 190 |
-
rename_file = os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav"
|
| 191 |
-
os.rename(path, rename_file)
|
| 192 |
-
|
| 193 |
-
return rename_file
|
| 194 |
-
|
| 195 |
-
def separator_reverb_audio(input, output, segments_size, overlap, denoise, hop_length, batch_size, sample_rate):
|
| 196 |
-
if not os.path.exists(input):
|
| 197 |
-
logger.warning(translations["input_not_valid"])
|
| 198 |
-
return None
|
| 199 |
-
|
| 200 |
-
if not os.path.exists(output):
|
| 201 |
-
logger.warning(translations["output_not_valid"])
|
| 202 |
-
return None
|
| 203 |
-
|
| 204 |
-
logger.info(f"{translations['dereverb']}: {input}...")
|
| 205 |
-
output_dereverb = separator_main(audio_file=input, model_filename="Reverb_HQ_By_FoxJoy.onnx", output_format="wav", output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=hop_length, mdx_hop_length=batch_size, mdx_enable_denoise=denoise, sample_rate=sample_rate)
|
| 206 |
-
|
| 207 |
-
for f in output_dereverb:
|
| 208 |
-
path = os.path.join(output, f)
|
| 209 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
| 210 |
-
|
| 211 |
-
if '_(Reverb)_' in f: os.rename(path, os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav")
|
| 212 |
-
elif '_(No Reverb)_' in f:
|
| 213 |
-
rename_file = os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav"
|
| 214 |
-
os.rename(path, rename_file)
|
| 215 |
-
|
| 216 |
-
logger.info(f"{translations['dereverb_success']}: {rename_file}")
|
| 217 |
-
return rename_file
|
| 218 |
-
|
| 219 |
-
def separator_main(audio_file=None, model_filename="Kim_Vocal_1.onnx", output_format="wav", output_dir=".", mdx_segment_size=256, mdx_overlap=0.25, mdx_batch_size=1, mdx_hop_length=1024, mdx_enable_denoise=True, sample_rate=44100):
|
| 220 |
-
try:
|
| 221 |
-
separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=sample_rate, mdx_params={"hop_length": mdx_hop_length, "segment_size": mdx_segment_size, "overlap": mdx_overlap, "batch_size": mdx_batch_size, "enable_denoise": mdx_enable_denoise})
|
| 222 |
-
separator.load_model(model_filename=model_filename)
|
| 223 |
-
return separator.separate(audio_file)
|
| 224 |
-
except:
|
| 225 |
-
logger.debug(translations["default_setting"])
|
| 226 |
-
separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=44100, mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": mdx_enable_denoise})
|
| 227 |
-
separator.load_model(model_filename=model_filename)
|
| 228 |
-
return separator.separate(audio_file)
|
| 229 |
-
|
| 230 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/create_index.py
DELETED
|
@@ -1,90 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import faiss
|
| 4 |
-
import logging
|
| 5 |
-
import argparse
|
| 6 |
-
import logging.handlers
|
| 7 |
-
|
| 8 |
-
import numpy as np
|
| 9 |
-
|
| 10 |
-
from multiprocessing import cpu_count
|
| 11 |
-
from sklearn.cluster import MiniBatchKMeans
|
| 12 |
-
|
| 13 |
-
sys.path.append(os.getcwd())
|
| 14 |
-
|
| 15 |
-
from main.configs.config import Config
|
| 16 |
-
translations = Config().translations
|
| 17 |
-
|
| 18 |
-
def parse_arguments():
|
| 19 |
-
parser = argparse.ArgumentParser()
|
| 20 |
-
parser.add_argument("--model_name", type=str, required=True)
|
| 21 |
-
parser.add_argument("--rvc_version", type=str, default="v2")
|
| 22 |
-
parser.add_argument("--index_algorithm", type=str, default="Auto")
|
| 23 |
-
|
| 24 |
-
return parser.parse_args()
|
| 25 |
-
|
| 26 |
-
def main():
|
| 27 |
-
args = parse_arguments()
|
| 28 |
-
exp_dir = os.path.join("assets", "logs", args.model_name)
|
| 29 |
-
version, index_algorithm = args.rvc_version, args.index_algorithm
|
| 30 |
-
logger = logging.getLogger(__name__)
|
| 31 |
-
|
| 32 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
| 33 |
-
else:
|
| 34 |
-
console_handler = logging.StreamHandler()
|
| 35 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
| 36 |
-
console_handler.setFormatter(console_formatter)
|
| 37 |
-
console_handler.setLevel(logging.INFO)
|
| 38 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join(exp_dir, "create_index.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
| 39 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
| 40 |
-
file_handler.setFormatter(file_formatter)
|
| 41 |
-
file_handler.setLevel(logging.DEBUG)
|
| 42 |
-
logger.addHandler(console_handler)
|
| 43 |
-
logger.addHandler(file_handler)
|
| 44 |
-
logger.setLevel(logging.DEBUG)
|
| 45 |
-
|
| 46 |
-
log_data = {translations['modelname']: args.model_name, translations['model_path']: exp_dir, translations['training_version']: version, translations['index_algorithm_info']: index_algorithm}
|
| 47 |
-
for key, value in log_data.items():
|
| 48 |
-
logger.debug(f"{key}: {value}")
|
| 49 |
-
|
| 50 |
-
try:
|
| 51 |
-
npys = []
|
| 52 |
-
feature_dir = os.path.join(exp_dir, f"{version}_extracted")
|
| 53 |
-
model_name = os.path.basename(exp_dir)
|
| 54 |
-
|
| 55 |
-
for name in sorted(os.listdir(feature_dir)):
|
| 56 |
-
npys.append(np.load(os.path.join(feature_dir, name)))
|
| 57 |
-
|
| 58 |
-
big_npy = np.concatenate(npys, axis=0)
|
| 59 |
-
big_npy_idx = np.arange(big_npy.shape[0])
|
| 60 |
-
np.random.shuffle(big_npy_idx)
|
| 61 |
-
big_npy = big_npy[big_npy_idx]
|
| 62 |
-
|
| 63 |
-
if big_npy.shape[0] > 2e5 and (index_algorithm == "Auto" or index_algorithm == "KMeans"): big_npy = (MiniBatchKMeans(n_clusters=10000, verbose=True, batch_size=256 * cpu_count(), compute_labels=False, init="random").fit(big_npy).cluster_centers_)
|
| 64 |
-
np.save(os.path.join(exp_dir, "total_fea.npy"), big_npy)
|
| 65 |
-
|
| 66 |
-
n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
|
| 67 |
-
index_trained = faiss.index_factory(256 if version == "v1" else 768, f"IVF{n_ivf},Flat")
|
| 68 |
-
index_ivf_trained = faiss.extract_index_ivf(index_trained)
|
| 69 |
-
index_ivf_trained.nprobe = 1
|
| 70 |
-
index_trained.train(big_npy)
|
| 71 |
-
faiss.write_index(index_trained, os.path.join(exp_dir, f"trained_IVF{n_ivf}_Flat_nprobe_{index_ivf_trained.nprobe}_{model_name}_{version}.index"))
|
| 72 |
-
|
| 73 |
-
index_added = faiss.index_factory(256 if version == "v1" else 768, f"IVF{n_ivf},Flat")
|
| 74 |
-
index_ivf_added = faiss.extract_index_ivf(index_added)
|
| 75 |
-
index_ivf_added.nprobe = 1
|
| 76 |
-
index_added.train(big_npy)
|
| 77 |
-
batch_size_add = 8192
|
| 78 |
-
|
| 79 |
-
for i in range(0, big_npy.shape[0], batch_size_add):
|
| 80 |
-
index_added.add(big_npy[i : i + batch_size_add])
|
| 81 |
-
|
| 82 |
-
index_filepath_added = os.path.join(exp_dir, f"added_IVF{n_ivf}_Flat_nprobe_{index_ivf_added.nprobe}_{model_name}_{version}.index")
|
| 83 |
-
faiss.write_index(index_added, index_filepath_added)
|
| 84 |
-
logger.info(f"{translations['save_index']} '{index_filepath_added}'")
|
| 85 |
-
except Exception as e:
|
| 86 |
-
logger.error(f"{translations['create_index_error']}: {e}")
|
| 87 |
-
import traceback
|
| 88 |
-
logger.debug(traceback.format_exc())
|
| 89 |
-
|
| 90 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/extract.py
DELETED
|
@@ -1,360 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import re
|
| 3 |
-
import sys
|
| 4 |
-
import time
|
| 5 |
-
import tqdm
|
| 6 |
-
import torch
|
| 7 |
-
import shutil
|
| 8 |
-
import logging
|
| 9 |
-
import argparse
|
| 10 |
-
import warnings
|
| 11 |
-
import onnxruntime
|
| 12 |
-
import logging.handlers
|
| 13 |
-
|
| 14 |
-
import numpy as np
|
| 15 |
-
import soundfile as sf
|
| 16 |
-
import torch.nn.functional as F
|
| 17 |
-
|
| 18 |
-
from random import shuffle
|
| 19 |
-
from distutils.util import strtobool
|
| 20 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 21 |
-
|
| 22 |
-
sys.path.append(os.getcwd())
|
| 23 |
-
|
| 24 |
-
from main.configs.config import Config
|
| 25 |
-
from main.library.utils import check_predictors, check_embedders, load_audio, load_embedders_model
|
| 26 |
-
|
| 27 |
-
logger = logging.getLogger(__name__)
|
| 28 |
-
config = Config()
|
| 29 |
-
translations = config.translations
|
| 30 |
-
logger.propagate = False
|
| 31 |
-
|
| 32 |
-
warnings.filterwarnings("ignore")
|
| 33 |
-
for l in ["torch", "faiss", "httpx", "fairseq", "httpcore", "faiss.loader", "numba.core", "urllib3", "matplotlib"]:
|
| 34 |
-
logging.getLogger(l).setLevel(logging.ERROR)
|
| 35 |
-
|
| 36 |
-
def parse_arguments():
|
| 37 |
-
parser = argparse.ArgumentParser()
|
| 38 |
-
parser.add_argument("--model_name", type=str, required=True)
|
| 39 |
-
parser.add_argument("--rvc_version", type=str, default="v2")
|
| 40 |
-
parser.add_argument("--f0_method", type=str, default="rmvpe")
|
| 41 |
-
parser.add_argument("--pitch_guidance", type=lambda x: bool(strtobool(x)), default=True)
|
| 42 |
-
parser.add_argument("--hop_length", type=int, default=128)
|
| 43 |
-
parser.add_argument("--cpu_cores", type=int, default=2)
|
| 44 |
-
parser.add_argument("--gpu", type=str, default="-")
|
| 45 |
-
parser.add_argument("--sample_rate", type=int, required=True)
|
| 46 |
-
parser.add_argument("--embedder_model", type=str, default="contentvec_base")
|
| 47 |
-
parser.add_argument("--f0_onnx", type=lambda x: bool(strtobool(x)), default=False)
|
| 48 |
-
parser.add_argument("--embedders_mode", type=str, default="fairseq")
|
| 49 |
-
|
| 50 |
-
return parser.parse_args()
|
| 51 |
-
|
| 52 |
-
def generate_config(rvc_version, sample_rate, model_path):
|
| 53 |
-
config_save_path = os.path.join(model_path, "config.json")
|
| 54 |
-
if not os.path.exists(config_save_path): shutil.copy(os.path.join("main", "configs", rvc_version, f"{sample_rate}.json"), config_save_path)
|
| 55 |
-
|
| 56 |
-
def generate_filelist(pitch_guidance, model_path, rvc_version, sample_rate):
|
| 57 |
-
gt_wavs_dir, feature_dir = os.path.join(model_path, "sliced_audios"), os.path.join(model_path, f"{rvc_version}_extracted")
|
| 58 |
-
f0_dir, f0nsf_dir = None, None
|
| 59 |
-
|
| 60 |
-
if pitch_guidance: f0_dir, f0nsf_dir = os.path.join(model_path, "f0"), os.path.join(model_path, "f0_voiced")
|
| 61 |
-
|
| 62 |
-
gt_wavs_files, feature_files = set(name.split(".")[0] for name in os.listdir(gt_wavs_dir)), set(name.split(".")[0] for name in os.listdir(feature_dir))
|
| 63 |
-
names = gt_wavs_files & feature_files & set(name.split(".")[0] for name in os.listdir(f0_dir)) & set(name.split(".")[0] for name in os.listdir(f0nsf_dir)) if pitch_guidance else gt_wavs_files & feature_files
|
| 64 |
-
|
| 65 |
-
options = []
|
| 66 |
-
mute_base_path = os.path.join("assets", "logs", "mute")
|
| 67 |
-
|
| 68 |
-
for name in names:
|
| 69 |
-
options.append(f"{gt_wavs_dir}/{name}.wav|{feature_dir}/{name}.npy|{f0_dir}/{name}.wav.npy|{f0nsf_dir}/{name}.wav.npy|0" if pitch_guidance else f"{gt_wavs_dir}/{name}.wav|{feature_dir}/{name}.npy|0")
|
| 70 |
-
|
| 71 |
-
mute_audio_path, mute_feature_path = os.path.join(mute_base_path, "sliced_audios", f"mute{sample_rate}.wav"), os.path.join(mute_base_path, f"{rvc_version}_extracted", "mute.npy")
|
| 72 |
-
for _ in range(2):
|
| 73 |
-
options.append(f"{mute_audio_path}|{mute_feature_path}|{os.path.join(mute_base_path, 'f0', 'mute.wav.npy')}|{os.path.join(mute_base_path, 'f0_voiced', 'mute.wav.npy')}|0" if pitch_guidance else f"{mute_audio_path}|{mute_feature_path}|0")
|
| 74 |
-
|
| 75 |
-
shuffle(options)
|
| 76 |
-
with open(os.path.join(model_path, "filelist.txt"), "w") as f:
|
| 77 |
-
f.write("\n".join(options))
|
| 78 |
-
|
| 79 |
-
def setup_paths(exp_dir, version = None):
|
| 80 |
-
wav_path = os.path.join(exp_dir, "sliced_audios_16k")
|
| 81 |
-
|
| 82 |
-
if version:
|
| 83 |
-
out_path = os.path.join(exp_dir, f"{version}_extracted")
|
| 84 |
-
os.makedirs(out_path, exist_ok=True)
|
| 85 |
-
return wav_path, out_path
|
| 86 |
-
else:
|
| 87 |
-
output_root1, output_root2 = os.path.join(exp_dir, "f0"), os.path.join(exp_dir, "f0_voiced")
|
| 88 |
-
os.makedirs(output_root1, exist_ok=True); os.makedirs(output_root2, exist_ok=True)
|
| 89 |
-
return wav_path, output_root1, output_root2
|
| 90 |
-
|
| 91 |
-
def read_wave(wav_path, normalize = False, is_half = False):
|
| 92 |
-
wav, sr = sf.read(wav_path, dtype=np.float32)
|
| 93 |
-
assert sr == 16000, translations["sr_not_16000"]
|
| 94 |
-
|
| 95 |
-
feats = torch.from_numpy(wav).float()
|
| 96 |
-
if feats.dim() == 2: feats = feats.mean(-1)
|
| 97 |
-
feats = feats.view(1, -1)
|
| 98 |
-
|
| 99 |
-
if normalize: feats = F.layer_norm(feats, feats.shape)
|
| 100 |
-
return feats.half() if is_half else feats.float()
|
| 101 |
-
|
| 102 |
-
def get_device(gpu_index):
|
| 103 |
-
try:
|
| 104 |
-
index = int(gpu_index)
|
| 105 |
-
if index < torch.cuda.device_count(): return f"cuda:{index}"
|
| 106 |
-
else: logger.warning(translations["gpu_not_valid"])
|
| 107 |
-
except ValueError:
|
| 108 |
-
logger.warning(translations["gpu_not_valid"])
|
| 109 |
-
return "cpu"
|
| 110 |
-
|
| 111 |
-
def get_providers():
|
| 112 |
-
ort_providers = onnxruntime.get_available_providers()
|
| 113 |
-
|
| 114 |
-
if "CUDAExecutionProvider" in ort_providers: providers = ["CUDAExecutionProvider"]
|
| 115 |
-
elif "CoreMLExecutionProvider" in ort_providers: providers = ["CoreMLExecutionProvider"]
|
| 116 |
-
else: providers = ["CPUExecutionProvider"]
|
| 117 |
-
|
| 118 |
-
return providers
|
| 119 |
-
|
| 120 |
-
class FeatureInput:
|
| 121 |
-
def __init__(self, sample_rate=16000, hop_size=160, is_half=False, device=config.device):
|
| 122 |
-
self.fs = sample_rate
|
| 123 |
-
self.hop = hop_size
|
| 124 |
-
self.f0_bin = 256
|
| 125 |
-
self.f0_max = 1100.0
|
| 126 |
-
self.f0_min = 50.0
|
| 127 |
-
self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
|
| 128 |
-
self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
|
| 129 |
-
self.device = device
|
| 130 |
-
self.is_half = is_half
|
| 131 |
-
|
| 132 |
-
def compute_f0_hybrid(self, methods_str, np_arr, hop_length, f0_onnx):
|
| 133 |
-
methods_str = re.search("hybrid\[(.+)\]", methods_str)
|
| 134 |
-
if methods_str: methods = [method.strip() for method in methods_str.group(1).split("+")]
|
| 135 |
-
f0_computation_stack, resampled_stack = [], []
|
| 136 |
-
logger.debug(translations["hybrid_methods"].format(methods=methods))
|
| 137 |
-
|
| 138 |
-
for method in methods:
|
| 139 |
-
f0 = None
|
| 140 |
-
f0_methods = {"pm": lambda: self.get_pm(np_arr), "dio": lambda: self.get_pyworld(np_arr, "dio"), "mangio-crepe-full": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "full", onnx=f0_onnx), "mangio-crepe-large": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "large", onnx=f0_onnx), "mangio-crepe-medium": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "medium", onnx=f0_onnx), "mangio-crepe-small": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "small", onnx=f0_onnx), "mangio-crepe-tiny": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "tiny", onnx=f0_onnx), "crepe-full": lambda: self.get_crepe(np_arr, "full", onnx=f0_onnx), "crepe-large": lambda: self.get_crepe(np_arr, "large", onnx=f0_onnx), "crepe-medium": lambda: self.get_crepe(np_arr, "medium", onnx=f0_onnx), "crepe-small": lambda: self.get_crepe(np_arr, "small", onnx=f0_onnx), "crepe-tiny": lambda: self.get_crepe(np_arr, "tiny", onnx=f0_onnx), "fcpe": lambda: self.get_fcpe(np_arr, int(hop_length), onnx=f0_onnx), "fcpe-legacy": lambda: self.get_fcpe(np_arr, int(hop_length), legacy=True, onnx=f0_onnx), "rmvpe": lambda: self.get_rmvpe(np_arr, onnx=f0_onnx), "rmvpe-legacy": lambda: self.get_rmvpe(np_arr, legacy=True, onnx=f0_onnx), "harvest": lambda: self.get_pyworld(np_arr, "harvest"), "swipe": lambda: self.get_swipe(np_arr), "yin": lambda: self.get_yin(np_arr, int(hop_length), mode="yin"), "pyin": lambda: self.get_yin(np_arr, int(hop_length), mode="pyin")}
|
| 141 |
-
f0 = f0_methods.get(method, lambda: ValueError(translations["method_not_valid"]))()
|
| 142 |
-
f0_computation_stack.append(f0)
|
| 143 |
-
|
| 144 |
-
for f0 in f0_computation_stack:
|
| 145 |
-
resampled_stack.append(np.interp(np.linspace(0, len(f0), (np_arr.size // self.hop)), np.arange(len(f0)), f0))
|
| 146 |
-
|
| 147 |
-
return resampled_stack[0] if len(resampled_stack) == 1 else np.nanmedian(np.vstack(resampled_stack), axis=0)
|
| 148 |
-
|
| 149 |
-
def compute_f0(self, np_arr, f0_method, hop_length, f0_onnx=False):
|
| 150 |
-
f0_methods = {"pm": lambda: self.get_pm(np_arr), "dio": lambda: self.get_pyworld(np_arr, "dio"), "mangio-crepe-full": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "full", onnx=f0_onnx), "mangio-crepe-large": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "large", onnx=f0_onnx), "mangio-crepe-medium": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "medium", onnx=f0_onnx), "mangio-crepe-small": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "small", onnx=f0_onnx), "mangio-crepe-tiny": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "tiny", onnx=f0_onnx), "crepe-full": lambda: self.get_crepe(np_arr, "full", onnx=f0_onnx), "crepe-large": lambda: self.get_crepe(np_arr, "large", onnx=f0_onnx), "crepe-medium": lambda: self.get_crepe(np_arr, "medium", onnx=f0_onnx), "crepe-small": lambda: self.get_crepe(np_arr, "small", onnx=f0_onnx), "crepe-tiny": lambda: self.get_crepe(np_arr, "tiny", onnx=f0_onnx), "fcpe": lambda: self.get_fcpe(np_arr, int(hop_length), onnx=f0_onnx), "fcpe-legacy": lambda: self.get_fcpe(np_arr, int(hop_length), legacy=True, onnx=f0_onnx), "rmvpe": lambda: self.get_rmvpe(np_arr, onnx=f0_onnx), "rmvpe-legacy": lambda: self.get_rmvpe(np_arr, legacy=True, onnx=f0_onnx), "harvest": lambda: self.get_pyworld(np_arr, "harvest"), "swipe": lambda: self.get_swipe(np_arr), "yin": lambda: self.get_yin(np_arr, int(hop_length), mode="yin"), "pyin": lambda: self.get_yin(np_arr, int(hop_length), mode="pyin")}
|
| 151 |
-
return self.compute_f0_hybrid(f0_method, np_arr, int(hop_length), f0_onnx) if "hybrid" in f0_method else f0_methods.get(f0_method, lambda: ValueError(translations["method_not_valid"]))()
|
| 152 |
-
|
| 153 |
-
def get_pm(self, x):
|
| 154 |
-
import parselmouth
|
| 155 |
-
|
| 156 |
-
f0 = (parselmouth.Sound(x, self.fs).to_pitch_ac(time_step=(160 / 16000 * 1000) / 1000, voicing_threshold=0.6, pitch_floor=50, pitch_ceiling=1100).selected_array["frequency"])
|
| 157 |
-
pad_size = ((x.size // self.hop) - len(f0) + 1) // 2
|
| 158 |
-
|
| 159 |
-
if pad_size > 0 or (x.size // self.hop) - len(f0) - pad_size > 0: f0 = np.pad(f0, [[pad_size, (x.size // self.hop) - len(f0) - pad_size]], mode="constant")
|
| 160 |
-
return f0
|
| 161 |
-
|
| 162 |
-
def get_mangio_crepe(self, x, hop_length, model="full", onnx=False):
|
| 163 |
-
from main.library.predictors.CREPE import predict
|
| 164 |
-
|
| 165 |
-
audio = torch.from_numpy(x.astype(np.float32)).to(self.device)
|
| 166 |
-
audio /= torch.quantile(torch.abs(audio), 0.999)
|
| 167 |
-
audio = audio.unsqueeze(0)
|
| 168 |
-
source = predict(audio, self.fs, hop_length, self.f0_min, self.f0_max, model=model, batch_size=hop_length * 2, device=self.device, pad=True, providers=get_providers(), onnx=onnx).squeeze(0).cpu().float().numpy()
|
| 169 |
-
source[source < 0.001] = np.nan
|
| 170 |
-
|
| 171 |
-
return np.nan_to_num(np.interp(np.arange(0, len(source) * (x.size // self.hop), len(source)) / (x.size // self.hop), np.arange(0, len(source)), source))
|
| 172 |
-
|
| 173 |
-
def get_crepe(self, x, model="full", onnx=False):
|
| 174 |
-
from main.library.predictors.CREPE import predict, mean, median
|
| 175 |
-
|
| 176 |
-
f0, pd = predict(torch.tensor(np.copy(x))[None].float(), self.fs, 160, self.f0_min, self.f0_max, model, batch_size=512, device=self.device, return_periodicity=True, providers=get_providers(), onnx=onnx)
|
| 177 |
-
f0, pd = mean(f0, 3), median(pd, 3)
|
| 178 |
-
f0[pd < 0.1] = 0
|
| 179 |
-
|
| 180 |
-
return f0[0].cpu().numpy()
|
| 181 |
-
|
| 182 |
-
def get_fcpe(self, x, hop_length, legacy=False, onnx=False):
|
| 183 |
-
from main.library.predictors.FCPE import FCPE
|
| 184 |
-
|
| 185 |
-
model_fcpe = FCPE(os.path.join("assets", "models", "predictors", ("fcpe_legacy" if legacy else"fcpe") + (".onnx" if onnx else ".pt")), hop_length=int(hop_length), f0_min=int(self.f0_min), f0_max=int(self.f0_max), dtype=torch.float32, device=self.device, sample_rate=self.fs, threshold=0.03 if legacy else 0.006, providers=get_providers(), onnx=onnx, legacy=legacy)
|
| 186 |
-
f0 = model_fcpe.compute_f0(x, p_len=(x.size // self.hop))
|
| 187 |
-
|
| 188 |
-
del model_fcpe
|
| 189 |
-
return f0
|
| 190 |
-
|
| 191 |
-
def get_rmvpe(self, x, legacy=False, onnx=False):
|
| 192 |
-
from main.library.predictors.RMVPE import RMVPE
|
| 193 |
-
|
| 194 |
-
rmvpe_model = RMVPE(os.path.join("assets", "models", "predictors", "rmvpe" + (".onnx" if onnx else ".pt")), is_half=self.is_half, device=self.device, onnx=onnx, providers=get_providers())
|
| 195 |
-
f0 = rmvpe_model.infer_from_audio_with_pitch(x, thred=0.03, f0_min=self.f0_min, f0_max=self.f0_max) if legacy else rmvpe_model.infer_from_audio(x, thred=0.03)
|
| 196 |
-
|
| 197 |
-
del rmvpe_model
|
| 198 |
-
return f0
|
| 199 |
-
|
| 200 |
-
def get_pyworld(self, x, model="harvest"):
|
| 201 |
-
from main.library.predictors.WORLD_WRAPPER import PYWORLD
|
| 202 |
-
|
| 203 |
-
pw = PYWORLD()
|
| 204 |
-
x = x.astype(np.double)
|
| 205 |
-
|
| 206 |
-
if model == "harvest": f0, t = pw.harvest(x, fs=self.fs, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=1000 * self.hop / self.fs)
|
| 207 |
-
elif model == "dio": f0, t = pw.dio(x, fs=self.fs, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=1000 * self.hop / self.fs)
|
| 208 |
-
else: raise ValueError(translations["method_not_valid"])
|
| 209 |
-
|
| 210 |
-
return pw.stonemask(x, self.fs, t, f0)
|
| 211 |
-
|
| 212 |
-
def get_swipe(self, x):
|
| 213 |
-
from main.library.predictors.SWIPE import swipe
|
| 214 |
-
|
| 215 |
-
f0, _ = swipe(x.astype(np.float32), self.fs, f0_floor=self.f0_min, f0_ceil=self.f0_max, frame_period=1000 * self.hop / self.fs)
|
| 216 |
-
return f0
|
| 217 |
-
|
| 218 |
-
def get_yin(self, x, hop_length, mode="yin"):
|
| 219 |
-
import librosa
|
| 220 |
-
|
| 221 |
-
source = np.array(librosa.yin(x.astype(np.float32), sr=self.fs, fmin=self.f0_min, fmax=self.f0_max, hop_length=hop_length) if mode == "yin" else librosa.pyin(x.astype(np.float32), fmin=self.f0_min, fmax=self.f0_max, sr=self.fs, hop_length=hop_length)[0])
|
| 222 |
-
source[source < 0.001] = np.nan
|
| 223 |
-
return np.nan_to_num(np.interp(np.arange(0, len(source) * (x.size // self.hop), len(source)) / (x.size // self.hop), np.arange(0, len(source)), source))
|
| 224 |
-
|
| 225 |
-
def coarse_f0(self, f0):
|
| 226 |
-
return np.rint(np.clip(((1127 * np.log(1 + f0 / 700)) - self.f0_mel_min) * (self.f0_bin - 2) / (self.f0_mel_max - self.f0_mel_min) + 1, 1, self.f0_bin - 1)).astype(int)
|
| 227 |
-
|
| 228 |
-
def process_file(self, file_info, f0_method, hop_length, f0_onnx):
|
| 229 |
-
inp_path, opt_path1, opt_path2, np_arr = file_info
|
| 230 |
-
if os.path.exists(opt_path1 + ".npy") and os.path.exists(opt_path2 + ".npy"): return
|
| 231 |
-
|
| 232 |
-
try:
|
| 233 |
-
feature_pit = self.compute_f0(np_arr, f0_method, hop_length, f0_onnx)
|
| 234 |
-
if isinstance(feature_pit, tuple): feature_pit = feature_pit[0]
|
| 235 |
-
np.save(opt_path2, feature_pit, allow_pickle=False)
|
| 236 |
-
np.save(opt_path1, self.coarse_f0(feature_pit), allow_pickle=False)
|
| 237 |
-
except Exception as e:
|
| 238 |
-
raise RuntimeError(f"{translations['extract_file_error']} {inp_path}: {e}")
|
| 239 |
-
|
| 240 |
-
def process_files(self, files, f0_method, hop_length, f0_onnx, pbar):
|
| 241 |
-
for file_info in files:
|
| 242 |
-
self.process_file(file_info, f0_method, hop_length, f0_onnx)
|
| 243 |
-
pbar.update()
|
| 244 |
-
|
| 245 |
-
def run_pitch_extraction(exp_dir, f0_method, hop_length, num_processes, gpus, f0_onnx, is_half):
|
| 246 |
-
input_root, *output_roots = setup_paths(exp_dir)
|
| 247 |
-
output_root1, output_root2 = output_roots if len(output_roots) == 2 else (output_roots[0], None)
|
| 248 |
-
|
| 249 |
-
paths = [(os.path.join(input_root, name), os.path.join(output_root1, name) if output_root1 else None, os.path.join(output_root2, name) if output_root2 else None, load_audio(logger, os.path.join(input_root, name), 16000)) for name in sorted(os.listdir(input_root)) if "spec" not in name]
|
| 250 |
-
logger.info(translations["extract_f0_method"].format(num_processes=num_processes, f0_method=f0_method))
|
| 251 |
-
|
| 252 |
-
start_time = time.time()
|
| 253 |
-
gpus = gpus.split("-")
|
| 254 |
-
process_partials = []
|
| 255 |
-
|
| 256 |
-
pbar = tqdm.tqdm(total=len(paths), ncols=100, unit="p")
|
| 257 |
-
for idx, gpu in enumerate(gpus):
|
| 258 |
-
feature_input = FeatureInput(device=get_device(gpu) if gpu != "" else "cpu", is_half=is_half)
|
| 259 |
-
process_partials.append((feature_input, paths[idx::len(gpus)]))
|
| 260 |
-
|
| 261 |
-
with ThreadPoolExecutor(max_workers=num_processes) as executor:
|
| 262 |
-
for future in as_completed([executor.submit(FeatureInput.process_files, feature_input, part_paths, f0_method, hop_length, f0_onnx, pbar) for feature_input, part_paths in process_partials]):
|
| 263 |
-
pbar.update(1)
|
| 264 |
-
logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
|
| 265 |
-
future.result()
|
| 266 |
-
|
| 267 |
-
pbar.close()
|
| 268 |
-
logger.info(translations["extract_f0_success"].format(elapsed_time=f"{(time.time() - start_time):.2f}"))
|
| 269 |
-
|
| 270 |
-
def extract_features(model, feats, version):
|
| 271 |
-
return torch.as_tensor(model.run([model.get_outputs()[0].name, model.get_outputs()[1].name], {"feats": feats.detach().cpu().numpy()})[0 if version == "v1" else 1], dtype=torch.float32, device=feats.device)
|
| 272 |
-
|
| 273 |
-
def process_file_embedding(file, wav_path, out_path, model, device, version, saved_cfg, embed_suffix, is_half):
|
| 274 |
-
out_file_path = os.path.join(out_path, file.replace("wav", "npy"))
|
| 275 |
-
if os.path.exists(out_file_path): return
|
| 276 |
-
feats = read_wave(os.path.join(wav_path, file), normalize=saved_cfg.task.normalize if saved_cfg else False, is_half=is_half).to(device)
|
| 277 |
-
|
| 278 |
-
with torch.no_grad():
|
| 279 |
-
if embed_suffix == ".pt":
|
| 280 |
-
model = model.to(device).to(torch.float16 if is_half else torch.float32).eval()
|
| 281 |
-
logits = model.extract_features(**{"source": feats, "padding_mask": torch.BoolTensor(feats.shape).fill_(False).to(device), "output_layer": 9 if version == "v1" else 12})
|
| 282 |
-
feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
|
| 283 |
-
elif embed_suffix == ".onnx": feats = extract_features(model, feats, version).to(device)
|
| 284 |
-
elif embed_suffix == ".safetensors":
|
| 285 |
-
model = model.to(device).to(torch.float16 if is_half else torch.float32).eval()
|
| 286 |
-
logits = model(feats)["last_hidden_state"]
|
| 287 |
-
feats = (model.final_proj(logits[0]).unsqueeze(0) if version == "v1" else logits)
|
| 288 |
-
else: raise ValueError(translations["option_not_valid"])
|
| 289 |
-
|
| 290 |
-
feats = feats.squeeze(0).float().cpu().numpy()
|
| 291 |
-
if not np.isnan(feats).any(): np.save(out_file_path, feats, allow_pickle=False)
|
| 292 |
-
else: logger.warning(f"{file} {translations['NaN']}")
|
| 293 |
-
|
| 294 |
-
def run_embedding_extraction(exp_dir, version, gpus, embedder_model, embedders_mode, is_half):
|
| 295 |
-
wav_path, out_path = setup_paths(exp_dir, version)
|
| 296 |
-
logger.info(translations["start_extract_hubert"])
|
| 297 |
-
start_time = time.time()
|
| 298 |
-
models, saved_cfg, embed_suffix = load_embedders_model(embedder_model, embedders_mode, providers=get_providers())
|
| 299 |
-
devices = [get_device(gpu) for gpu in (gpus.split("-") if gpus != "-" else ["cpu"])]
|
| 300 |
-
paths = sorted([file for file in os.listdir(wav_path) if file.endswith(".wav")])
|
| 301 |
-
|
| 302 |
-
if not paths:
|
| 303 |
-
logger.warning(translations["not_found_audio_file"])
|
| 304 |
-
sys.exit(1)
|
| 305 |
-
|
| 306 |
-
pbar = tqdm.tqdm(total=len(paths) * len(devices), ncols=100, unit="p")
|
| 307 |
-
for task in [(file, wav_path, out_path, models, device, version, saved_cfg, embed_suffix, is_half) for file in paths for device in devices]:
|
| 308 |
-
try:
|
| 309 |
-
process_file_embedding(*task)
|
| 310 |
-
except Exception as e:
|
| 311 |
-
raise RuntimeError(f"{translations['process_error']} {task[0]}: {e}")
|
| 312 |
-
|
| 313 |
-
pbar.update(1)
|
| 314 |
-
logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
|
| 315 |
-
|
| 316 |
-
pbar.close()
|
| 317 |
-
logger.info(translations["extract_hubert_success"].format(elapsed_time=f"{(time.time() - start_time):.2f}"))
|
| 318 |
-
|
| 319 |
-
def main():
|
| 320 |
-
args = parse_arguments()
|
| 321 |
-
exp_dir = os.path.join("assets", "logs", args.model_name)
|
| 322 |
-
f0_method, hop_length, num_processes, gpus, version, pitch_guidance, sample_rate, embedder_model, f0_onnx, embedders_mode = args.f0_method, args.hop_length, args.cpu_cores, args.gpu, args.rvc_version, args.pitch_guidance, args.sample_rate, args.embedder_model, args.f0_onnx, args.embedders_mode
|
| 323 |
-
|
| 324 |
-
check_predictors(f0_method, f0_onnx); check_embedders(embedder_model, embedders_mode)
|
| 325 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
| 326 |
-
else:
|
| 327 |
-
console_handler = logging.StreamHandler()
|
| 328 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
| 329 |
-
console_handler.setFormatter(console_formatter)
|
| 330 |
-
console_handler.setLevel(logging.INFO)
|
| 331 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join(exp_dir, "extract.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
| 332 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
| 333 |
-
file_handler.setFormatter(file_formatter)
|
| 334 |
-
file_handler.setLevel(logging.DEBUG)
|
| 335 |
-
logger.addHandler(console_handler)
|
| 336 |
-
logger.addHandler(file_handler)
|
| 337 |
-
logger.setLevel(logging.DEBUG)
|
| 338 |
-
|
| 339 |
-
log_data = {translations['modelname']: args.model_name, translations['export_process']: exp_dir, translations['f0_method']: f0_method, translations['pretrain_sr']: sample_rate, translations['cpu_core']: num_processes, "Gpu": gpus, "Hop length": hop_length, translations['training_version']: version, translations['extract_f0']: pitch_guidance, translations['hubert_model']: embedder_model, translations["f0_onnx_mode"]: f0_onnx, translations["embed_mode"]: embedders_mode}
|
| 340 |
-
for key, value in log_data.items():
|
| 341 |
-
logger.debug(f"{key}: {value}")
|
| 342 |
-
|
| 343 |
-
pid_path = os.path.join(exp_dir, "extract_pid.txt")
|
| 344 |
-
with open(pid_path, "w") as pid_file:
|
| 345 |
-
pid_file.write(str(os.getpid()))
|
| 346 |
-
|
| 347 |
-
try:
|
| 348 |
-
run_pitch_extraction(exp_dir, f0_method, hop_length, num_processes, gpus, f0_onnx, config.is_half)
|
| 349 |
-
run_embedding_extraction(exp_dir, version, gpus, embedder_model, embedders_mode, config.is_half)
|
| 350 |
-
generate_config(version, sample_rate, exp_dir)
|
| 351 |
-
generate_filelist(pitch_guidance, exp_dir, version, sample_rate)
|
| 352 |
-
except Exception as e:
|
| 353 |
-
logger.error(f"{translations['extract_error']}: {e}")
|
| 354 |
-
import traceback
|
| 355 |
-
logger.debug(traceback.format_exc())
|
| 356 |
-
|
| 357 |
-
if os.path.exists(pid_path): os.remove(pid_path)
|
| 358 |
-
logger.info(f"{translations['extract_success']} {args.model_name}.")
|
| 359 |
-
|
| 360 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/preprocess.py
DELETED
|
@@ -1,270 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import time
|
| 4 |
-
import logging
|
| 5 |
-
import librosa
|
| 6 |
-
import argparse
|
| 7 |
-
import logging.handlers
|
| 8 |
-
|
| 9 |
-
import numpy as np
|
| 10 |
-
|
| 11 |
-
from tqdm import tqdm
|
| 12 |
-
from scipy import signal
|
| 13 |
-
from scipy.io import wavfile
|
| 14 |
-
from distutils.util import strtobool
|
| 15 |
-
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 16 |
-
|
| 17 |
-
sys.path.append(os.getcwd())
|
| 18 |
-
|
| 19 |
-
from main.library.utils import load_audio
|
| 20 |
-
from main.configs.config import Config
|
| 21 |
-
|
| 22 |
-
logger = logging.getLogger(__name__)
|
| 23 |
-
for l in ["numba.core.byteflow", "numba.core.ssa", "numba.core.interpreter"]:
|
| 24 |
-
logging.getLogger(l).setLevel(logging.ERROR)
|
| 25 |
-
|
| 26 |
-
OVERLAP, MAX_AMPLITUDE, ALPHA, HIGH_PASS_CUTOFF, SAMPLE_RATE_16K = 0.3, 0.9, 0.75, 48, 16000
|
| 27 |
-
|
| 28 |
-
config = Config()
|
| 29 |
-
translations = config.translations
|
| 30 |
-
|
| 31 |
-
def parse_arguments():
|
| 32 |
-
parser = argparse.ArgumentParser()
|
| 33 |
-
parser.add_argument("--model_name", type=str, required=True)
|
| 34 |
-
parser.add_argument("--dataset_path", type=str, default="./dataset")
|
| 35 |
-
parser.add_argument("--sample_rate", type=int, required=True)
|
| 36 |
-
parser.add_argument("--cpu_cores", type=int, default=2)
|
| 37 |
-
parser.add_argument("--cut_preprocess", type=lambda x: bool(strtobool(x)), default=True)
|
| 38 |
-
parser.add_argument("--process_effects", type=lambda x: bool(strtobool(x)), default=False)
|
| 39 |
-
parser.add_argument("--clean_dataset", type=lambda x: bool(strtobool(x)), default=False)
|
| 40 |
-
parser.add_argument("--clean_strength", type=float, default=0.7)
|
| 41 |
-
|
| 42 |
-
return parser.parse_args()
|
| 43 |
-
|
| 44 |
-
class Slicer:
|
| 45 |
-
def __init__(self, sr, threshold = -40.0, min_length = 5000, min_interval = 300, hop_size = 20, max_sil_kept = 5000):
|
| 46 |
-
if not min_length >= min_interval >= hop_size: raise ValueError(translations["min_length>=min_interval>=hop_size"])
|
| 47 |
-
if not max_sil_kept >= hop_size: raise ValueError(translations["max_sil_kept>=hop_size"])
|
| 48 |
-
|
| 49 |
-
min_interval = sr * min_interval / 1000
|
| 50 |
-
self.threshold = 10 ** (threshold / 20.0)
|
| 51 |
-
self.hop_size = round(sr * hop_size / 1000)
|
| 52 |
-
self.win_size = min(round(min_interval), 4 * self.hop_size)
|
| 53 |
-
self.min_length = round(sr * min_length / 1000 / self.hop_size)
|
| 54 |
-
self.min_interval = round(min_interval / self.hop_size)
|
| 55 |
-
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
|
| 56 |
-
|
| 57 |
-
def _apply_slice(self, waveform, begin, end):
|
| 58 |
-
start_idx = begin * self.hop_size
|
| 59 |
-
|
| 60 |
-
if len(waveform.shape) > 1: return waveform[:, start_idx:min(waveform.shape[1], end * self.hop_size)]
|
| 61 |
-
else: return waveform[start_idx:min(waveform.shape[0], end * self.hop_size)]
|
| 62 |
-
|
| 63 |
-
def slice(self, waveform):
|
| 64 |
-
samples = waveform.mean(axis=0) if len(waveform.shape) > 1 else waveform
|
| 65 |
-
if samples.shape[0] <= self.min_length: return [waveform]
|
| 66 |
-
rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
| 67 |
-
sil_tags = []
|
| 68 |
-
silence_start, clip_start = None, 0
|
| 69 |
-
|
| 70 |
-
for i, rms in enumerate(rms_list):
|
| 71 |
-
if rms < self.threshold:
|
| 72 |
-
if silence_start is None: silence_start = i
|
| 73 |
-
continue
|
| 74 |
-
|
| 75 |
-
if silence_start is None: continue
|
| 76 |
-
|
| 77 |
-
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
| 78 |
-
need_slice_middle = (i - silence_start >= self.min_interval and i - clip_start >= self.min_length)
|
| 79 |
-
|
| 80 |
-
if not is_leading_silence and not need_slice_middle:
|
| 81 |
-
silence_start = None
|
| 82 |
-
continue
|
| 83 |
-
|
| 84 |
-
if i - silence_start <= self.max_sil_kept:
|
| 85 |
-
pos = rms_list[silence_start : i + 1].argmin() + silence_start
|
| 86 |
-
sil_tags.append((0, pos) if silence_start == 0 else (pos, pos))
|
| 87 |
-
clip_start = pos
|
| 88 |
-
elif i - silence_start <= self.max_sil_kept * 2:
|
| 89 |
-
pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
|
| 90 |
-
pos += i - self.max_sil_kept
|
| 91 |
-
pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
|
| 92 |
-
|
| 93 |
-
if silence_start == 0:
|
| 94 |
-
sil_tags.append((0, pos_r))
|
| 95 |
-
clip_start = pos_r
|
| 96 |
-
else:
|
| 97 |
-
sil_tags.append((min((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos), max(pos_r, pos)))
|
| 98 |
-
clip_start = max(pos_r, pos)
|
| 99 |
-
else:
|
| 100 |
-
pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
|
| 101 |
-
sil_tags.append((0, pos_r) if silence_start == 0 else ((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos_r))
|
| 102 |
-
clip_start = pos_r
|
| 103 |
-
|
| 104 |
-
silence_start = None
|
| 105 |
-
total_frames = rms_list.shape[0]
|
| 106 |
-
if (silence_start is not None and total_frames - silence_start >= self.min_interval): sil_tags.append((rms_list[silence_start : min(total_frames, silence_start + self.max_sil_kept) + 1].argmin() + silence_start, total_frames + 1))
|
| 107 |
-
|
| 108 |
-
if not sil_tags: return [waveform]
|
| 109 |
-
else:
|
| 110 |
-
chunks = []
|
| 111 |
-
if sil_tags[0][0] > 0: chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0]))
|
| 112 |
-
|
| 113 |
-
for i in range(len(sil_tags) - 1):
|
| 114 |
-
chunks.append(self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]))
|
| 115 |
-
|
| 116 |
-
if sil_tags[-1][1] < total_frames: chunks.append(self._apply_slice(waveform, sil_tags[-1][1], total_frames))
|
| 117 |
-
return chunks
|
| 118 |
-
|
| 119 |
-
def get_rms(y, frame_length=2048, hop_length=512, pad_mode="constant"):
|
| 120 |
-
y = np.pad(y, (int(frame_length // 2), int(frame_length // 2)), mode=pad_mode)
|
| 121 |
-
axis = -1
|
| 122 |
-
x_shape_trimmed = list(y.shape)
|
| 123 |
-
x_shape_trimmed[axis] -= frame_length - 1
|
| 124 |
-
xw = np.moveaxis(np.lib.stride_tricks.as_strided(y, shape=tuple(x_shape_trimmed) + tuple([frame_length]), strides=y.strides + tuple([y.strides[axis]])), -1, axis - 1 if axis < 0 else axis + 1)
|
| 125 |
-
slices = [slice(None)] * xw.ndim
|
| 126 |
-
slices[axis] = slice(0, None, hop_length)
|
| 127 |
-
return np.sqrt(np.mean(np.abs(xw[tuple(slices)]) ** 2, axis=-2, keepdims=True))
|
| 128 |
-
|
| 129 |
-
class PreProcess:
|
| 130 |
-
def __init__(self, sr, exp_dir, per):
|
| 131 |
-
self.slicer = Slicer(sr=sr, threshold=-42, min_length=1500, min_interval=400, hop_size=15, max_sil_kept=500)
|
| 132 |
-
self.sr = sr
|
| 133 |
-
self.b_high, self.a_high = signal.butter(N=5, Wn=HIGH_PASS_CUTOFF, btype="high", fs=self.sr)
|
| 134 |
-
self.per = per
|
| 135 |
-
self.exp_dir = exp_dir
|
| 136 |
-
self.device = "cpu"
|
| 137 |
-
self.gt_wavs_dir = os.path.join(exp_dir, "sliced_audios")
|
| 138 |
-
self.wavs16k_dir = os.path.join(exp_dir, "sliced_audios_16k")
|
| 139 |
-
os.makedirs(self.gt_wavs_dir, exist_ok=True)
|
| 140 |
-
os.makedirs(self.wavs16k_dir, exist_ok=True)
|
| 141 |
-
|
| 142 |
-
def _normalize_audio(self, audio):
|
| 143 |
-
tmp_max = np.abs(audio).max()
|
| 144 |
-
if tmp_max > 2.5: return None
|
| 145 |
-
return (audio / tmp_max * (MAX_AMPLITUDE * ALPHA)) + (1 - ALPHA) * audio
|
| 146 |
-
|
| 147 |
-
def process_audio_segment(self, normalized_audio, sid, idx0, idx1):
|
| 148 |
-
if normalized_audio is None:
|
| 149 |
-
logger.debug(f"{sid}-{idx0}-{idx1}-filtered")
|
| 150 |
-
return
|
| 151 |
-
|
| 152 |
-
wavfile.write(os.path.join(self.gt_wavs_dir, f"{sid}_{idx0}_{idx1}.wav"), self.sr, normalized_audio.astype(np.float32))
|
| 153 |
-
wavfile.write(os.path.join(self.wavs16k_dir, f"{sid}_{idx0}_{idx1}.wav"), SAMPLE_RATE_16K, librosa.resample(normalized_audio, orig_sr=self.sr, target_sr=SAMPLE_RATE_16K, res_type="soxr_vhq").astype(np.float32))
|
| 154 |
-
|
| 155 |
-
def process_audio(self, path, idx0, sid, cut_preprocess, process_effects, clean_dataset, clean_strength):
|
| 156 |
-
try:
|
| 157 |
-
audio = load_audio(logger, path, self.sr)
|
| 158 |
-
|
| 159 |
-
if process_effects:
|
| 160 |
-
audio = signal.lfilter(self.b_high, self.a_high, audio)
|
| 161 |
-
audio = self._normalize_audio(audio)
|
| 162 |
-
|
| 163 |
-
if clean_dataset:
|
| 164 |
-
from main.tools.noisereduce import reduce_noise
|
| 165 |
-
audio = reduce_noise(y=audio, sr=self.sr, prop_decrease=clean_strength, device=config.device)
|
| 166 |
-
|
| 167 |
-
idx1 = 0
|
| 168 |
-
if cut_preprocess:
|
| 169 |
-
for audio_segment in self.slicer.slice(audio):
|
| 170 |
-
i = 0
|
| 171 |
-
|
| 172 |
-
while 1:
|
| 173 |
-
start = int(self.sr * (self.per - OVERLAP) * i)
|
| 174 |
-
i += 1
|
| 175 |
-
|
| 176 |
-
if len(audio_segment[start:]) > (self.per + OVERLAP) * self.sr:
|
| 177 |
-
self.process_audio_segment(audio_segment[start : start + int(self.per * self.sr)], sid, idx0, idx1)
|
| 178 |
-
idx1 += 1
|
| 179 |
-
else:
|
| 180 |
-
self.process_audio_segment(audio_segment[start:], sid, idx0, idx1)
|
| 181 |
-
idx1 += 1
|
| 182 |
-
break
|
| 183 |
-
else: self.process_audio_segment(audio, sid, idx0, idx1)
|
| 184 |
-
except Exception as e:
|
| 185 |
-
raise RuntimeError(f"{translations['process_audio_error']}: {e}")
|
| 186 |
-
|
| 187 |
-
def process_file(args):
|
| 188 |
-
pp, file, cut_preprocess, process_effects, clean_dataset, clean_strength = (args)
|
| 189 |
-
file_path, idx0, sid = file
|
| 190 |
-
pp.process_audio(file_path, idx0, sid, cut_preprocess, process_effects, clean_dataset, clean_strength)
|
| 191 |
-
|
| 192 |
-
def preprocess_training_set(input_root, sr, num_processes, exp_dir, per, cut_preprocess, process_effects, clean_dataset, clean_strength):
|
| 193 |
-
start_time = time.time()
|
| 194 |
-
|
| 195 |
-
pp = PreProcess(sr, exp_dir, per)
|
| 196 |
-
logger.info(translations["start_preprocess"].format(num_processes=num_processes))
|
| 197 |
-
files = []
|
| 198 |
-
idx = 0
|
| 199 |
-
|
| 200 |
-
for root, _, filenames in os.walk(input_root):
|
| 201 |
-
try:
|
| 202 |
-
sid = 0 if root == input_root else int(os.path.basename(root))
|
| 203 |
-
|
| 204 |
-
for f in filenames:
|
| 205 |
-
if f.lower().endswith(("wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3")):
|
| 206 |
-
files.append((os.path.join(root, f), idx, sid))
|
| 207 |
-
idx += 1
|
| 208 |
-
except ValueError:
|
| 209 |
-
raise ValueError(f"{translations['not_integer']} '{os.path.basename(root)}'.")
|
| 210 |
-
|
| 211 |
-
with tqdm(total=len(files), ncols=100, unit="f") as pbar:
|
| 212 |
-
with ProcessPoolExecutor(max_workers=num_processes) as executor:
|
| 213 |
-
futures = [executor.submit(process_file, (pp, file, cut_preprocess, process_effects, clean_dataset, clean_strength)) for file in files]
|
| 214 |
-
for future in as_completed(futures):
|
| 215 |
-
try:
|
| 216 |
-
future.result()
|
| 217 |
-
except Exception as e:
|
| 218 |
-
raise RuntimeError(f"{translations['process_error']}: {e}")
|
| 219 |
-
pbar.update(1)
|
| 220 |
-
logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
|
| 221 |
-
|
| 222 |
-
elapsed_time = time.time() - start_time
|
| 223 |
-
logger.info(translations["preprocess_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
|
| 224 |
-
|
| 225 |
-
def main():
|
| 226 |
-
args = parse_arguments()
|
| 227 |
-
experiment_directory = os.path.join("assets", "logs", args.model_name)
|
| 228 |
-
|
| 229 |
-
num_processes = args.cpu_cores
|
| 230 |
-
num_processes = 2 if num_processes is None else int(num_processes)
|
| 231 |
-
|
| 232 |
-
dataset, sample_rate, cut_preprocess, preprocess_effects, clean_dataset, clean_strength = args.dataset_path, args.sample_rate, args.cut_preprocess, args.process_effects, args.clean_dataset, args.clean_strength
|
| 233 |
-
|
| 234 |
-
os.makedirs(experiment_directory, exist_ok=True)
|
| 235 |
-
|
| 236 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
| 237 |
-
else:
|
| 238 |
-
console_handler = logging.StreamHandler()
|
| 239 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
| 240 |
-
console_handler.setFormatter(console_formatter)
|
| 241 |
-
console_handler.setLevel(logging.INFO)
|
| 242 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join(experiment_directory, "preprocess.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
| 243 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
| 244 |
-
file_handler.setFormatter(file_formatter)
|
| 245 |
-
file_handler.setLevel(logging.DEBUG)
|
| 246 |
-
logger.addHandler(console_handler)
|
| 247 |
-
logger.addHandler(file_handler)
|
| 248 |
-
logger.setLevel(logging.DEBUG)
|
| 249 |
-
|
| 250 |
-
log_data = {translations['modelname']: args.model_name, translations['export_process']: experiment_directory, translations['dataset_folder']: dataset, translations['pretrain_sr']: sample_rate, translations['cpu_core']: num_processes, translations['split_audio']: cut_preprocess, translations['preprocess_effect']: preprocess_effects, translations['clear_audio']: clean_dataset}
|
| 251 |
-
if clean_dataset: log_data[translations['clean_strength']] = clean_strength
|
| 252 |
-
|
| 253 |
-
for key, value in log_data.items():
|
| 254 |
-
logger.debug(f"{key}: {value}")
|
| 255 |
-
|
| 256 |
-
pid_path = os.path.join(experiment_directory, "preprocess_pid.txt")
|
| 257 |
-
with open(pid_path, "w") as pid_file:
|
| 258 |
-
pid_file.write(str(os.getpid()))
|
| 259 |
-
|
| 260 |
-
try:
|
| 261 |
-
preprocess_training_set(dataset, sample_rate, num_processes, experiment_directory, config.per_preprocess, cut_preprocess, preprocess_effects, clean_dataset, clean_strength)
|
| 262 |
-
except Exception as e:
|
| 263 |
-
logger.error(f"{translations['process_audio_error']} {e}")
|
| 264 |
-
import traceback
|
| 265 |
-
logger.debug(traceback.format_exc())
|
| 266 |
-
|
| 267 |
-
if os.path.exists(pid_path): os.remove(pid_path)
|
| 268 |
-
logger.info(f"{translations['preprocess_model_success']} {args.model_name}")
|
| 269 |
-
|
| 270 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/separator_music.py
DELETED
|
@@ -1,310 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import time
|
| 4 |
-
import logging
|
| 5 |
-
import argparse
|
| 6 |
-
import logging.handlers
|
| 7 |
-
|
| 8 |
-
import numpy as np
|
| 9 |
-
|
| 10 |
-
from distutils.util import strtobool
|
| 11 |
-
|
| 12 |
-
sys.path.append(os.getcwd())
|
| 13 |
-
|
| 14 |
-
from main.configs.config import Config
|
| 15 |
-
from main.library.algorithm.separator import Separator
|
| 16 |
-
from main.library.utils import pydub_convert, pydub_load
|
| 17 |
-
|
| 18 |
-
config = Config()
|
| 19 |
-
translations = config.translations
|
| 20 |
-
logger = logging.getLogger(__name__)
|
| 21 |
-
|
| 22 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
| 23 |
-
else:
|
| 24 |
-
console_handler = logging.StreamHandler()
|
| 25 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
| 26 |
-
console_handler.setFormatter(console_formatter)
|
| 27 |
-
console_handler.setLevel(logging.INFO)
|
| 28 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "separator.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
| 29 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
| 30 |
-
file_handler.setFormatter(file_formatter)
|
| 31 |
-
file_handler.setLevel(logging.DEBUG)
|
| 32 |
-
logger.addHandler(console_handler)
|
| 33 |
-
logger.addHandler(file_handler)
|
| 34 |
-
logger.setLevel(logging.DEBUG)
|
| 35 |
-
|
| 36 |
-
demucs_models = {"HT-Tuned": "htdemucs_ft.yaml", "HT-Normal": "htdemucs.yaml", "HD_MMI": "hdemucs_mmi.yaml", "HT_6S": "htdemucs_6s.yaml"}
|
| 37 |
-
mdx_models = {"Main_340": "UVR-MDX-NET_Main_340.onnx", "Main_390": "UVR-MDX-NET_Main_390.onnx", "Main_406": "UVR-MDX-NET_Main_406.onnx", "Main_427": "UVR-MDX-NET_Main_427.onnx","Main_438": "UVR-MDX-NET_Main_438.onnx", "Inst_full_292": "UVR-MDX-NET-Inst_full_292.onnx", "Inst_HQ_1": "UVR-MDX-NET-Inst_HQ_1.onnx", "Inst_HQ_2": "UVR-MDX-NET-Inst_HQ_2.onnx", "Inst_HQ_3": "UVR-MDX-NET-Inst_HQ_3.onnx", "Inst_HQ_4": "UVR-MDX-NET-Inst_HQ_4.onnx", "Inst_HQ_5": "UVR-MDX-NET-Inst_HQ_5.onnx", "Kim_Vocal_1": "Kim_Vocal_1.onnx", "Kim_Vocal_2": "Kim_Vocal_2.onnx", "Kim_Inst": "Kim_Inst.onnx", "Inst_187_beta": "UVR-MDX-NET_Inst_187_beta.onnx", "Inst_82_beta": "UVR-MDX-NET_Inst_82_beta.onnx", "Inst_90_beta": "UVR-MDX-NET_Inst_90_beta.onnx", "Voc_FT": "UVR-MDX-NET-Voc_FT.onnx", "Crowd_HQ": "UVR-MDX-NET_Crowd_HQ_1.onnx", "MDXNET_9482": "UVR_MDXNET_9482.onnx", "Inst_1": "UVR-MDX-NET-Inst_1.onnx", "Inst_2": "UVR-MDX-NET-Inst_2.onnx", "Inst_3": "UVR-MDX-NET-Inst_3.onnx", "MDXNET_1_9703": "UVR_MDXNET_1_9703.onnx", "MDXNET_2_9682": "UVR_MDXNET_2_9682.onnx", "MDXNET_3_9662": "UVR_MDXNET_3_9662.onnx", "Inst_Main": "UVR-MDX-NET-Inst_Main.onnx", "MDXNET_Main": "UVR_MDXNET_Main.onnx"}
|
| 38 |
-
kara_models = {"Version-1": "UVR_MDXNET_KARA.onnx", "Version-2": "UVR_MDXNET_KARA_2.onnx"}
|
| 39 |
-
|
| 40 |
-
def parse_arguments():
|
| 41 |
-
parser = argparse.ArgumentParser()
|
| 42 |
-
parser.add_argument("--input_path", type=str, required=True)
|
| 43 |
-
parser.add_argument("--output_path", type=str, default="./audios")
|
| 44 |
-
parser.add_argument("--format", type=str, default="wav")
|
| 45 |
-
parser.add_argument("--shifts", type=int, default=2)
|
| 46 |
-
parser.add_argument("--segments_size", type=int, default=256)
|
| 47 |
-
parser.add_argument("--overlap", type=float, default=0.25)
|
| 48 |
-
parser.add_argument("--mdx_hop_length", type=int, default=1024)
|
| 49 |
-
parser.add_argument("--mdx_batch_size", type=int, default=1)
|
| 50 |
-
parser.add_argument("--clean_audio", type=lambda x: bool(strtobool(x)), default=False)
|
| 51 |
-
parser.add_argument("--clean_strength", type=float, default=0.7)
|
| 52 |
-
parser.add_argument("--model_name", type=str, default="HT-Normal")
|
| 53 |
-
parser.add_argument("--kara_model", type=str, default="Version-1")
|
| 54 |
-
parser.add_argument("--backing", type=lambda x: bool(strtobool(x)), default=False)
|
| 55 |
-
parser.add_argument("--mdx_denoise", type=lambda x: bool(strtobool(x)), default=False)
|
| 56 |
-
parser.add_argument("--reverb", type=lambda x: bool(strtobool(x)), default=False)
|
| 57 |
-
parser.add_argument("--backing_reverb", type=lambda x: bool(strtobool(x)), default=False)
|
| 58 |
-
parser.add_argument("--sample_rate", type=int, default=44100)
|
| 59 |
-
|
| 60 |
-
return parser.parse_args()
|
| 61 |
-
|
| 62 |
-
def main():
|
| 63 |
-
start_time = time.time()
|
| 64 |
-
pid_path = os.path.join("assets", "separate_pid.txt")
|
| 65 |
-
|
| 66 |
-
with open(pid_path, "w") as pid_file:
|
| 67 |
-
pid_file.write(str(os.getpid()))
|
| 68 |
-
|
| 69 |
-
try:
|
| 70 |
-
args = parse_arguments()
|
| 71 |
-
input_path, output_path, export_format, shifts, segments_size, overlap, hop_length, batch_size, clean_audio, clean_strength, model_name, kara_model, backing, mdx_denoise, reverb, backing_reverb, sample_rate = args.input_path, args.output_path, args.format, args.shifts, args.segments_size, args.overlap, args.mdx_hop_length, args.mdx_batch_size, args.clean_audio, args.clean_strength, args.model_name, args.kara_model, args.backing, args.mdx_denoise, args.reverb, args.backing_reverb, args.sample_rate
|
| 72 |
-
|
| 73 |
-
if backing_reverb and not reverb:
|
| 74 |
-
logger.warning(translations["turn_on_dereverb"])
|
| 75 |
-
sys.exit(1)
|
| 76 |
-
|
| 77 |
-
if backing_reverb and not backing:
|
| 78 |
-
logger.warning(translations["turn_on_separator_backing"])
|
| 79 |
-
sys.exit(1)
|
| 80 |
-
|
| 81 |
-
input_path = input_path.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
|
| 82 |
-
output_path = os.path.dirname(output_path) or output_path
|
| 83 |
-
|
| 84 |
-
log_data = {translations['audio_path']: input_path, translations['output_path']: output_path, translations['export_format']: export_format, translations['shift']: shifts, translations['segments_size']: segments_size, translations['overlap']: overlap, translations['modelname']: model_name, translations['denoise_mdx']: mdx_denoise, "Hop length": hop_length, translations['batch_size']: batch_size, translations['sr']: sample_rate}
|
| 85 |
-
|
| 86 |
-
if clean_audio:
|
| 87 |
-
log_data[translations['clear_audio']] = clean_audio
|
| 88 |
-
log_data[translations['clean_strength']] = clean_strength
|
| 89 |
-
|
| 90 |
-
if backing:
|
| 91 |
-
log_data[translations['backing_model_ver']] = kara_model
|
| 92 |
-
log_data[translations['separator_backing']] = backing
|
| 93 |
-
|
| 94 |
-
if reverb:
|
| 95 |
-
log_data[translations['dereveb_audio']] = reverb
|
| 96 |
-
log_data[translations['dereveb_backing']] = backing_reverb
|
| 97 |
-
|
| 98 |
-
for key, value in log_data.items():
|
| 99 |
-
logger.debug(f"{key}: {value}")
|
| 100 |
-
|
| 101 |
-
if os.path.isdir(input_path):
|
| 102 |
-
for f in input_path:
|
| 103 |
-
separation(f, output_path, export_format, shifts, overlap, segments_size, model_name, sample_rate, mdx_denoise, hop_length, batch_size, backing, reverb, kara_model, backing_reverb, clean_audio, clean_strength)
|
| 104 |
-
else: separation(input_path, output_path, export_format, shifts, overlap, segments_size, model_name, sample_rate, mdx_denoise, hop_length, batch_size, backing, reverb, kara_model, backing_reverb, clean_audio, clean_strength)
|
| 105 |
-
|
| 106 |
-
except Exception as e:
|
| 107 |
-
logger.error(f"{translations['separator_error']}: {e}")
|
| 108 |
-
import traceback
|
| 109 |
-
logger.debug(traceback.format_exc())
|
| 110 |
-
|
| 111 |
-
if os.path.exists(pid_path): os.remove(pid_path)
|
| 112 |
-
elapsed_time = time.time() - start_time
|
| 113 |
-
logger.info(translations["separator_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
|
| 114 |
-
|
| 115 |
-
def separation(input_path, output_path, export_format, shifts, overlap, segments_size, model_name, sample_rate, mdx_denoise, hop_length, batch_size, backing, reverb, kara_model, backing_reverb, clean_audio, clean_strength):
|
| 116 |
-
filename, _ = os.path.splitext(os.path.basename(input_path))
|
| 117 |
-
output_path = os.path.join(output_path, filename)
|
| 118 |
-
os.makedirs(output_path, exist_ok=True)
|
| 119 |
-
|
| 120 |
-
if model_name in ["HT-Tuned", "HT-Normal", "HD_MMI", "HT_6S"]: vocals, _ = separator_music_demucs(input_path, output_path, export_format, shifts, overlap, segments_size, model_name, sample_rate)
|
| 121 |
-
else: vocals, _ = separator_music_mdx(input_path, output_path, export_format, segments_size, overlap, mdx_denoise, model_name, hop_length, batch_size, sample_rate)
|
| 122 |
-
|
| 123 |
-
if backing: main_vocals, backing_vocals = separator_backing(vocals, output_path, export_format, segments_size, overlap, mdx_denoise, kara_model, hop_length, batch_size, sample_rate)
|
| 124 |
-
if reverb: vocals_no_reverb, main_vocals_no_reverb, backing_vocals_no_reverb = separator_reverb(output_path, export_format, segments_size, overlap, mdx_denoise, reverb, backing_reverb, hop_length, batch_size, sample_rate)
|
| 125 |
-
|
| 126 |
-
original_output = os.path.join(output_path, f"Original_Vocals_No_Reverb.{export_format}") if reverb else os.path.join(output_path, f"Original_Vocals.{export_format}")
|
| 127 |
-
main_output = os.path.join(output_path, f"Main_Vocals_No_Reverb.{export_format}") if reverb and backing_reverb else os.path.join(output_path, f"Main_Vocals.{export_format}")
|
| 128 |
-
backing_output = os.path.join(output_path, f"Backing_Vocals_No_Reverb.{export_format}") if reverb and backing_reverb else os.path.join(output_path, f"Backing_Vocals.{export_format}")
|
| 129 |
-
|
| 130 |
-
if clean_audio:
|
| 131 |
-
import soundfile as sf
|
| 132 |
-
|
| 133 |
-
logger.info(f"{translations['clear_audio']}...")
|
| 134 |
-
vocal_data, vocal_sr = sf.read(vocals_no_reverb if reverb else vocals, dtype=np.float32)
|
| 135 |
-
|
| 136 |
-
from main.tools.noisereduce import reduce_noise
|
| 137 |
-
sf.write(original_output, reduce_noise(y=vocal_data, sr=vocal_sr, prop_decrease=clean_strength), vocal_sr, format=export_format, device=config.device)
|
| 138 |
-
|
| 139 |
-
if backing:
|
| 140 |
-
main_data, main_sr = sf.read(main_vocals_no_reverb if reverb and backing else main_vocals, dtype=np.float32)
|
| 141 |
-
backing_data, backing_sr = sf.read(backing_vocals_no_reverb if reverb and backing_reverb else backing_vocals, dtype=np.float32)
|
| 142 |
-
|
| 143 |
-
sf.write(main_output, reduce_noise(y=main_data, sr=main_sr, prop_decrease=clean_strength), main_sr, format=export_format, device=config.device)
|
| 144 |
-
sf.write(backing_output, reduce_noise(y=backing_data, sr=backing_sr, prop_decrease=clean_strength), backing_sr, format=export_format, device=config.device)
|
| 145 |
-
|
| 146 |
-
logger.info(translations["clean_audio_success"])
|
| 147 |
-
|
| 148 |
-
def separator_music_demucs(input, output, format, shifts, overlap, segments_size, demucs_model, sample_rate):
|
| 149 |
-
if not os.path.exists(input):
|
| 150 |
-
logger.warning(translations["input_not_valid"])
|
| 151 |
-
sys.exit(1)
|
| 152 |
-
|
| 153 |
-
if not os.path.exists(output):
|
| 154 |
-
logger.warning(translations["output_not_valid"])
|
| 155 |
-
sys.exit(1)
|
| 156 |
-
|
| 157 |
-
for i in [f"Original_Vocals.{format}", f"Instruments.{format}"]:
|
| 158 |
-
if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
|
| 159 |
-
|
| 160 |
-
logger.info(f"{translations['separator_process_2']}...")
|
| 161 |
-
demucs_output = separator_main(audio_file=input, model_filename=demucs_models.get(demucs_model), output_format=format, output_dir=output, demucs_segment_size=(segments_size / 2), demucs_shifts=shifts, demucs_overlap=overlap, sample_rate=sample_rate)
|
| 162 |
-
|
| 163 |
-
for f in demucs_output:
|
| 164 |
-
path = os.path.join(output, f)
|
| 165 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
| 166 |
-
|
| 167 |
-
if '_(Drums)_' in f: drums = path
|
| 168 |
-
elif '_(Bass)_' in f: bass = path
|
| 169 |
-
elif '_(Other)_' in f: other = path
|
| 170 |
-
elif '_(Vocals)_' in f: os.rename(path, os.path.join(output, f"Original_Vocals.{format}"))
|
| 171 |
-
|
| 172 |
-
pydub_convert(pydub_load(drums)).overlay(pydub_convert(pydub_load(bass))).overlay(pydub_convert(pydub_load(other))).export(os.path.join(output, f"Instruments.{format}"), format=format)
|
| 173 |
-
|
| 174 |
-
for f in [drums, bass, other]:
|
| 175 |
-
if os.path.exists(f): os.remove(f)
|
| 176 |
-
|
| 177 |
-
logger.info(translations["separator_success_2"])
|
| 178 |
-
return os.path.join(output, f"Original_Vocals.{format}"), os.path.join(output, f"Instruments.{format}")
|
| 179 |
-
|
| 180 |
-
def separator_backing(input, output, format, segments_size, overlap, denoise, kara_model, hop_length, batch_size, sample_rate):
|
| 181 |
-
if not os.path.exists(input):
|
| 182 |
-
logger.warning(translations["input_not_valid"])
|
| 183 |
-
sys.exit(1)
|
| 184 |
-
|
| 185 |
-
if not os.path.exists(output):
|
| 186 |
-
logger.warning(translations["output_not_valid"])
|
| 187 |
-
sys.exit(1)
|
| 188 |
-
|
| 189 |
-
for f in [f"Main_Vocals.{format}", f"Backing_Vocals.{format}"]:
|
| 190 |
-
if os.path.exists(os.path.join(output, f)): os.remove(os.path.join(output, f))
|
| 191 |
-
|
| 192 |
-
model_2 = kara_models.get(kara_model)
|
| 193 |
-
logger.info(f"{translations['separator_process_backing']}...")
|
| 194 |
-
|
| 195 |
-
backing_outputs = separator_main(audio_file=input, model_filename=model_2, output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate)
|
| 196 |
-
main_output = os.path.join(output, f"Main_Vocals.{format}")
|
| 197 |
-
backing_output = os.path.join(output, f"Backing_Vocals.{format}")
|
| 198 |
-
|
| 199 |
-
for f in backing_outputs:
|
| 200 |
-
path = os.path.join(output, f)
|
| 201 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
| 202 |
-
|
| 203 |
-
if '_(Instrumental)_' in f: os.rename(path, backing_output)
|
| 204 |
-
elif '_(Vocals)_' in f: os.rename(path, main_output)
|
| 205 |
-
|
| 206 |
-
logger.info(translations["separator_process_backing_success"])
|
| 207 |
-
return main_output, backing_output
|
| 208 |
-
|
| 209 |
-
def separator_music_mdx(input, output, format, segments_size, overlap, denoise, mdx_model, hop_length, batch_size, sample_rate):
|
| 210 |
-
if not os.path.exists(input):
|
| 211 |
-
logger.warning(translations["input_not_valid"])
|
| 212 |
-
sys.exit(1)
|
| 213 |
-
|
| 214 |
-
if not os.path.exists(output):
|
| 215 |
-
logger.warning(translations["output_not_valid"])
|
| 216 |
-
sys.exit(1)
|
| 217 |
-
|
| 218 |
-
for i in [f"Original_Vocals.{format}", f"Instruments.{format}"]:
|
| 219 |
-
if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
|
| 220 |
-
|
| 221 |
-
model_3 = mdx_models.get(mdx_model)
|
| 222 |
-
logger.info(f"{translations['separator_process_2']}...")
|
| 223 |
-
|
| 224 |
-
output_music = separator_main(audio_file=input, model_filename=model_3, output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate)
|
| 225 |
-
original_output, instruments_output = os.path.join(output, f"Original_Vocals.{format}"), os.path.join(output, f"Instruments.{format}")
|
| 226 |
-
|
| 227 |
-
for f in output_music:
|
| 228 |
-
path = os.path.join(output, f)
|
| 229 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
| 230 |
-
|
| 231 |
-
if '_(Instrumental)_' in f: os.rename(path, instruments_output)
|
| 232 |
-
elif '_(Vocals)_' in f: os.rename(path, original_output)
|
| 233 |
-
|
| 234 |
-
logger.info(translations["separator_process_backing_success"])
|
| 235 |
-
return original_output, instruments_output
|
| 236 |
-
|
| 237 |
-
def separator_reverb(output, format, segments_size, overlap, denoise, original, backing_reverb, hop_length, batch_size, sample_rate):
|
| 238 |
-
if not os.path.exists(output):
|
| 239 |
-
logger.warning(translations["output_not_valid"])
|
| 240 |
-
sys.exit(1)
|
| 241 |
-
|
| 242 |
-
for i in [f"Original_Vocals_Reverb.{format}", f"Main_Vocals_Reverb.{format}", f"Original_Vocals_No_Reverb.{format}", f"Main_Vocals_No_Reverb.{format}"]:
|
| 243 |
-
if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
|
| 244 |
-
|
| 245 |
-
dereveb_path = []
|
| 246 |
-
|
| 247 |
-
if original:
|
| 248 |
-
try:
|
| 249 |
-
dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Original_Vocals' in f][0]))
|
| 250 |
-
except IndexError:
|
| 251 |
-
logger.warning(translations["not_found_original_vocal"])
|
| 252 |
-
sys.exit(1)
|
| 253 |
-
|
| 254 |
-
if backing_reverb:
|
| 255 |
-
try:
|
| 256 |
-
dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Main_Vocals' in f][0]))
|
| 257 |
-
except IndexError:
|
| 258 |
-
logger.warning(translations["not_found_main_vocal"])
|
| 259 |
-
sys.exit(1)
|
| 260 |
-
|
| 261 |
-
if backing_reverb:
|
| 262 |
-
try:
|
| 263 |
-
dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Backing_Vocals' in f][0]))
|
| 264 |
-
except IndexError:
|
| 265 |
-
logger.warning(translations["not_found_backing_vocal"])
|
| 266 |
-
sys.exit(1)
|
| 267 |
-
|
| 268 |
-
for path in dereveb_path:
|
| 269 |
-
if not os.path.exists(path):
|
| 270 |
-
logger.warning(translations["not_found"].format(name=path))
|
| 271 |
-
sys.exit(1)
|
| 272 |
-
|
| 273 |
-
if "Original_Vocals" in path:
|
| 274 |
-
reverb_path, no_reverb_path = os.path.join(output, f"Original_Vocals_Reverb.{format}"), os.path.join(output, f"Original_Vocals_No_Reverb.{format}")
|
| 275 |
-
start_title, end_title = translations["process_original"], translations["process_original_success"]
|
| 276 |
-
elif "Main_Vocals" in path:
|
| 277 |
-
reverb_path, no_reverb_path = os.path.join(output, f"Main_Vocals_Reverb.{format}"), os.path.join(output, f"Main_Vocals_No_Reverb.{format}")
|
| 278 |
-
start_title, end_title = translations["process_main"], translations["process_main_success"]
|
| 279 |
-
elif "Backing_Vocals" in path:
|
| 280 |
-
reverb_path, no_reverb_path = os.path.join(output, f"Backing_Vocals_Reverb.{format}"), os.path.join(output, f"Backing_Vocals_No_Reverb.{format}")
|
| 281 |
-
start_title, end_title = translations["process_backing"], translations["process_backing_success"]
|
| 282 |
-
|
| 283 |
-
logger.info(start_title)
|
| 284 |
-
output_dereveb = separator_main(audio_file=path, model_filename="Reverb_HQ_By_FoxJoy.onnx", output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate)
|
| 285 |
-
|
| 286 |
-
for f in output_dereveb:
|
| 287 |
-
path = os.path.join(output, f)
|
| 288 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
| 289 |
-
|
| 290 |
-
if '_(Reverb)_' in f: os.rename(path, reverb_path)
|
| 291 |
-
elif '_(No Reverb)_' in f: os.rename(path, no_reverb_path)
|
| 292 |
-
|
| 293 |
-
logger.info(end_title)
|
| 294 |
-
|
| 295 |
-
return (os.path.join(output, f"Original_Vocals_No_Reverb.{format}") if original else None), (os.path.join(output, f"Main_Vocals_No_Reverb.{format}") if backing_reverb else None), (os.path.join(output, f"Backing_Vocals_No_Reverb.{format}") if backing_reverb else None)
|
| 296 |
-
|
| 297 |
-
def separator_main(audio_file=None, model_filename="UVR-MDX-NET_Main_340.onnx", output_format="wav", output_dir=".", mdx_segment_size=256, mdx_overlap=0.25, mdx_batch_size=1, mdx_hop_length=1024, mdx_enable_denoise=True, demucs_segment_size=256, demucs_shifts=2, demucs_overlap=0.25, sample_rate=44100):
|
| 298 |
-
try:
|
| 299 |
-
separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=sample_rate, mdx_params={"hop_length": mdx_hop_length, "segment_size": mdx_segment_size, "overlap": mdx_overlap, "batch_size": mdx_batch_size, "enable_denoise": mdx_enable_denoise}, demucs_params={"segment_size": demucs_segment_size, "shifts": demucs_shifts, "overlap": demucs_overlap, "segments_enabled": True})
|
| 300 |
-
separator.load_model(model_filename=model_filename)
|
| 301 |
-
|
| 302 |
-
return separator.separate(audio_file)
|
| 303 |
-
except:
|
| 304 |
-
logger.debug(translations["default_setting"])
|
| 305 |
-
separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=44100, mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": mdx_enable_denoise}, demucs_params={"segment_size": 128, "shifts": 2, "overlap": 0.25, "segments_enabled": True})
|
| 306 |
-
separator.load_model(model_filename=model_filename)
|
| 307 |
-
|
| 308 |
-
return separator.separate(audio_file)
|
| 309 |
-
|
| 310 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/train.py
DELETED
|
@@ -1,990 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import glob
|
| 4 |
-
import json
|
| 5 |
-
import torch
|
| 6 |
-
import hashlib
|
| 7 |
-
import logging
|
| 8 |
-
import argparse
|
| 9 |
-
import datetime
|
| 10 |
-
import warnings
|
| 11 |
-
import logging.handlers
|
| 12 |
-
|
| 13 |
-
import numpy as np
|
| 14 |
-
import soundfile as sf
|
| 15 |
-
import matplotlib.pyplot as plt
|
| 16 |
-
import torch.distributed as dist
|
| 17 |
-
import torch.utils.data as tdata
|
| 18 |
-
import torch.multiprocessing as mp
|
| 19 |
-
|
| 20 |
-
from tqdm import tqdm
|
| 21 |
-
from collections import OrderedDict
|
| 22 |
-
from random import randint, shuffle
|
| 23 |
-
from torch.utils.checkpoint import checkpoint
|
| 24 |
-
from torch.cuda.amp import GradScaler, autocast
|
| 25 |
-
from torch.utils.tensorboard import SummaryWriter
|
| 26 |
-
|
| 27 |
-
from time import time as ttime
|
| 28 |
-
from torch.nn import functional as F
|
| 29 |
-
from distutils.util import strtobool
|
| 30 |
-
from librosa.filters import mel as librosa_mel_fn
|
| 31 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 32 |
-
from torch.nn.utils.parametrizations import spectral_norm, weight_norm
|
| 33 |
-
|
| 34 |
-
sys.path.append(os.getcwd())
|
| 35 |
-
|
| 36 |
-
from main.configs.config import Config
|
| 37 |
-
from main.library.algorithm.residuals import LRELU_SLOPE
|
| 38 |
-
from main.library.algorithm.synthesizers import Synthesizer
|
| 39 |
-
from main.library.algorithm.commons import get_padding, slice_segments, clip_grad_value
|
| 40 |
-
|
| 41 |
-
MATPLOTLIB_FLAG = False
|
| 42 |
-
main_config = Config()
|
| 43 |
-
translations = main_config.translations
|
| 44 |
-
warnings.filterwarnings("ignore")
|
| 45 |
-
logging.getLogger("torch").setLevel(logging.ERROR)
|
| 46 |
-
|
| 47 |
-
class HParams:
|
| 48 |
-
def __init__(self, **kwargs):
|
| 49 |
-
for k, v in kwargs.items():
|
| 50 |
-
self[k] = HParams(**v) if isinstance(v, dict) else v
|
| 51 |
-
|
| 52 |
-
def keys(self):
|
| 53 |
-
return self.__dict__.keys()
|
| 54 |
-
|
| 55 |
-
def items(self):
|
| 56 |
-
return self.__dict__.items()
|
| 57 |
-
|
| 58 |
-
def values(self):
|
| 59 |
-
return self.__dict__.values()
|
| 60 |
-
|
| 61 |
-
def __len__(self):
|
| 62 |
-
return len(self.__dict__)
|
| 63 |
-
|
| 64 |
-
def __getitem__(self, key):
|
| 65 |
-
return self.__dict__[key]
|
| 66 |
-
|
| 67 |
-
def __setitem__(self, key, value):
|
| 68 |
-
self.__dict__[key] = value
|
| 69 |
-
|
| 70 |
-
def __contains__(self, key):
|
| 71 |
-
return key in self.__dict__
|
| 72 |
-
|
| 73 |
-
def __repr__(self):
|
| 74 |
-
return repr(self.__dict__)
|
| 75 |
-
|
| 76 |
-
def parse_arguments():
|
| 77 |
-
parser = argparse.ArgumentParser()
|
| 78 |
-
parser.add_argument("--model_name", type=str, required=True)
|
| 79 |
-
parser.add_argument("--rvc_version", type=str, default="v2")
|
| 80 |
-
parser.add_argument("--save_every_epoch", type=int, required=True)
|
| 81 |
-
parser.add_argument("--save_only_latest", type=lambda x: bool(strtobool(x)), default=True)
|
| 82 |
-
parser.add_argument("--save_every_weights", type=lambda x: bool(strtobool(x)), default=True)
|
| 83 |
-
parser.add_argument("--total_epoch", type=int, default=300)
|
| 84 |
-
parser.add_argument("--sample_rate", type=int, required=True)
|
| 85 |
-
parser.add_argument("--batch_size", type=int, default=8)
|
| 86 |
-
parser.add_argument("--gpu", type=str, default="0")
|
| 87 |
-
parser.add_argument("--pitch_guidance", type=lambda x: bool(strtobool(x)), default=True)
|
| 88 |
-
parser.add_argument("--g_pretrained_path", type=str, default="")
|
| 89 |
-
parser.add_argument("--d_pretrained_path", type=str, default="")
|
| 90 |
-
parser.add_argument("--overtraining_detector", type=lambda x: bool(strtobool(x)), default=False)
|
| 91 |
-
parser.add_argument("--overtraining_threshold", type=int, default=50)
|
| 92 |
-
parser.add_argument("--cleanup", type=lambda x: bool(strtobool(x)), default=False)
|
| 93 |
-
parser.add_argument("--cache_data_in_gpu", type=lambda x: bool(strtobool(x)), default=False)
|
| 94 |
-
parser.add_argument("--model_author", type=str)
|
| 95 |
-
parser.add_argument("--vocoder", type=str, default="Default")
|
| 96 |
-
parser.add_argument("--checkpointing", type=lambda x: bool(strtobool(x)), default=False)
|
| 97 |
-
parser.add_argument("--deterministic", type=lambda x: bool(strtobool(x)), default=False)
|
| 98 |
-
parser.add_argument("--benchmark", type=lambda x: bool(strtobool(x)), default=False)
|
| 99 |
-
|
| 100 |
-
return parser.parse_args()
|
| 101 |
-
|
| 102 |
-
args = parse_arguments()
|
| 103 |
-
model_name, save_every_epoch, total_epoch, pretrainG, pretrainD, version, gpus, batch_size, sample_rate, pitch_guidance, save_only_latest, save_every_weights, cache_data_in_gpu, overtraining_detector, overtraining_threshold, cleanup, model_author, vocoder, checkpointing = args.model_name, args.save_every_epoch, args.total_epoch, args.g_pretrained_path, args.d_pretrained_path, args.rvc_version, args.gpu, args.batch_size, args.sample_rate, args.pitch_guidance, args.save_only_latest, args.save_every_weights, args.cache_data_in_gpu, args.overtraining_detector, args.overtraining_threshold, args.cleanup, args.model_author, args.vocoder, args.checkpointing
|
| 104 |
-
|
| 105 |
-
experiment_dir = os.path.join("assets", "logs", model_name)
|
| 106 |
-
training_file_path = os.path.join(experiment_dir, "training_data.json")
|
| 107 |
-
config_save_path = os.path.join(experiment_dir, "config.json")
|
| 108 |
-
torch.backends.cudnn.deterministic = args.deterministic
|
| 109 |
-
torch.backends.cudnn.benchmark = args.benchmark
|
| 110 |
-
lowest_value = {"step": 0, "value": float("inf"), "epoch": 0}
|
| 111 |
-
global_step, last_loss_gen_all, overtrain_save_epoch = 0, 0, 0
|
| 112 |
-
loss_gen_history, smoothed_loss_gen_history, loss_disc_history, smoothed_loss_disc_history = [], [], [], []
|
| 113 |
-
|
| 114 |
-
with open(config_save_path, "r") as f:
|
| 115 |
-
config = json.load(f)
|
| 116 |
-
|
| 117 |
-
config = HParams(**config)
|
| 118 |
-
config.data.training_files = os.path.join(experiment_dir, "filelist.txt")
|
| 119 |
-
logger = logging.getLogger(__name__)
|
| 120 |
-
|
| 121 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
| 122 |
-
else:
|
| 123 |
-
console_handler = logging.StreamHandler()
|
| 124 |
-
console_handler.setFormatter(logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
|
| 125 |
-
console_handler.setLevel(logging.INFO)
|
| 126 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join(experiment_dir, "train.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
| 127 |
-
file_handler.setFormatter(logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
|
| 128 |
-
file_handler.setLevel(logging.DEBUG)
|
| 129 |
-
logger.addHandler(console_handler)
|
| 130 |
-
logger.addHandler(file_handler)
|
| 131 |
-
logger.setLevel(logging.DEBUG)
|
| 132 |
-
|
| 133 |
-
log_data = {translations['modelname']: model_name, translations["save_every_epoch"]: save_every_epoch, translations["total_e"]: total_epoch, translations["dorg"].format(pretrainG=pretrainG, pretrainD=pretrainD): "", translations['training_version']: version, "Gpu": gpus, translations['batch_size']: batch_size, translations['pretrain_sr']: sample_rate, translations['training_f0']: pitch_guidance, translations['save_only_latest']: save_only_latest, translations['save_every_weights']: save_every_weights, translations['cache_in_gpu']: cache_data_in_gpu, translations['overtraining_detector']: overtraining_detector, translations['threshold']: overtraining_threshold, translations['cleanup_training']: cleanup, translations['memory_efficient_training']: checkpointing}
|
| 134 |
-
if model_author: log_data[translations["model_author"].format(model_author=model_author)] = ""
|
| 135 |
-
if vocoder != "Default": log_data[translations['vocoder']] = vocoder
|
| 136 |
-
|
| 137 |
-
for key, value in log_data.items():
|
| 138 |
-
logger.debug(f"{key}: {value}" if value != "" else f"{key} {value}")
|
| 139 |
-
|
| 140 |
-
def main():
|
| 141 |
-
global training_file_path, last_loss_gen_all, smoothed_loss_gen_history, loss_gen_history, loss_disc_history, smoothed_loss_disc_history, overtrain_save_epoch, model_author, vocoder, checkpointing, gpus
|
| 142 |
-
|
| 143 |
-
try:
|
| 144 |
-
os.environ["MASTER_ADDR"] = "localhost"
|
| 145 |
-
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
| 146 |
-
|
| 147 |
-
if torch.cuda.is_available():
|
| 148 |
-
device, gpus = torch.device("cuda"), [int(item) for item in gpus.split("-")]
|
| 149 |
-
n_gpus = len(gpus)
|
| 150 |
-
elif torch.backends.mps.is_available():
|
| 151 |
-
device, gpus = torch.device("mps"), [0]
|
| 152 |
-
n_gpus = 1
|
| 153 |
-
else:
|
| 154 |
-
device, gpus = torch.device("cpu"), [0]
|
| 155 |
-
n_gpus = 1
|
| 156 |
-
logger.warning(translations["not_gpu"])
|
| 157 |
-
|
| 158 |
-
def start():
|
| 159 |
-
children = []
|
| 160 |
-
pid_data = {"process_pids": []}
|
| 161 |
-
|
| 162 |
-
with open(config_save_path, "r") as pid_file:
|
| 163 |
-
try:
|
| 164 |
-
pid_data.update(json.load(pid_file))
|
| 165 |
-
except json.JSONDecodeError:
|
| 166 |
-
pass
|
| 167 |
-
|
| 168 |
-
with open(config_save_path, "w") as pid_file:
|
| 169 |
-
for rank, device_id in enumerate(gpus):
|
| 170 |
-
subproc = mp.Process(target=run, args=(rank, n_gpus, experiment_dir, pretrainG, pretrainD, pitch_guidance, total_epoch, save_every_weights, config, device, device_id, model_author, vocoder, checkpointing))
|
| 171 |
-
children.append(subproc)
|
| 172 |
-
subproc.start()
|
| 173 |
-
pid_data["process_pids"].append(subproc.pid)
|
| 174 |
-
|
| 175 |
-
json.dump(pid_data, pid_file, indent=4)
|
| 176 |
-
|
| 177 |
-
for i in range(n_gpus):
|
| 178 |
-
children[i].join()
|
| 179 |
-
|
| 180 |
-
def load_from_json(file_path):
|
| 181 |
-
if os.path.exists(file_path):
|
| 182 |
-
with open(file_path, "r") as f:
|
| 183 |
-
data = json.load(f)
|
| 184 |
-
return (data.get("loss_disc_history", []), data.get("smoothed_loss_disc_history", []), data.get("loss_gen_history", []), data.get("smoothed_loss_gen_history", []))
|
| 185 |
-
return [], [], [], []
|
| 186 |
-
|
| 187 |
-
def continue_overtrain_detector(training_file_path):
|
| 188 |
-
if overtraining_detector and os.path.exists(training_file_path): (loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history) = load_from_json(training_file_path)
|
| 189 |
-
|
| 190 |
-
if cleanup:
|
| 191 |
-
for root, dirs, files in os.walk(experiment_dir, topdown=False):
|
| 192 |
-
for name in files:
|
| 193 |
-
file_path = os.path.join(root, name)
|
| 194 |
-
_, file_extension = os.path.splitext(name)
|
| 195 |
-
if (file_extension == ".0" or (name.startswith("D_") and file_extension == ".pth") or (name.startswith("G_") and file_extension == ".pth") or (file_extension == ".index")): os.remove(file_path)
|
| 196 |
-
|
| 197 |
-
for name in dirs:
|
| 198 |
-
if name == "eval":
|
| 199 |
-
folder_path = os.path.join(root, name)
|
| 200 |
-
for item in os.listdir(folder_path):
|
| 201 |
-
item_path = os.path.join(folder_path, item)
|
| 202 |
-
if os.path.isfile(item_path): os.remove(item_path)
|
| 203 |
-
os.rmdir(folder_path)
|
| 204 |
-
|
| 205 |
-
continue_overtrain_detector(training_file_path)
|
| 206 |
-
start()
|
| 207 |
-
except Exception as e:
|
| 208 |
-
logger.error(f"{translations['training_error']} {e}")
|
| 209 |
-
import traceback
|
| 210 |
-
logger.debug(traceback.format_exc())
|
| 211 |
-
|
| 212 |
-
def plot_spectrogram_to_numpy(spectrogram):
|
| 213 |
-
global MATPLOTLIB_FLAG
|
| 214 |
-
|
| 215 |
-
if not MATPLOTLIB_FLAG:
|
| 216 |
-
plt.switch_backend("Agg")
|
| 217 |
-
MATPLOTLIB_FLAG = True
|
| 218 |
-
|
| 219 |
-
fig, ax = plt.subplots(figsize=(10, 2))
|
| 220 |
-
plt.colorbar(ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none"), ax=ax)
|
| 221 |
-
plt.xlabel("Frames")
|
| 222 |
-
plt.ylabel("Channels")
|
| 223 |
-
plt.tight_layout()
|
| 224 |
-
fig.canvas.draw()
|
| 225 |
-
plt.close(fig)
|
| 226 |
-
|
| 227 |
-
try:
|
| 228 |
-
data = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8).reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3]
|
| 229 |
-
except:
|
| 230 |
-
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="").reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
| 231 |
-
|
| 232 |
-
return data
|
| 233 |
-
|
| 234 |
-
def verify_checkpoint_shapes(checkpoint_path, model):
|
| 235 |
-
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 236 |
-
checkpoint_state_dict = checkpoint["model"]
|
| 237 |
-
try:
|
| 238 |
-
model_state_dict = model.module.load_state_dict(checkpoint_state_dict) if hasattr(model, "module") else model.load_state_dict(checkpoint_state_dict)
|
| 239 |
-
except RuntimeError:
|
| 240 |
-
logger.warning(translations["checkpointing_err"])
|
| 241 |
-
sys.exit(1)
|
| 242 |
-
else: del checkpoint, checkpoint_state_dict, model_state_dict
|
| 243 |
-
|
| 244 |
-
def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sample_rate=22050):
|
| 245 |
-
for k, v in scalars.items():
|
| 246 |
-
writer.add_scalar(k, v, global_step)
|
| 247 |
-
|
| 248 |
-
for k, v in histograms.items():
|
| 249 |
-
writer.add_histogram(k, v, global_step)
|
| 250 |
-
|
| 251 |
-
for k, v in images.items():
|
| 252 |
-
writer.add_image(k, v, global_step, dataformats="HWC")
|
| 253 |
-
|
| 254 |
-
for k, v in audios.items():
|
| 255 |
-
writer.add_audio(k, v, global_step, audio_sample_rate)
|
| 256 |
-
|
| 257 |
-
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
|
| 258 |
-
assert os.path.isfile(checkpoint_path), translations["not_found_checkpoint"].format(checkpoint_path=checkpoint_path)
|
| 259 |
-
checkpoint_dict = replace_keys_in_dict(replace_keys_in_dict(torch.load(checkpoint_path, map_location="cpu"), ".weight_v", ".parametrizations.weight.original1"), ".weight_g", ".parametrizations.weight.original0")
|
| 260 |
-
new_state_dict = {k: checkpoint_dict["model"].get(k, v) for k, v in (model.module.state_dict() if hasattr(model, "module") else model.state_dict()).items()}
|
| 261 |
-
|
| 262 |
-
if hasattr(model, "module"): model.module.load_state_dict(new_state_dict, strict=False)
|
| 263 |
-
else: model.load_state_dict(new_state_dict, strict=False)
|
| 264 |
-
|
| 265 |
-
if optimizer and load_opt == 1: optimizer.load_state_dict(checkpoint_dict.get("optimizer", {}))
|
| 266 |
-
logger.debug(translations["save_checkpoint"].format(checkpoint_path=checkpoint_path, checkpoint_dict=checkpoint_dict['iteration']))
|
| 267 |
-
return (model, optimizer, checkpoint_dict.get("learning_rate", 0), checkpoint_dict["iteration"])
|
| 268 |
-
|
| 269 |
-
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
|
| 270 |
-
state_dict = (model.module.state_dict() if hasattr(model, "module") else model.state_dict())
|
| 271 |
-
torch.save(replace_keys_in_dict(replace_keys_in_dict({"model": state_dict, "iteration": iteration, "optimizer": optimizer.state_dict(), "learning_rate": learning_rate}, ".parametrizations.weight.original1", ".weight_v"), ".parametrizations.weight.original0", ".weight_g"), checkpoint_path)
|
| 272 |
-
logger.info(translations["save_model"].format(checkpoint_path=checkpoint_path, iteration=iteration))
|
| 273 |
-
|
| 274 |
-
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
|
| 275 |
-
checkpoints = sorted(glob.glob(os.path.join(dir_path, regex)), key=lambda f: int("".join(filter(str.isdigit, f))))
|
| 276 |
-
return checkpoints[-1] if checkpoints else None
|
| 277 |
-
|
| 278 |
-
def load_wav_to_torch(full_path):
|
| 279 |
-
data, sample_rate = sf.read(full_path, dtype=np.float32)
|
| 280 |
-
return torch.FloatTensor(data.astype(np.float32)), sample_rate
|
| 281 |
-
|
| 282 |
-
def load_filepaths_and_text(filename, split="|"):
|
| 283 |
-
with open(filename, encoding="utf-8") as f:
|
| 284 |
-
return [line.strip().split(split) for line in f]
|
| 285 |
-
|
| 286 |
-
def feature_loss(fmap_r, fmap_g):
|
| 287 |
-
loss = 0
|
| 288 |
-
for dr, dg in zip(fmap_r, fmap_g):
|
| 289 |
-
for rl, gl in zip(dr, dg):
|
| 290 |
-
loss += torch.mean(torch.abs(rl.float().detach() - gl.float()))
|
| 291 |
-
return loss * 2
|
| 292 |
-
|
| 293 |
-
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
| 294 |
-
loss = 0
|
| 295 |
-
r_losses, g_losses = [], []
|
| 296 |
-
|
| 297 |
-
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
| 298 |
-
dr = dr.float()
|
| 299 |
-
dg = dg.float()
|
| 300 |
-
r_loss = torch.mean((1 - dr) ** 2)
|
| 301 |
-
g_loss = torch.mean(dg**2)
|
| 302 |
-
loss += r_loss + g_loss
|
| 303 |
-
r_losses.append(r_loss.item())
|
| 304 |
-
g_losses.append(g_loss.item())
|
| 305 |
-
return loss, r_losses, g_losses
|
| 306 |
-
|
| 307 |
-
def generator_loss(disc_outputs):
|
| 308 |
-
loss = 0
|
| 309 |
-
gen_losses = []
|
| 310 |
-
|
| 311 |
-
for dg in disc_outputs:
|
| 312 |
-
l = torch.mean((1 - dg.float()) ** 2)
|
| 313 |
-
gen_losses.append(l)
|
| 314 |
-
loss += l
|
| 315 |
-
return loss, gen_losses
|
| 316 |
-
|
| 317 |
-
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
| 318 |
-
z_p = z_p.float()
|
| 319 |
-
logs_q = logs_q.float()
|
| 320 |
-
m_p = m_p.float()
|
| 321 |
-
logs_p = logs_p.float()
|
| 322 |
-
z_mask = z_mask.float()
|
| 323 |
-
kl = logs_p - logs_q - 0.5
|
| 324 |
-
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
| 325 |
-
return torch.sum(kl * z_mask) / torch.sum(z_mask)
|
| 326 |
-
|
| 327 |
-
class TextAudioLoaderMultiNSFsid(tdata.Dataset):
|
| 328 |
-
def __init__(self, hparams):
|
| 329 |
-
self.audiopaths_and_text = load_filepaths_and_text(hparams.training_files)
|
| 330 |
-
self.max_wav_value = hparams.max_wav_value
|
| 331 |
-
self.sample_rate = hparams.sample_rate
|
| 332 |
-
self.filter_length = hparams.filter_length
|
| 333 |
-
self.hop_length = hparams.hop_length
|
| 334 |
-
self.win_length = hparams.win_length
|
| 335 |
-
self.sample_rate = hparams.sample_rate
|
| 336 |
-
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
| 337 |
-
self.max_text_len = getattr(hparams, "max_text_len", 5000)
|
| 338 |
-
self._filter()
|
| 339 |
-
|
| 340 |
-
def _filter(self):
|
| 341 |
-
audiopaths_and_text_new, lengths = [], []
|
| 342 |
-
for audiopath, text, pitch, pitchf, dv in self.audiopaths_and_text:
|
| 343 |
-
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
|
| 344 |
-
audiopaths_and_text_new.append([audiopath, text, pitch, pitchf, dv])
|
| 345 |
-
lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
|
| 346 |
-
|
| 347 |
-
self.audiopaths_and_text = audiopaths_and_text_new
|
| 348 |
-
self.lengths = lengths
|
| 349 |
-
|
| 350 |
-
def get_sid(self, sid):
|
| 351 |
-
try:
|
| 352 |
-
sid = torch.LongTensor([int(sid)])
|
| 353 |
-
except ValueError as e:
|
| 354 |
-
logger.error(translations["sid_error"].format(sid=sid, e=e))
|
| 355 |
-
sid = torch.LongTensor([0])
|
| 356 |
-
return sid
|
| 357 |
-
|
| 358 |
-
def get_audio_text_pair(self, audiopath_and_text):
|
| 359 |
-
phone, pitch, pitchf = self.get_labels(audiopath_and_text[1], audiopath_and_text[2], audiopath_and_text[3])
|
| 360 |
-
spec, wav = self.get_audio(audiopath_and_text[0])
|
| 361 |
-
dv = self.get_sid(audiopath_and_text[4])
|
| 362 |
-
len_phone = phone.size()[0]
|
| 363 |
-
len_spec = spec.size()[-1]
|
| 364 |
-
|
| 365 |
-
if len_phone != len_spec:
|
| 366 |
-
len_min = min(len_phone, len_spec)
|
| 367 |
-
len_wav = len_min * self.hop_length
|
| 368 |
-
spec, wav, phone = spec[:, :len_min], wav[:, :len_wav], phone[:len_min, :]
|
| 369 |
-
pitch, pitchf = pitch[:len_min], pitchf[:len_min]
|
| 370 |
-
return (spec, wav, phone, pitch, pitchf, dv)
|
| 371 |
-
|
| 372 |
-
def get_labels(self, phone, pitch, pitchf):
|
| 373 |
-
phone = np.repeat(np.load(phone), 2, axis=0)
|
| 374 |
-
n_num = min(phone.shape[0], 900)
|
| 375 |
-
return torch.FloatTensor(phone[:n_num, :]), torch.LongTensor(np.load(pitch)[:n_num]), torch.FloatTensor(np.load(pitchf)[:n_num])
|
| 376 |
-
|
| 377 |
-
def get_audio(self, filename):
|
| 378 |
-
audio, sample_rate = load_wav_to_torch(filename)
|
| 379 |
-
if sample_rate != self.sample_rate: raise ValueError(translations["sr_does_not_match"].format(sample_rate=sample_rate, sample_rate2=self.sample_rate))
|
| 380 |
-
audio_norm = audio.unsqueeze(0)
|
| 381 |
-
spec_filename = filename.replace(".wav", ".spec.pt")
|
| 382 |
-
|
| 383 |
-
if os.path.exists(spec_filename):
|
| 384 |
-
try:
|
| 385 |
-
spec = torch.load(spec_filename)
|
| 386 |
-
except Exception as e:
|
| 387 |
-
logger.error(translations["spec_error"].format(spec_filename=spec_filename, e=e))
|
| 388 |
-
spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0)
|
| 389 |
-
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
| 390 |
-
else:
|
| 391 |
-
spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0)
|
| 392 |
-
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
| 393 |
-
return spec, audio_norm
|
| 394 |
-
|
| 395 |
-
def __getitem__(self, index):
|
| 396 |
-
return self.get_audio_text_pair(self.audiopaths_and_text[index])
|
| 397 |
-
|
| 398 |
-
def __len__(self):
|
| 399 |
-
return len(self.audiopaths_and_text)
|
| 400 |
-
|
| 401 |
-
class TextAudioCollateMultiNSFsid:
|
| 402 |
-
def __init__(self, return_ids=False):
|
| 403 |
-
self.return_ids = return_ids
|
| 404 |
-
|
| 405 |
-
def __call__(self, batch):
|
| 406 |
-
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True)
|
| 407 |
-
spec_lengths, wave_lengths = torch.LongTensor(len(batch)), torch.LongTensor(len(batch))
|
| 408 |
-
spec_padded, wave_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max([x[0].size(1) for x in batch])), torch.FloatTensor(len(batch), 1, max([x[1].size(1) for x in batch]))
|
| 409 |
-
spec_padded.zero_()
|
| 410 |
-
wave_padded.zero_()
|
| 411 |
-
max_phone_len = max([x[2].size(0) for x in batch])
|
| 412 |
-
phone_lengths, phone_padded = torch.LongTensor(len(batch)), torch.FloatTensor(len(batch), max_phone_len, batch[0][2].shape[1])
|
| 413 |
-
pitch_padded, pitchf_padded = torch.LongTensor(len(batch), max_phone_len), torch.FloatTensor(len(batch), max_phone_len)
|
| 414 |
-
phone_padded.zero_()
|
| 415 |
-
pitch_padded.zero_()
|
| 416 |
-
pitchf_padded.zero_()
|
| 417 |
-
sid = torch.LongTensor(len(batch))
|
| 418 |
-
|
| 419 |
-
for i in range(len(ids_sorted_decreasing)):
|
| 420 |
-
row = batch[ids_sorted_decreasing[i]]
|
| 421 |
-
spec = row[0]
|
| 422 |
-
spec_padded[i, :, : spec.size(1)] = spec
|
| 423 |
-
spec_lengths[i] = spec.size(1)
|
| 424 |
-
wave = row[1]
|
| 425 |
-
wave_padded[i, :, : wave.size(1)] = wave
|
| 426 |
-
wave_lengths[i] = wave.size(1)
|
| 427 |
-
phone = row[2]
|
| 428 |
-
phone_padded[i, : phone.size(0), :] = phone
|
| 429 |
-
phone_lengths[i] = phone.size(0)
|
| 430 |
-
pitch = row[3]
|
| 431 |
-
pitch_padded[i, : pitch.size(0)] = pitch
|
| 432 |
-
pitchf = row[4]
|
| 433 |
-
pitchf_padded[i, : pitchf.size(0)] = pitchf
|
| 434 |
-
sid[i] = row[5]
|
| 435 |
-
return (phone_padded, phone_lengths, pitch_padded, pitchf_padded, spec_padded, spec_lengths, wave_padded, wave_lengths, sid)
|
| 436 |
-
|
| 437 |
-
class TextAudioLoader(tdata.Dataset):
|
| 438 |
-
def __init__(self, hparams):
|
| 439 |
-
self.audiopaths_and_text = load_filepaths_and_text(hparams.training_files)
|
| 440 |
-
self.max_wav_value = hparams.max_wav_value
|
| 441 |
-
self.sample_rate = hparams.sample_rate
|
| 442 |
-
self.filter_length = hparams.filter_length
|
| 443 |
-
self.hop_length = hparams.hop_length
|
| 444 |
-
self.win_length = hparams.win_length
|
| 445 |
-
self.sample_rate = hparams.sample_rate
|
| 446 |
-
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
| 447 |
-
self.max_text_len = getattr(hparams, "max_text_len", 5000)
|
| 448 |
-
self._filter()
|
| 449 |
-
|
| 450 |
-
def _filter(self):
|
| 451 |
-
audiopaths_and_text_new, lengths = [], []
|
| 452 |
-
for entry in self.audiopaths_and_text:
|
| 453 |
-
if len(entry) >= 3:
|
| 454 |
-
audiopath, text, dv = entry[:3]
|
| 455 |
-
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
|
| 456 |
-
audiopaths_and_text_new.append([audiopath, text, dv])
|
| 457 |
-
lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
|
| 458 |
-
|
| 459 |
-
self.audiopaths_and_text = audiopaths_and_text_new
|
| 460 |
-
self.lengths = lengths
|
| 461 |
-
|
| 462 |
-
def get_sid(self, sid):
|
| 463 |
-
try:
|
| 464 |
-
sid = torch.LongTensor([int(sid)])
|
| 465 |
-
except ValueError as e:
|
| 466 |
-
logger.error(translations["sid_error"].format(sid=sid, e=e))
|
| 467 |
-
sid = torch.LongTensor([0])
|
| 468 |
-
return sid
|
| 469 |
-
|
| 470 |
-
def get_audio_text_pair(self, audiopath_and_text):
|
| 471 |
-
phone = self.get_labels(audiopath_and_text[1])
|
| 472 |
-
spec, wav = self.get_audio(audiopath_and_text[0])
|
| 473 |
-
dv = self.get_sid(audiopath_and_text[2])
|
| 474 |
-
len_phone = phone.size()[0]
|
| 475 |
-
len_spec = spec.size()[-1]
|
| 476 |
-
|
| 477 |
-
if len_phone != len_spec:
|
| 478 |
-
len_min = min(len_phone, len_spec)
|
| 479 |
-
len_wav = len_min * self.hop_length
|
| 480 |
-
spec = spec[:, :len_min]
|
| 481 |
-
wav = wav[:, :len_wav]
|
| 482 |
-
phone = phone[:len_min, :]
|
| 483 |
-
return (spec, wav, phone, dv)
|
| 484 |
-
|
| 485 |
-
def get_labels(self, phone):
|
| 486 |
-
phone = np.repeat(np.load(phone), 2, axis=0)
|
| 487 |
-
return torch.FloatTensor(phone[:min(phone.shape[0], 900), :])
|
| 488 |
-
|
| 489 |
-
def get_audio(self, filename):
|
| 490 |
-
audio, sample_rate = load_wav_to_torch(filename)
|
| 491 |
-
if sample_rate != self.sample_rate: raise ValueError(translations["sr_does_not_match"].format(sample_rate=sample_rate, sample_rate2=self.sample_rate))
|
| 492 |
-
audio_norm = audio.unsqueeze(0)
|
| 493 |
-
spec_filename = filename.replace(".wav", ".spec.pt")
|
| 494 |
-
|
| 495 |
-
if os.path.exists(spec_filename):
|
| 496 |
-
try:
|
| 497 |
-
spec = torch.load(spec_filename)
|
| 498 |
-
except Exception as e:
|
| 499 |
-
logger.error(translations["spec_error"].format(spec_filename=spec_filename, e=e))
|
| 500 |
-
spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0)
|
| 501 |
-
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
| 502 |
-
else:
|
| 503 |
-
spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0)
|
| 504 |
-
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
| 505 |
-
return spec, audio_norm
|
| 506 |
-
|
| 507 |
-
def __getitem__(self, index):
|
| 508 |
-
return self.get_audio_text_pair(self.audiopaths_and_text[index])
|
| 509 |
-
|
| 510 |
-
def __len__(self):
|
| 511 |
-
return len(self.audiopaths_and_text)
|
| 512 |
-
|
| 513 |
-
class TextAudioCollate:
|
| 514 |
-
def __init__(self, return_ids=False):
|
| 515 |
-
self.return_ids = return_ids
|
| 516 |
-
|
| 517 |
-
def __call__(self, batch):
|
| 518 |
-
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True)
|
| 519 |
-
spec_lengths, wave_lengths = torch.LongTensor(len(batch)), torch.LongTensor(len(batch))
|
| 520 |
-
spec_padded, wave_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max([x[0].size(1) for x in batch])), torch.FloatTensor(len(batch), 1, max([x[1].size(1) for x in batch]))
|
| 521 |
-
spec_padded.zero_()
|
| 522 |
-
wave_padded.zero_()
|
| 523 |
-
max_phone_len = max([x[2].size(0) for x in batch])
|
| 524 |
-
phone_lengths, phone_padded = torch.LongTensor(len(batch)), torch.FloatTensor(len(batch), max_phone_len, batch[0][2].shape[1])
|
| 525 |
-
phone_padded.zero_()
|
| 526 |
-
sid = torch.LongTensor(len(batch))
|
| 527 |
-
for i in range(len(ids_sorted_decreasing)):
|
| 528 |
-
row = batch[ids_sorted_decreasing[i]]
|
| 529 |
-
spec = row[0]
|
| 530 |
-
spec_padded[i, :, : spec.size(1)] = spec
|
| 531 |
-
spec_lengths[i] = spec.size(1)
|
| 532 |
-
wave = row[1]
|
| 533 |
-
wave_padded[i, :, : wave.size(1)] = wave
|
| 534 |
-
wave_lengths[i] = wave.size(1)
|
| 535 |
-
phone = row[2]
|
| 536 |
-
phone_padded[i, : phone.size(0), :] = phone
|
| 537 |
-
phone_lengths[i] = phone.size(0)
|
| 538 |
-
sid[i] = row[3]
|
| 539 |
-
return (phone_padded, phone_lengths, spec_padded, spec_lengths, wave_padded, wave_lengths, sid)
|
| 540 |
-
|
| 541 |
-
class DistributedBucketSampler(tdata.distributed.DistributedSampler):
|
| 542 |
-
def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
|
| 543 |
-
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
| 544 |
-
self.lengths = dataset.lengths
|
| 545 |
-
self.batch_size = batch_size
|
| 546 |
-
self.boundaries = boundaries
|
| 547 |
-
self.buckets, self.num_samples_per_bucket = self._create_buckets()
|
| 548 |
-
self.total_size = sum(self.num_samples_per_bucket)
|
| 549 |
-
self.num_samples = self.total_size // self.num_replicas
|
| 550 |
-
|
| 551 |
-
def _create_buckets(self):
|
| 552 |
-
buckets = [[] for _ in range(len(self.boundaries) - 1)]
|
| 553 |
-
for i in range(len(self.lengths)):
|
| 554 |
-
idx_bucket = self._bisect(self.lengths[i])
|
| 555 |
-
if idx_bucket != -1: buckets[idx_bucket].append(i)
|
| 556 |
-
|
| 557 |
-
for i in range(len(buckets) - 1, -1, -1):
|
| 558 |
-
if len(buckets[i]) == 0:
|
| 559 |
-
buckets.pop(i)
|
| 560 |
-
self.boundaries.pop(i + 1)
|
| 561 |
-
|
| 562 |
-
num_samples_per_bucket = []
|
| 563 |
-
for i in range(len(buckets)):
|
| 564 |
-
len_bucket = len(buckets[i])
|
| 565 |
-
total_batch_size = self.num_replicas * self.batch_size
|
| 566 |
-
num_samples_per_bucket.append(len_bucket + ((total_batch_size - (len_bucket % total_batch_size)) % total_batch_size))
|
| 567 |
-
return buckets, num_samples_per_bucket
|
| 568 |
-
|
| 569 |
-
def __iter__(self):
|
| 570 |
-
g = torch.Generator()
|
| 571 |
-
g.manual_seed(self.epoch)
|
| 572 |
-
indices, batches = [], []
|
| 573 |
-
if self.shuffle:
|
| 574 |
-
for bucket in self.buckets:
|
| 575 |
-
indices.append(torch.randperm(len(bucket), generator=g).tolist())
|
| 576 |
-
else:
|
| 577 |
-
for bucket in self.buckets:
|
| 578 |
-
indices.append(list(range(len(bucket))))
|
| 579 |
-
|
| 580 |
-
for i in range(len(self.buckets)):
|
| 581 |
-
bucket = self.buckets[i]
|
| 582 |
-
len_bucket = len(bucket)
|
| 583 |
-
ids_bucket = indices[i]
|
| 584 |
-
rem = self.num_samples_per_bucket[i] - len_bucket
|
| 585 |
-
ids_bucket = (ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)])[self.rank :: self.num_replicas]
|
| 586 |
-
|
| 587 |
-
for j in range(len(ids_bucket) // self.batch_size):
|
| 588 |
-
batches.append([bucket[idx] for idx in ids_bucket[j * self.batch_size : (j + 1) * self.batch_size]])
|
| 589 |
-
|
| 590 |
-
if self.shuffle: batches = [batches[i] for i in torch.randperm(len(batches), generator=g).tolist()]
|
| 591 |
-
self.batches = batches
|
| 592 |
-
assert len(self.batches) * self.batch_size == self.num_samples
|
| 593 |
-
return iter(self.batches)
|
| 594 |
-
|
| 595 |
-
def _bisect(self, x, lo=0, hi=None):
|
| 596 |
-
if hi is None: hi = len(self.boundaries) - 1
|
| 597 |
-
|
| 598 |
-
if hi > lo:
|
| 599 |
-
mid = (hi + lo) // 2
|
| 600 |
-
if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: return mid
|
| 601 |
-
elif x <= self.boundaries[mid]: return self._bisect(x, lo, mid)
|
| 602 |
-
else: return self._bisect(x, mid + 1, hi)
|
| 603 |
-
else: return -1
|
| 604 |
-
|
| 605 |
-
def __len__(self):
|
| 606 |
-
return self.num_samples // self.batch_size
|
| 607 |
-
|
| 608 |
-
class MultiPeriodDiscriminator(torch.nn.Module):
|
| 609 |
-
def __init__(self, version, use_spectral_norm=False, checkpointing=False):
|
| 610 |
-
super(MultiPeriodDiscriminator, self).__init__()
|
| 611 |
-
self.checkpointing = checkpointing
|
| 612 |
-
periods = ([2, 3, 5, 7, 11, 17] if version == "v1" else [2, 3, 5, 7, 11, 17, 23, 37])
|
| 613 |
-
self.discriminators = torch.nn.ModuleList([DiscriminatorS(use_spectral_norm=use_spectral_norm, checkpointing=checkpointing)] + [DiscriminatorP(p, use_spectral_norm=use_spectral_norm, checkpointing=checkpointing) for p in periods])
|
| 614 |
-
|
| 615 |
-
def forward(self, y, y_hat):
|
| 616 |
-
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
|
| 617 |
-
for d in self.discriminators:
|
| 618 |
-
if self.training and self.checkpointing:
|
| 619 |
-
def forward_discriminator(d, y, y_hat):
|
| 620 |
-
y_d_r, fmap_r = d(y)
|
| 621 |
-
y_d_g, fmap_g = d(y_hat)
|
| 622 |
-
return y_d_r, fmap_r, y_d_g, fmap_g
|
| 623 |
-
y_d_r, fmap_r, y_d_g, fmap_g = checkpoint(forward_discriminator, d, y, y_hat, use_reentrant=False)
|
| 624 |
-
else:
|
| 625 |
-
y_d_r, fmap_r = d(y)
|
| 626 |
-
y_d_g, fmap_g = d(y_hat)
|
| 627 |
-
|
| 628 |
-
y_d_rs.append(y_d_r); fmap_rs.append(fmap_r)
|
| 629 |
-
y_d_gs.append(y_d_g); fmap_gs.append(fmap_g)
|
| 630 |
-
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 631 |
-
|
| 632 |
-
class DiscriminatorS(torch.nn.Module):
|
| 633 |
-
def __init__(self, use_spectral_norm=False, checkpointing=False):
|
| 634 |
-
super(DiscriminatorS, self).__init__()
|
| 635 |
-
self.checkpointing = checkpointing
|
| 636 |
-
norm_f = spectral_norm if use_spectral_norm else weight_norm
|
| 637 |
-
self.convs = torch.nn.ModuleList([norm_f(torch.nn.Conv1d(1, 16, 15, 1, padding=7)), norm_f(torch.nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)), norm_f(torch.nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)), norm_f(torch.nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)), norm_f(torch.nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), norm_f(torch.nn.Conv1d(1024, 1024, 5, 1, padding=2))])
|
| 638 |
-
self.conv_post = norm_f(torch.nn.Conv1d(1024, 1, 3, 1, padding=1))
|
| 639 |
-
self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE)
|
| 640 |
-
|
| 641 |
-
def forward(self, x):
|
| 642 |
-
fmap = []
|
| 643 |
-
for conv in self.convs:
|
| 644 |
-
x = checkpoint(self.lrelu, checkpoint(conv, x, use_reentrant = False), use_reentrant = False) if self.training and self.checkpointing else self.lrelu(conv(x))
|
| 645 |
-
fmap.append(x)
|
| 646 |
-
|
| 647 |
-
x = self.conv_post(x)
|
| 648 |
-
fmap.append(x)
|
| 649 |
-
return torch.flatten(x, 1, -1), fmap
|
| 650 |
-
|
| 651 |
-
class DiscriminatorP(torch.nn.Module):
|
| 652 |
-
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, checkpointing=False):
|
| 653 |
-
super(DiscriminatorP, self).__init__()
|
| 654 |
-
self.period = period
|
| 655 |
-
self.checkpointing = checkpointing
|
| 656 |
-
norm_f = spectral_norm if use_spectral_norm else weight_norm
|
| 657 |
-
self.convs = torch.nn.ModuleList([norm_f(torch.nn.Conv2d(in_ch, out_ch, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))) for in_ch, out_ch in zip([1, 32, 128, 512, 1024], [32, 128, 512, 1024, 1024])])
|
| 658 |
-
self.conv_post = norm_f(torch.nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
| 659 |
-
self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE)
|
| 660 |
-
|
| 661 |
-
def forward(self, x):
|
| 662 |
-
fmap = []
|
| 663 |
-
b, c, t = x.shape
|
| 664 |
-
|
| 665 |
-
if t % self.period != 0: x = F.pad(x, (0, (self.period - (t % self.period))), "reflect")
|
| 666 |
-
x = x.view(b, c, -1, self.period)
|
| 667 |
-
|
| 668 |
-
for conv in self.convs:
|
| 669 |
-
x = checkpoint(self.lrelu, checkpoint(conv, x, use_reentrant = False), use_reentrant = False) if self.training and self.checkpointing else self.lrelu(conv(x))
|
| 670 |
-
fmap.append(x)
|
| 671 |
-
|
| 672 |
-
x = self.conv_post(x)
|
| 673 |
-
fmap.append(x)
|
| 674 |
-
return torch.flatten(x, 1, -1), fmap
|
| 675 |
-
|
| 676 |
-
class EpochRecorder:
|
| 677 |
-
def __init__(self):
|
| 678 |
-
self.last_time = ttime()
|
| 679 |
-
|
| 680 |
-
def record(self):
|
| 681 |
-
now_time = ttime()
|
| 682 |
-
elapsed_time = now_time - self.last_time
|
| 683 |
-
self.last_time = now_time
|
| 684 |
-
return translations["time_or_speed_training"].format(current_time=datetime.datetime.now().strftime("%H:%M:%S"), elapsed_time_str=str(datetime.timedelta(seconds=int(round(elapsed_time, 1)))))
|
| 685 |
-
|
| 686 |
-
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 687 |
-
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 688 |
-
|
| 689 |
-
def dynamic_range_decompression_torch(x, C=1):
|
| 690 |
-
return torch.exp(x) / C
|
| 691 |
-
|
| 692 |
-
def spectral_normalize_torch(magnitudes):
|
| 693 |
-
return dynamic_range_compression_torch(magnitudes)
|
| 694 |
-
|
| 695 |
-
def spectral_de_normalize_torch(magnitudes):
|
| 696 |
-
return dynamic_range_decompression_torch(magnitudes)
|
| 697 |
-
|
| 698 |
-
mel_basis, hann_window = {}, {}
|
| 699 |
-
|
| 700 |
-
def spectrogram_torch(y, n_fft, hop_size, win_size, center=False):
|
| 701 |
-
global hann_window
|
| 702 |
-
|
| 703 |
-
wnsize_dtype_device = str(win_size) + "_" + str(y.dtype) + "_" + str(y.device)
|
| 704 |
-
if wnsize_dtype_device not in hann_window: hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
| 705 |
-
spec = torch.stft(F.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect").squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
|
| 706 |
-
return torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6)
|
| 707 |
-
|
| 708 |
-
def spec_to_mel_torch(spec, n_fft, num_mels, sample_rate, fmin, fmax):
|
| 709 |
-
global mel_basis
|
| 710 |
-
|
| 711 |
-
fmax_dtype_device = str(fmax) + "_" + str(spec.dtype) + "_" + str(spec.device)
|
| 712 |
-
if fmax_dtype_device not in mel_basis: mel_basis[fmax_dtype_device] = torch.from_numpy(librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)).to(dtype=spec.dtype, device=spec.device)
|
| 713 |
-
return spectral_normalize_torch(torch.matmul(mel_basis[fmax_dtype_device], spec))
|
| 714 |
-
|
| 715 |
-
def mel_spectrogram_torch(y, n_fft, num_mels, sample_rate, hop_size, win_size, fmin, fmax, center=False):
|
| 716 |
-
return spec_to_mel_torch(spectrogram_torch(y, n_fft, hop_size, win_size, center), n_fft, num_mels, sample_rate, fmin, fmax)
|
| 717 |
-
|
| 718 |
-
def replace_keys_in_dict(d, old_key_part, new_key_part):
|
| 719 |
-
updated_dict = OrderedDict() if isinstance(d, OrderedDict) else {}
|
| 720 |
-
for key, value in d.items():
|
| 721 |
-
updated_dict[(key.replace(old_key_part, new_key_part) if isinstance(key, str) else key)] = (replace_keys_in_dict(value, old_key_part, new_key_part) if isinstance(value, dict) else value)
|
| 722 |
-
return updated_dict
|
| 723 |
-
|
| 724 |
-
def extract_model(ckpt, sr, pitch_guidance, name, model_path, epoch, step, version, hps, model_author, vocoder):
|
| 725 |
-
try:
|
| 726 |
-
logger.info(translations["savemodel"].format(model_dir=model_path, epoch=epoch, step=step))
|
| 727 |
-
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
| 728 |
-
|
| 729 |
-
opt = OrderedDict(weight={key: value.half() for key, value in ckpt.items() if "enc_q" not in key})
|
| 730 |
-
opt["config"] = [hps.data.filter_length // 2 + 1, 32, hps.model.inter_channels, hps.model.hidden_channels, hps.model.filter_channels, hps.model.n_heads, hps.model.n_layers, hps.model.kernel_size, hps.model.p_dropout, hps.model.resblock, hps.model.resblock_kernel_sizes, hps.model.resblock_dilation_sizes, hps.model.upsample_rates, hps.model.upsample_initial_channel, hps.model.upsample_kernel_sizes, hps.model.spk_embed_dim, hps.model.gin_channels, hps.data.sample_rate]
|
| 731 |
-
opt["epoch"] = f"{epoch}epoch"
|
| 732 |
-
opt["step"] = step
|
| 733 |
-
opt["sr"] = sr
|
| 734 |
-
opt["f0"] = int(pitch_guidance)
|
| 735 |
-
opt["version"] = version
|
| 736 |
-
opt["creation_date"] = datetime.datetime.now().isoformat()
|
| 737 |
-
opt["model_hash"] = hashlib.sha256(f"{str(ckpt)} {epoch} {step} {datetime.datetime.now().isoformat()}".encode()).hexdigest()
|
| 738 |
-
opt["model_name"] = name
|
| 739 |
-
opt["author"] = model_author
|
| 740 |
-
opt["vocoder"] = vocoder
|
| 741 |
-
|
| 742 |
-
torch.save(replace_keys_in_dict(replace_keys_in_dict(opt, ".parametrizations.weight.original1", ".weight_v"), ".parametrizations.weight.original0", ".weight_g"), model_path)
|
| 743 |
-
except Exception as e:
|
| 744 |
-
logger.error(f"{translations['extract_model_error']}: {e}")
|
| 745 |
-
|
| 746 |
-
def run(rank, n_gpus, experiment_dir, pretrainG, pretrainD, pitch_guidance, custom_total_epoch, custom_save_every_weights, config, device, device_id, model_author, vocoder, checkpointing):
|
| 747 |
-
global global_step
|
| 748 |
-
|
| 749 |
-
if rank == 0: writer_eval = SummaryWriter(log_dir=os.path.join(experiment_dir, "eval"))
|
| 750 |
-
else: writer_eval = None
|
| 751 |
-
|
| 752 |
-
try:
|
| 753 |
-
dist.init_process_group(backend=("gloo" if sys.platform == "win32" or device.type != "cuda" else "nccl"), init_method="env://", world_size=n_gpus, rank=rank)
|
| 754 |
-
except:
|
| 755 |
-
dist.init_process_group(backend=("gloo" if sys.platform == "win32" or device.type != "cuda" else "nccl"), init_method="env://?use_libuv=False", world_size=n_gpus, rank=rank)
|
| 756 |
-
|
| 757 |
-
torch.manual_seed(config.train.seed)
|
| 758 |
-
if torch.cuda.is_available(): torch.cuda.set_device(device_id)
|
| 759 |
-
|
| 760 |
-
train_dataset = TextAudioLoaderMultiNSFsid(config.data) if pitch_guidance else TextAudioLoader(config.data)
|
| 761 |
-
train_loader = tdata.DataLoader(train_dataset, num_workers=4, shuffle=False, pin_memory=True, collate_fn=TextAudioCollateMultiNSFsid() if pitch_guidance else TextAudioCollate(), batch_sampler=DistributedBucketSampler(train_dataset, batch_size * n_gpus, [100, 200, 300, 400, 500, 600, 700, 800, 900], num_replicas=n_gpus, rank=rank, shuffle=True), persistent_workers=True, prefetch_factor=8)
|
| 762 |
-
|
| 763 |
-
net_g, net_d = Synthesizer(config.data.filter_length // 2 + 1, config.train.segment_size // config.data.hop_length, **config.model, use_f0=pitch_guidance, sr=sample_rate, vocoder=vocoder, checkpointing=checkpointing), MultiPeriodDiscriminator(version, config.model.use_spectral_norm, checkpointing=checkpointing)
|
| 764 |
-
net_g, net_d = (net_g.cuda(device_id), net_d.cuda(device_id)) if torch.cuda.is_available() else (net_g.to(device), net_d.to(device))
|
| 765 |
-
|
| 766 |
-
optim_g, optim_d = torch.optim.AdamW(net_g.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps), torch.optim.AdamW(net_d.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps)
|
| 767 |
-
net_g, net_d = (DDP(net_g, device_ids=[device_id]), DDP(net_d, device_ids=[device_id])) if torch.cuda.is_available() else (DDP(net_g), DDP(net_d))
|
| 768 |
-
|
| 769 |
-
try:
|
| 770 |
-
logger.info(translations["start_training"])
|
| 771 |
-
_, _, _, epoch_str = load_checkpoint((os.path.join(experiment_dir, "D_latest.pth") if save_only_latest else latest_checkpoint_path(experiment_dir, "D_*.pth")), net_d, optim_d)
|
| 772 |
-
_, _, _, epoch_str = load_checkpoint((os.path.join(experiment_dir, "G_latest.pth") if save_only_latest else latest_checkpoint_path(experiment_dir, "G_*.pth")), net_g, optim_g)
|
| 773 |
-
epoch_str += 1
|
| 774 |
-
global_step = (epoch_str - 1) * len(train_loader)
|
| 775 |
-
except:
|
| 776 |
-
epoch_str, global_step = 1, 0
|
| 777 |
-
|
| 778 |
-
if pretrainG != "" and pretrainG != "None":
|
| 779 |
-
if rank == 0:
|
| 780 |
-
verify_checkpoint_shapes(pretrainG, net_g)
|
| 781 |
-
logger.info(translations["import_pretrain"].format(dg="G", pretrain=pretrainG))
|
| 782 |
-
|
| 783 |
-
if hasattr(net_g, "module"): net_g.module.load_state_dict(torch.load(pretrainG, map_location="cpu")["model"])
|
| 784 |
-
else: net_g.load_state_dict(torch.load(pretrainG, map_location="cpu")["model"])
|
| 785 |
-
else: logger.warning(translations["not_using_pretrain"].format(dg="G"))
|
| 786 |
-
|
| 787 |
-
if pretrainD != "" and pretrainD != "None":
|
| 788 |
-
if rank == 0:
|
| 789 |
-
verify_checkpoint_shapes(pretrainD, net_d)
|
| 790 |
-
logger.info(translations["import_pretrain"].format(dg="D", pretrain=pretrainD))
|
| 791 |
-
|
| 792 |
-
if hasattr(net_d, "module"): net_d.module.load_state_dict(torch.load(pretrainD, map_location="cpu")["model"])
|
| 793 |
-
else: net_d.load_state_dict(torch.load(pretrainD, map_location="cpu")["model"])
|
| 794 |
-
else: logger.warning(translations["not_using_pretrain"].format(dg="D"))
|
| 795 |
-
|
| 796 |
-
scheduler_g, scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.train.lr_decay, last_epoch=epoch_str - 2), torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.train.lr_decay, last_epoch=epoch_str - 2)
|
| 797 |
-
optim_d.step(); optim_g.step()
|
| 798 |
-
|
| 799 |
-
scaler = GradScaler(enabled=main_config.is_half and device.type == "cuda")
|
| 800 |
-
cache = []
|
| 801 |
-
|
| 802 |
-
for info in train_loader:
|
| 803 |
-
phone, phone_lengths, pitch, pitchf, _, _, _, _, sid = info
|
| 804 |
-
reference = (phone.cuda(device_id, non_blocking=True), phone_lengths.cuda(device_id, non_blocking=True), (pitch.cuda(device_id, non_blocking=True) if pitch_guidance else None), (pitchf.cuda(device_id, non_blocking=True) if pitch_guidance else None), sid.cuda(device_id, non_blocking=True)) if device.type == "cuda" else (phone.to(device), phone_lengths.to(device), (pitch.to(device) if pitch_guidance else None), (pitchf.to(device) if pitch_guidance else None), sid.to(device))
|
| 805 |
-
break
|
| 806 |
-
|
| 807 |
-
for epoch in range(epoch_str, total_epoch + 1):
|
| 808 |
-
train_and_evaluate(rank, epoch, config, [net_g, net_d], [optim_g, optim_d], scaler, train_loader, writer_eval, cache, custom_save_every_weights, custom_total_epoch, device, device_id, reference, model_author, vocoder)
|
| 809 |
-
scheduler_g.step(); scheduler_d.step()
|
| 810 |
-
|
| 811 |
-
def train_and_evaluate(rank, epoch, hps, nets, optims, scaler, train_loader, writer, cache, custom_save_every_weights, custom_total_epoch, device, device_id, reference, model_author, vocoder):
|
| 812 |
-
global global_step, lowest_value, loss_disc, consecutive_increases_gen, consecutive_increases_disc
|
| 813 |
-
|
| 814 |
-
if epoch == 1:
|
| 815 |
-
lowest_value = {"step": 0, "value": float("inf"), "epoch": 0}
|
| 816 |
-
last_loss_gen_all, consecutive_increases_gen, consecutive_increases_disc = 0.0, 0, 0
|
| 817 |
-
|
| 818 |
-
net_g, net_d = nets
|
| 819 |
-
optim_g, optim_d = optims
|
| 820 |
-
train_loader.batch_sampler.set_epoch(epoch)
|
| 821 |
-
|
| 822 |
-
net_g.train(); net_d.train()
|
| 823 |
-
|
| 824 |
-
if device.type == "cuda" and cache_data_in_gpu:
|
| 825 |
-
data_iterator = cache
|
| 826 |
-
if cache == []:
|
| 827 |
-
for batch_idx, info in enumerate(train_loader):
|
| 828 |
-
cache.append((batch_idx, [tensor.cuda(device_id, non_blocking=True) for tensor in info]))
|
| 829 |
-
else: shuffle(cache)
|
| 830 |
-
else: data_iterator = enumerate(train_loader)
|
| 831 |
-
|
| 832 |
-
epoch_recorder = EpochRecorder()
|
| 833 |
-
with tqdm(total=len(train_loader), leave=False) as pbar:
|
| 834 |
-
for batch_idx, info in data_iterator:
|
| 835 |
-
if device.type == "cuda" and not cache_data_in_gpu: info = [tensor.cuda(device_id, non_blocking=True) for tensor in info]
|
| 836 |
-
elif device.type != "cuda": info = [tensor.to(device) for tensor in info]
|
| 837 |
-
|
| 838 |
-
phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, _, sid = info
|
| 839 |
-
pitch = pitch if pitch_guidance else None
|
| 840 |
-
pitchf = pitchf if pitch_guidance else None
|
| 841 |
-
|
| 842 |
-
with autocast(enabled=main_config.is_half and device.type == "cuda"):
|
| 843 |
-
y_hat, ids_slice, _, z_mask, (_, z_p, m_p, logs_p, _, logs_q) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
|
| 844 |
-
mel = spec_to_mel_torch(spec, config.data.filter_length, config.data.n_mel_channels, config.data.sample_rate, config.data.mel_fmin, config.data.mel_fmax)
|
| 845 |
-
y_mel = slice_segments(mel, ids_slice, config.train.segment_size // config.data.hop_length, dim=3)
|
| 846 |
-
|
| 847 |
-
with autocast(enabled=main_config.is_half and device.type == "cuda"):
|
| 848 |
-
y_hat_mel = mel_spectrogram_torch(y_hat.float().squeeze(1), config.data.filter_length, config.data.n_mel_channels, config.data.sample_rate, config.data.hop_length, config.data.win_length, config.data.mel_fmin, config.data.mel_fmax)
|
| 849 |
-
|
| 850 |
-
wave = slice_segments(wave, ids_slice * config.data.hop_length, config.train.segment_size, dim=3)
|
| 851 |
-
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
|
| 852 |
-
|
| 853 |
-
with autocast(enabled=main_config.is_half and device.type == "cuda"):
|
| 854 |
-
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
|
| 855 |
-
|
| 856 |
-
optim_d.zero_grad()
|
| 857 |
-
scaler.scale(loss_disc).backward()
|
| 858 |
-
scaler.unscale_(optim_d)
|
| 859 |
-
grad_norm_d = clip_grad_value(net_d.parameters(), None)
|
| 860 |
-
scaler.step(optim_d)
|
| 861 |
-
|
| 862 |
-
with autocast(enabled=main_config.is_half and device.type == "cuda"):
|
| 863 |
-
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
|
| 864 |
-
with autocast(enabled=main_config.is_half and device.type == "cuda"):
|
| 865 |
-
loss_mel = F.l1_loss(y_mel, y_hat_mel) * config.train.c_mel
|
| 866 |
-
loss_kl = (kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl)
|
| 867 |
-
loss_fm = feature_loss(fmap_r, fmap_g)
|
| 868 |
-
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
| 869 |
-
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
|
| 870 |
-
if loss_gen_all < lowest_value["value"]:
|
| 871 |
-
lowest_value["value"] = loss_gen_all
|
| 872 |
-
lowest_value["step"] = global_step
|
| 873 |
-
lowest_value["epoch"] = epoch
|
| 874 |
-
if epoch > lowest_value["epoch"]: logger.warning(translations["training_warning"])
|
| 875 |
-
|
| 876 |
-
optim_g.zero_grad()
|
| 877 |
-
scaler.scale(loss_gen_all).backward()
|
| 878 |
-
scaler.unscale_(optim_g)
|
| 879 |
-
grad_norm_g = clip_grad_value(net_g.parameters(), None)
|
| 880 |
-
scaler.step(optim_g)
|
| 881 |
-
scaler.update()
|
| 882 |
-
|
| 883 |
-
if rank == 0 and global_step % config.train.log_interval == 0:
|
| 884 |
-
if loss_mel > 75: loss_mel = 75
|
| 885 |
-
if loss_kl > 9: loss_kl = 9
|
| 886 |
-
|
| 887 |
-
scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc, "learning_rate": optim_g.param_groups[0]["lr"], "grad/norm_d": grad_norm_d, "grad/norm_g": grad_norm_g, "loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl}
|
| 888 |
-
scalar_dict.update({f"loss/g/{i}": v for i, v in enumerate(losses_gen)})
|
| 889 |
-
scalar_dict.update({f"loss/d_r/{i}": v for i, v in enumerate(losses_disc_r)})
|
| 890 |
-
scalar_dict.update({f"loss/d_g/{i}": v for i, v in enumerate(losses_disc_g)})
|
| 891 |
-
|
| 892 |
-
with torch.no_grad():
|
| 893 |
-
o, *_ = net_g.module.infer(*reference) if hasattr(net_g, "module") else net_g.infer(*reference)
|
| 894 |
-
|
| 895 |
-
summarize(writer=writer, global_step=global_step, images={"slice/mel_org": plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), "slice/mel_gen": plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), "all/mel": plot_spectrogram_to_numpy(mel[0].data.cpu().numpy())}, scalars=scalar_dict, audios={f"gen/audio_{global_step:07d}": o[0, :, :]}, audio_sample_rate=config.data.sample_rate)
|
| 896 |
-
|
| 897 |
-
global_step += 1
|
| 898 |
-
pbar.update(1)
|
| 899 |
-
|
| 900 |
-
def check_overtraining(smoothed_loss_history, threshold, epsilon=0.004):
|
| 901 |
-
if len(smoothed_loss_history) < threshold + 1: return False
|
| 902 |
-
for i in range(-threshold, -1):
|
| 903 |
-
if smoothed_loss_history[i + 1] > smoothed_loss_history[i]: return True
|
| 904 |
-
if abs(smoothed_loss_history[i + 1] - smoothed_loss_history[i]) >= epsilon: return False
|
| 905 |
-
return True
|
| 906 |
-
|
| 907 |
-
def update_exponential_moving_average(smoothed_loss_history, new_value, smoothing=0.987):
|
| 908 |
-
smoothed_value = new_value if not smoothed_loss_history else (smoothing * smoothed_loss_history[-1] + (1 - smoothing) * new_value)
|
| 909 |
-
smoothed_loss_history.append(smoothed_value)
|
| 910 |
-
return smoothed_value
|
| 911 |
-
|
| 912 |
-
def save_to_json(file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history):
|
| 913 |
-
with open(file_path, "w") as f:
|
| 914 |
-
json.dump({"loss_disc_history": loss_disc_history, "smoothed_loss_disc_history": smoothed_loss_disc_history, "loss_gen_history": loss_gen_history, "smoothed_loss_gen_history": smoothed_loss_gen_history}, f)
|
| 915 |
-
|
| 916 |
-
model_add, model_del = [], []
|
| 917 |
-
done = False
|
| 918 |
-
|
| 919 |
-
if rank == 0:
|
| 920 |
-
if epoch % save_every_epoch == False:
|
| 921 |
-
checkpoint_suffix = f"{'latest' if save_only_latest else global_step}.pth"
|
| 922 |
-
save_checkpoint(net_g, optim_g, config.train.learning_rate, epoch, os.path.join(experiment_dir, "G_" + checkpoint_suffix))
|
| 923 |
-
save_checkpoint(net_d, optim_d, config.train.learning_rate, epoch, os.path.join(experiment_dir, "D_" + checkpoint_suffix))
|
| 924 |
-
if custom_save_every_weights: model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s.pth"))
|
| 925 |
-
|
| 926 |
-
if overtraining_detector and epoch > 1:
|
| 927 |
-
current_loss_disc = float(loss_disc)
|
| 928 |
-
loss_disc_history.append(current_loss_disc)
|
| 929 |
-
smoothed_value_disc = update_exponential_moving_average(smoothed_loss_disc_history, current_loss_disc)
|
| 930 |
-
is_overtraining_disc = check_overtraining(smoothed_loss_disc_history, overtraining_threshold * 2)
|
| 931 |
-
|
| 932 |
-
if is_overtraining_disc: consecutive_increases_disc += 1
|
| 933 |
-
else: consecutive_increases_disc = 0
|
| 934 |
-
|
| 935 |
-
current_loss_gen = float(lowest_value["value"])
|
| 936 |
-
loss_gen_history.append(current_loss_gen)
|
| 937 |
-
smoothed_value_gen = update_exponential_moving_average(smoothed_loss_gen_history, current_loss_gen)
|
| 938 |
-
is_overtraining_gen = check_overtraining(smoothed_loss_gen_history, overtraining_threshold, 0.01)
|
| 939 |
-
|
| 940 |
-
if is_overtraining_gen: consecutive_increases_gen += 1
|
| 941 |
-
else: consecutive_increases_gen = 0
|
| 942 |
-
|
| 943 |
-
if epoch % save_every_epoch == 0: save_to_json(training_file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history)
|
| 944 |
-
|
| 945 |
-
if (is_overtraining_gen and consecutive_increases_gen == overtraining_threshold or is_overtraining_disc and consecutive_increases_disc == (overtraining_threshold * 2)):
|
| 946 |
-
logger.info(translations["overtraining_find"].format(epoch=epoch, smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
|
| 947 |
-
done = True
|
| 948 |
-
else:
|
| 949 |
-
logger.info(translations["best_epoch"].format(epoch=epoch, smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
|
| 950 |
-
for file in glob.glob(os.path.join("assets", "weights", f"{model_name}_*e_*s_best_epoch.pth")):
|
| 951 |
-
model_del.append(file)
|
| 952 |
-
|
| 953 |
-
model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s_best_epoch.pth"))
|
| 954 |
-
|
| 955 |
-
if epoch >= custom_total_epoch:
|
| 956 |
-
logger.info(translations["success_training"].format(epoch=epoch, global_step=global_step, loss_gen_all=round(loss_gen_all.item(), 3)))
|
| 957 |
-
logger.info(translations["training_info"].format(lowest_value_rounded=round(float(lowest_value["value"]), 3), lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step']))
|
| 958 |
-
|
| 959 |
-
pid_file_path = os.path.join(experiment_dir, "config.json")
|
| 960 |
-
with open(pid_file_path, "r") as pid_file:
|
| 961 |
-
pid_data = json.load(pid_file)
|
| 962 |
-
|
| 963 |
-
with open(pid_file_path, "w") as pid_file:
|
| 964 |
-
pid_data.pop("process_pids", None)
|
| 965 |
-
json.dump(pid_data, pid_file, indent=4)
|
| 966 |
-
|
| 967 |
-
if os.path.exists(os.path.join(experiment_dir, "train_pid.txt")): os.remove(os.path.join(experiment_dir, "train_pid.txt"))
|
| 968 |
-
model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s.pth"))
|
| 969 |
-
done = True
|
| 970 |
-
|
| 971 |
-
for m in model_del:
|
| 972 |
-
os.remove(m)
|
| 973 |
-
|
| 974 |
-
if model_add:
|
| 975 |
-
ckpt = (net_g.module.state_dict() if hasattr(net_g, "module") else net_g.state_dict())
|
| 976 |
-
for m in model_add:
|
| 977 |
-
extract_model(ckpt=ckpt, sr=sample_rate, pitch_guidance=pitch_guidance == True, name=model_name, model_path=m, epoch=epoch, step=global_step, version=version, hps=hps, model_author=model_author, vocoder=vocoder)
|
| 978 |
-
|
| 979 |
-
lowest_value_rounded = round(float(lowest_value["value"]), 3)
|
| 980 |
-
|
| 981 |
-
if epoch > 1 and overtraining_detector: logger.info(translations["model_training_info"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record(), lowest_value_rounded=lowest_value_rounded, lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step'], remaining_epochs_gen=(overtraining_threshold - consecutive_increases_gen), remaining_epochs_disc=((overtraining_threshold * 2) - consecutive_increases_disc), smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
|
| 982 |
-
elif epoch > 1 and overtraining_detector == False: logger.info(translations["model_training_info_2"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record(), lowest_value_rounded=lowest_value_rounded, lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step']))
|
| 983 |
-
else: logger.info(translations["model_training_info_3"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record()))
|
| 984 |
-
|
| 985 |
-
last_loss_gen_all = loss_gen_all
|
| 986 |
-
if done: os._exit(0)
|
| 987 |
-
|
| 988 |
-
if __name__ == "__main__":
|
| 989 |
-
torch.multiprocessing.set_start_method("spawn")
|
| 990 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/commons.py
DELETED
|
@@ -1,60 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
def init_weights(m, mean=0.0, std=0.01):
|
| 6 |
-
if m.__class__.__name__.find("Conv") != -1: m.weight.data.normal_(mean, std)
|
| 7 |
-
|
| 8 |
-
def get_padding(kernel_size, dilation=1):
|
| 9 |
-
return int((kernel_size * dilation - dilation) / 2)
|
| 10 |
-
|
| 11 |
-
def convert_pad_shape(pad_shape):
|
| 12 |
-
return [item for sublist in pad_shape[::-1] for item in sublist]
|
| 13 |
-
|
| 14 |
-
def slice_segments(x, ids_str, segment_size = 4, dim = 2):
|
| 15 |
-
if dim == 2: ret = torch.zeros_like(x[:, :segment_size])
|
| 16 |
-
elif dim == 3: ret = torch.zeros_like(x[:, :, :segment_size])
|
| 17 |
-
|
| 18 |
-
for i in range(x.size(0)):
|
| 19 |
-
idx_str = ids_str[i].item()
|
| 20 |
-
idx_end = idx_str + segment_size
|
| 21 |
-
|
| 22 |
-
if dim == 2: ret[i] = x[i, idx_str:idx_end]
|
| 23 |
-
else: ret[i] = x[i, :, idx_str:idx_end]
|
| 24 |
-
|
| 25 |
-
return ret
|
| 26 |
-
|
| 27 |
-
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
| 28 |
-
b, _, t = x.size()
|
| 29 |
-
if x_lengths is None: x_lengths = t
|
| 30 |
-
|
| 31 |
-
ids_str = (torch.rand([b]).to(device=x.device) * (x_lengths - segment_size + 1)).to(dtype=torch.long)
|
| 32 |
-
|
| 33 |
-
return slice_segments(x, ids_str, segment_size, dim=3), ids_str
|
| 34 |
-
|
| 35 |
-
@torch.jit.script
|
| 36 |
-
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
| 37 |
-
n_channels_int = n_channels[0]
|
| 38 |
-
|
| 39 |
-
in_act = input_a + input_b
|
| 40 |
-
|
| 41 |
-
return torch.tanh(in_act[:, :n_channels_int, :]) * torch.sigmoid(in_act[:, n_channels_int:, :])
|
| 42 |
-
|
| 43 |
-
def sequence_mask(length, max_length = None):
|
| 44 |
-
if max_length is None: max_length = length.max()
|
| 45 |
-
|
| 46 |
-
return torch.arange(max_length, dtype=length.dtype, device=length.device).unsqueeze(0) < length.unsqueeze(1)
|
| 47 |
-
|
| 48 |
-
def clip_grad_value(parameters, clip_value, norm_type=2):
|
| 49 |
-
if isinstance(parameters, torch.Tensor): parameters = [parameters]
|
| 50 |
-
norm_type = float(norm_type)
|
| 51 |
-
|
| 52 |
-
if clip_value is not None: clip_value = float(clip_value)
|
| 53 |
-
total_norm = 0
|
| 54 |
-
|
| 55 |
-
for p in list(filter(lambda p: p.grad is not None, parameters)):
|
| 56 |
-
total_norm += (p.grad.data.norm(norm_type)).item() ** norm_type
|
| 57 |
-
|
| 58 |
-
if clip_value is not None: p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
| 59 |
-
|
| 60 |
-
return total_norm ** (1.0 / norm_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/modules.py
DELETED
|
@@ -1,60 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
sys.path.append(os.getcwd())
|
| 6 |
-
|
| 7 |
-
from .commons import fused_add_tanh_sigmoid_multiply
|
| 8 |
-
|
| 9 |
-
class WaveNet(torch.nn.Module):
|
| 10 |
-
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
|
| 11 |
-
super(WaveNet, self).__init__()
|
| 12 |
-
assert kernel_size % 2 == 1
|
| 13 |
-
self.hidden_channels = hidden_channels
|
| 14 |
-
self.kernel_size = (kernel_size,)
|
| 15 |
-
self.dilation_rate = dilation_rate
|
| 16 |
-
self.n_layers = n_layers
|
| 17 |
-
self.gin_channels = gin_channels
|
| 18 |
-
self.p_dropout = p_dropout
|
| 19 |
-
self.in_layers = torch.nn.ModuleList()
|
| 20 |
-
self.res_skip_layers = torch.nn.ModuleList()
|
| 21 |
-
self.drop = torch.nn.Dropout(p_dropout)
|
| 22 |
-
if gin_channels != 0: self.cond_layer = torch.nn.utils.parametrizations.weight_norm(torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1), name="weight")
|
| 23 |
-
dilations = [dilation_rate ** i for i in range(n_layers)]
|
| 24 |
-
paddings = [(kernel_size * d - d) // 2 for d in dilations]
|
| 25 |
-
|
| 26 |
-
for i in range(n_layers):
|
| 27 |
-
in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilations[i], padding=paddings[i])
|
| 28 |
-
in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
|
| 29 |
-
self.in_layers.append(in_layer)
|
| 30 |
-
res_skip_channels = (hidden_channels if i == n_layers - 1 else 2 * hidden_channels)
|
| 31 |
-
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
| 32 |
-
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
|
| 33 |
-
self.res_skip_layers.append(res_skip_layer)
|
| 34 |
-
|
| 35 |
-
def forward(self, x, x_mask, g=None):
|
| 36 |
-
output = x.clone().zero_()
|
| 37 |
-
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
| 38 |
-
|
| 39 |
-
if g is not None: g = self.cond_layer(g)
|
| 40 |
-
|
| 41 |
-
for i in range(self.n_layers):
|
| 42 |
-
x_in = self.in_layers[i](x)
|
| 43 |
-
g_l = (g[:, i * 2 * self.hidden_channels : (i + 1) * 2 * self.hidden_channels, :] if g is not None else 0)
|
| 44 |
-
res_skip_acts = self.res_skip_layers[i](self.drop(fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)))
|
| 45 |
-
|
| 46 |
-
if i < self.n_layers - 1:
|
| 47 |
-
x = (x + (res_skip_acts[:, : self.hidden_channels, :])) * x_mask
|
| 48 |
-
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
| 49 |
-
else: output = output + res_skip_acts
|
| 50 |
-
|
| 51 |
-
return output * x_mask
|
| 52 |
-
|
| 53 |
-
def remove_weight_norm(self):
|
| 54 |
-
if self.gin_channels != 0: torch.nn.utils.remove_weight_norm(self.cond_layer)
|
| 55 |
-
|
| 56 |
-
for l in self.in_layers:
|
| 57 |
-
torch.nn.utils.remove_weight_norm(l)
|
| 58 |
-
|
| 59 |
-
for l in self.res_skip_layers:
|
| 60 |
-
torch.nn.utils.remove_weight_norm(l)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/mrf_hifigan.py
DELETED
|
@@ -1,150 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
import torch
|
| 3 |
-
|
| 4 |
-
import numpy as np
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
|
| 8 |
-
from torch.nn.utils import remove_weight_norm
|
| 9 |
-
from torch.utils.checkpoint import checkpoint
|
| 10 |
-
from torch.nn.utils.parametrizations import weight_norm
|
| 11 |
-
|
| 12 |
-
LRELU_SLOPE = 0.1
|
| 13 |
-
|
| 14 |
-
class MRFLayer(nn.Module):
|
| 15 |
-
def __init__(self, channels, kernel_size, dilation):
|
| 16 |
-
super().__init__()
|
| 17 |
-
self.conv1 = weight_norm(nn.Conv1d(channels, channels, kernel_size, padding=(kernel_size * dilation - dilation) // 2, dilation=dilation))
|
| 18 |
-
self.conv2 = weight_norm(nn.Conv1d(channels, channels, kernel_size, padding=kernel_size // 2, dilation=1))
|
| 19 |
-
|
| 20 |
-
def forward(self, x):
|
| 21 |
-
return x + self.conv2(F.leaky_relu(self.conv1(F.leaky_relu(x, LRELU_SLOPE)), LRELU_SLOPE))
|
| 22 |
-
|
| 23 |
-
def remove_weight_norm(self):
|
| 24 |
-
remove_weight_norm(self.conv1)
|
| 25 |
-
remove_weight_norm(self.conv2)
|
| 26 |
-
|
| 27 |
-
class MRFBlock(nn.Module):
|
| 28 |
-
def __init__(self, channels, kernel_size, dilations):
|
| 29 |
-
super().__init__()
|
| 30 |
-
self.layers = nn.ModuleList()
|
| 31 |
-
|
| 32 |
-
for dilation in dilations:
|
| 33 |
-
self.layers.append(MRFLayer(channels, kernel_size, dilation))
|
| 34 |
-
|
| 35 |
-
def forward(self, x):
|
| 36 |
-
for layer in self.layers:
|
| 37 |
-
x = layer(x)
|
| 38 |
-
|
| 39 |
-
return x
|
| 40 |
-
|
| 41 |
-
def remove_weight_norm(self):
|
| 42 |
-
for layer in self.layers:
|
| 43 |
-
layer.remove_weight_norm()
|
| 44 |
-
|
| 45 |
-
class SineGenerator(nn.Module):
|
| 46 |
-
def __init__(self, samp_rate, harmonic_num = 0, sine_amp = 0.1, noise_std = 0.003, voiced_threshold = 0):
|
| 47 |
-
super(SineGenerator, self).__init__()
|
| 48 |
-
self.sine_amp = sine_amp
|
| 49 |
-
self.noise_std = noise_std
|
| 50 |
-
self.harmonic_num = harmonic_num
|
| 51 |
-
self.dim = self.harmonic_num + 1
|
| 52 |
-
self.sampling_rate = samp_rate
|
| 53 |
-
self.voiced_threshold = voiced_threshold
|
| 54 |
-
|
| 55 |
-
def _f02uv(self, f0):
|
| 56 |
-
return torch.ones_like(f0) * (f0 > self.voiced_threshold)
|
| 57 |
-
|
| 58 |
-
def _f02sine(self, f0_values):
|
| 59 |
-
rad_values = (f0_values / self.sampling_rate) % 1
|
| 60 |
-
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], dtype=f0_values.dtype, device=f0_values.device)
|
| 61 |
-
rand_ini[:, 0] = 0
|
| 62 |
-
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
| 63 |
-
tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
| 64 |
-
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
|
| 65 |
-
cumsum_shift = torch.zeros_like(rad_values)
|
| 66 |
-
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
| 67 |
-
|
| 68 |
-
return torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
|
| 69 |
-
|
| 70 |
-
def forward(self, f0):
|
| 71 |
-
with torch.no_grad():
|
| 72 |
-
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, dtype=f0.dtype, device=f0.device)
|
| 73 |
-
f0_buf[:, :, 0] = f0[:, :, 0]
|
| 74 |
-
|
| 75 |
-
for idx in np.arange(self.harmonic_num):
|
| 76 |
-
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
|
| 77 |
-
|
| 78 |
-
sine_waves = self._f02sine(f0_buf) * self.sine_amp
|
| 79 |
-
uv = self._f02uv(f0)
|
| 80 |
-
sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
|
| 81 |
-
|
| 82 |
-
return sine_waves
|
| 83 |
-
|
| 84 |
-
class SourceModuleHnNSF(nn.Module):
|
| 85 |
-
def __init__(self, sampling_rate, harmonic_num = 0, sine_amp = 0.1, add_noise_std = 0.003, voiced_threshold = 0):
|
| 86 |
-
super(SourceModuleHnNSF, self).__init__()
|
| 87 |
-
self.sine_amp = sine_amp
|
| 88 |
-
self.noise_std = add_noise_std
|
| 89 |
-
self.l_sin_gen = SineGenerator(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshold)
|
| 90 |
-
self.l_linear = nn.Linear(harmonic_num + 1, 1)
|
| 91 |
-
self.l_tanh = nn.Tanh()
|
| 92 |
-
|
| 93 |
-
def forward(self, x):
|
| 94 |
-
return self.l_tanh(self.l_linear(self.l_sin_gen(x).to(dtype=self.l_linear.weight.dtype)))
|
| 95 |
-
|
| 96 |
-
class HiFiGANMRFGenerator(nn.Module):
|
| 97 |
-
def __init__(self, in_channel, upsample_initial_channel, upsample_rates, upsample_kernel_sizes, resblock_kernel_sizes, resblock_dilations, gin_channels, sample_rate, harmonic_num, checkpointing = False):
|
| 98 |
-
super().__init__()
|
| 99 |
-
self.num_kernels = len(resblock_kernel_sizes)
|
| 100 |
-
self.checkpointing = checkpointing
|
| 101 |
-
self.f0_upsample = nn.Upsample(scale_factor=np.prod(upsample_rates))
|
| 102 |
-
self.m_source = SourceModuleHnNSF(sample_rate, harmonic_num)
|
| 103 |
-
self.conv_pre = weight_norm(nn.Conv1d(in_channel, upsample_initial_channel, kernel_size=7, stride=1, padding=3))
|
| 104 |
-
self.upsamples = nn.ModuleList()
|
| 105 |
-
self.noise_convs = nn.ModuleList()
|
| 106 |
-
stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1 for i in range(len(upsample_rates))]
|
| 107 |
-
|
| 108 |
-
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 109 |
-
self.upsamples.append(weight_norm(nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), kernel_size=k, stride=u, padding=((k - u) // 2) if u % 2 == 0 else (u // 2 + u % 2), output_padding=u % 2)))
|
| 110 |
-
stride = stride_f0s[i]
|
| 111 |
-
kernel = 1 if stride == 1 else stride * 2 - stride % 2
|
| 112 |
-
self.noise_convs.append(nn.Conv1d(1, upsample_initial_channel // (2 ** (i + 1)), kernel_size=kernel, stride=stride, padding=0 if stride == 1 else (kernel - stride) // 2))
|
| 113 |
-
|
| 114 |
-
self.mrfs = nn.ModuleList()
|
| 115 |
-
for i in range(len(self.upsamples)):
|
| 116 |
-
channel = upsample_initial_channel // (2 ** (i + 1))
|
| 117 |
-
self.mrfs.append(nn.ModuleList([MRFBlock(channel, kernel_size=k, dilations=d) for k, d in zip(resblock_kernel_sizes, resblock_dilations)]))
|
| 118 |
-
|
| 119 |
-
self.conv_post = weight_norm(nn.Conv1d(channel, 1, kernel_size=7, stride=1, padding=3))
|
| 120 |
-
if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
| 121 |
-
|
| 122 |
-
def forward(self, x, f0, g = None):
|
| 123 |
-
har_source = self.m_source(self.f0_upsample(f0[:, None, :]).transpose(-1, -2)).transpose(-1, -2)
|
| 124 |
-
x = self.conv_pre(x)
|
| 125 |
-
if g is not None: x += self.cond(g)
|
| 126 |
-
|
| 127 |
-
for ups, mrf, noise_conv in zip(self.upsamples, self.mrfs, self.noise_convs):
|
| 128 |
-
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 129 |
-
|
| 130 |
-
if self.training and self.checkpointing:
|
| 131 |
-
x = checkpoint(ups, x, use_reentrant=False) + noise_conv(har_source)
|
| 132 |
-
xs = sum([checkpoint(layer, x, use_reentrant=False) for layer in mrf])
|
| 133 |
-
else:
|
| 134 |
-
x = ups(x) + noise_conv(har_source)
|
| 135 |
-
xs = sum([layer(x) for layer in mrf])
|
| 136 |
-
|
| 137 |
-
x = xs / self.num_kernels
|
| 138 |
-
|
| 139 |
-
return torch.tanh(self.conv_post(F.leaky_relu(x)))
|
| 140 |
-
|
| 141 |
-
def remove_weight_norm(self):
|
| 142 |
-
remove_weight_norm(self.conv_pre)
|
| 143 |
-
|
| 144 |
-
for up in self.upsamples:
|
| 145 |
-
remove_weight_norm(up)
|
| 146 |
-
|
| 147 |
-
for mrf in self.mrfs:
|
| 148 |
-
mrf.remove_weight_norm()
|
| 149 |
-
|
| 150 |
-
remove_weight_norm(self.conv_post)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/onnx_export.py
DELETED
|
@@ -1,50 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import io
|
| 3 |
-
import sys
|
| 4 |
-
import onnx
|
| 5 |
-
import json
|
| 6 |
-
import torch
|
| 7 |
-
import onnxsim
|
| 8 |
-
import warnings
|
| 9 |
-
|
| 10 |
-
sys.path.append(os.getcwd())
|
| 11 |
-
|
| 12 |
-
from main.library.algorithm.synthesizers import SynthesizerONNX
|
| 13 |
-
|
| 14 |
-
warnings.filterwarnings("ignore")
|
| 15 |
-
|
| 16 |
-
def onnx_exporter(input_path, output_path, is_half=False, device="cpu"):
|
| 17 |
-
cpt = (torch.load(input_path, map_location="cpu") if os.path.isfile(input_path) else None)
|
| 18 |
-
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
|
| 19 |
-
|
| 20 |
-
model_name, model_author, epochs, steps, version, f0, model_hash, vocoder, creation_date = cpt.get("model_name", None), cpt.get("author", None), cpt.get("epoch", None), cpt.get("step", None), cpt.get("version", "v1"), cpt.get("f0", 1), cpt.get("model_hash", None), cpt.get("vocoder", "Default"), cpt.get("creation_date", None)
|
| 21 |
-
text_enc_hidden_dim = 768 if version == "v2" else 256
|
| 22 |
-
tgt_sr = cpt["config"][-1]
|
| 23 |
-
|
| 24 |
-
net_g = SynthesizerONNX(*cpt["config"], use_f0=f0, text_enc_hidden_dim=text_enc_hidden_dim, vocoder=vocoder, checkpointing=False)
|
| 25 |
-
net_g.load_state_dict(cpt["weight"], strict=False)
|
| 26 |
-
net_g.eval().to(device)
|
| 27 |
-
net_g = (net_g.half() if is_half else net_g.float())
|
| 28 |
-
|
| 29 |
-
phone = torch.rand(1, 200, text_enc_hidden_dim).to(device)
|
| 30 |
-
phone_length = torch.tensor([200]).long().to(device)
|
| 31 |
-
ds = torch.LongTensor([0]).to(device)
|
| 32 |
-
rnd = torch.rand(1, 192, 200).to(device)
|
| 33 |
-
|
| 34 |
-
if f0:
|
| 35 |
-
args = (phone, phone_length, ds, rnd, torch.randint(size=(1, 200), low=5, high=255).to(device), torch.rand(1, 200).to(device))
|
| 36 |
-
input_names = ["phone", "phone_lengths", "ds", "rnd", "pitch", "pitchf"]
|
| 37 |
-
dynamic_axes = {"phone": [1], "rnd": [2], "pitch": [1], "pitchf": [1]}
|
| 38 |
-
else:
|
| 39 |
-
args = (phone, phone_length, ds, rnd)
|
| 40 |
-
input_names = ["phone", "phone_lengths", "ds", "rnd"]
|
| 41 |
-
dynamic_axes = {"phone": [1], "rnd": [2]}
|
| 42 |
-
|
| 43 |
-
with io.BytesIO() as model:
|
| 44 |
-
torch.onnx.export(net_g, args, model, do_constant_folding=True, opset_version=17, verbose=False, input_names=input_names, output_names=["audio"], dynamic_axes=dynamic_axes)
|
| 45 |
-
|
| 46 |
-
model, _ = onnxsim.simplify(onnx.load_model_from_string(model.getvalue()))
|
| 47 |
-
model.metadata_props.append(onnx.StringStringEntryProto(key="model_info", value=json.dumps({"model_name": model_name, "author": model_author, "epoch": epochs, "step": steps, "version": version, "sr": tgt_sr, "f0": f0, "model_hash": model_hash, "creation_date": creation_date, "vocoder": vocoder, "text_enc_hidden_dim": text_enc_hidden_dim})))
|
| 48 |
-
|
| 49 |
-
onnx.save(model, output_path)
|
| 50 |
-
return output_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/refinegan.py
DELETED
|
@@ -1,170 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import math
|
| 4 |
-
import torch
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch.nn as nn
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
from torch.utils.checkpoint import checkpoint
|
| 11 |
-
from torch.nn.utils import remove_weight_norm
|
| 12 |
-
from torch.nn.utils.parametrizations import weight_norm
|
| 13 |
-
|
| 14 |
-
sys.path.append(os.getcwd())
|
| 15 |
-
|
| 16 |
-
from main.library.algorithm.commons import init_weights, get_padding
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class ResBlock(nn.Module):
|
| 20 |
-
def __init__(self, channels, kernel_size = 7, dilation = (1, 3, 5), leaky_relu_slope = 0.2):
|
| 21 |
-
super().__init__()
|
| 22 |
-
self.leaky_relu_slope = leaky_relu_slope
|
| 23 |
-
self.convs1 = nn.ModuleList([weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=d, padding=get_padding(kernel_size, d))) for d in dilation])
|
| 24 |
-
self.convs1.apply(init_weights)
|
| 25 |
-
self.convs2 = nn.ModuleList([weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1, padding=get_padding(kernel_size, 1))) for _ in dilation])
|
| 26 |
-
self.convs2.apply(init_weights)
|
| 27 |
-
|
| 28 |
-
def forward(self, x):
|
| 29 |
-
for c1, c2 in zip(self.convs1, self.convs2):
|
| 30 |
-
x = c2(F.leaky_relu(c1(F.leaky_relu(x, self.leaky_relu_slope)), self.leaky_relu_slope)) + x
|
| 31 |
-
|
| 32 |
-
return x
|
| 33 |
-
|
| 34 |
-
def remove_weight_norm(self):
|
| 35 |
-
for c1, c2 in zip(self.convs1, self.convs2):
|
| 36 |
-
remove_weight_norm(c1)
|
| 37 |
-
remove_weight_norm(c2)
|
| 38 |
-
|
| 39 |
-
class AdaIN(nn.Module):
|
| 40 |
-
def __init__(self, *, channels, leaky_relu_slope = 0.2):
|
| 41 |
-
super().__init__()
|
| 42 |
-
self.weight = nn.Parameter(torch.ones(channels))
|
| 43 |
-
self.activation = nn.LeakyReLU(leaky_relu_slope)
|
| 44 |
-
|
| 45 |
-
def forward(self, x):
|
| 46 |
-
return self.activation(x + (torch.randn_like(x) * self.weight[None, :, None]))
|
| 47 |
-
|
| 48 |
-
class ParallelResBlock(nn.Module):
|
| 49 |
-
def __init__(self, *, in_channels, out_channels, kernel_sizes = (3, 7, 11), dilation = (1, 3, 5), leaky_relu_slope = 0.2):
|
| 50 |
-
super().__init__()
|
| 51 |
-
self.in_channels = in_channels
|
| 52 |
-
self.out_channels = out_channels
|
| 53 |
-
self.input_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=1, padding=3)
|
| 54 |
-
self.input_conv.apply(init_weights)
|
| 55 |
-
self.blocks = nn.ModuleList([nn.Sequential(AdaIN(channels=out_channels), ResBlock(out_channels, kernel_size=kernel_size, dilation=dilation, leaky_relu_slope=leaky_relu_slope), AdaIN(channels=out_channels)) for kernel_size in kernel_sizes])
|
| 56 |
-
|
| 57 |
-
def forward(self, x):
|
| 58 |
-
x = self.input_conv(x)
|
| 59 |
-
return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
|
| 60 |
-
|
| 61 |
-
def remove_weight_norm(self):
|
| 62 |
-
remove_weight_norm(self.input_conv)
|
| 63 |
-
for block in self.blocks:
|
| 64 |
-
block[1].remove_weight_norm()
|
| 65 |
-
|
| 66 |
-
class SineGenerator(nn.Module):
|
| 67 |
-
def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0):
|
| 68 |
-
super(SineGenerator, self).__init__()
|
| 69 |
-
self.sine_amp = sine_amp
|
| 70 |
-
self.noise_std = noise_std
|
| 71 |
-
self.harmonic_num = harmonic_num
|
| 72 |
-
self.dim = self.harmonic_num + 1
|
| 73 |
-
self.sampling_rate = samp_rate
|
| 74 |
-
self.voiced_threshold = voiced_threshold
|
| 75 |
-
self.merge = nn.Sequential(nn.Linear(self.dim, 1, bias=False), nn.Tanh())
|
| 76 |
-
|
| 77 |
-
def _f02uv(self, f0):
|
| 78 |
-
return torch.ones_like(f0) * (f0 > self.voiced_threshold)
|
| 79 |
-
|
| 80 |
-
def _f02sine(self, f0_values):
|
| 81 |
-
rad_values = (f0_values / self.sampling_rate) % 1
|
| 82 |
-
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], dtype=f0_values.dtype, device=f0_values.device)
|
| 83 |
-
|
| 84 |
-
rand_ini[:, 0] = 0
|
| 85 |
-
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
| 86 |
-
|
| 87 |
-
tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
| 88 |
-
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
|
| 89 |
-
|
| 90 |
-
cumsum_shift = torch.zeros_like(rad_values)
|
| 91 |
-
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
| 92 |
-
|
| 93 |
-
return torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
|
| 94 |
-
|
| 95 |
-
def forward(self, f0):
|
| 96 |
-
with torch.no_grad():
|
| 97 |
-
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, dtype=f0.dtype, device=f0.device)
|
| 98 |
-
f0_buf[:, :, 0] = f0[:, :, 0]
|
| 99 |
-
|
| 100 |
-
for idx in np.arange(self.harmonic_num):
|
| 101 |
-
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
|
| 102 |
-
|
| 103 |
-
sine_waves = self._f02sine(f0_buf) * self.sine_amp
|
| 104 |
-
uv = self._f02uv(f0)
|
| 105 |
-
sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
|
| 106 |
-
|
| 107 |
-
return self.merge(sine_waves)
|
| 108 |
-
|
| 109 |
-
class RefineGANGenerator(nn.Module):
|
| 110 |
-
def __init__(self, *, sample_rate = 44100, upsample_rates = (8, 8, 2, 2), leaky_relu_slope = 0.2, num_mels = 128, gin_channels = 256, checkpointing = False, upsample_initial_channel = 512):
|
| 111 |
-
super().__init__()
|
| 112 |
-
self.upsample_rates = upsample_rates
|
| 113 |
-
self.checkpointing = checkpointing
|
| 114 |
-
self.leaky_relu_slope = leaky_relu_slope
|
| 115 |
-
self.upp = np.prod(upsample_rates)
|
| 116 |
-
self.m_source = SineGenerator(sample_rate)
|
| 117 |
-
self.pre_conv = weight_norm(nn.Conv1d(1, upsample_initial_channel // 2, 7, 1, padding=3))
|
| 118 |
-
stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1 for i in range(len(upsample_rates))]
|
| 119 |
-
|
| 120 |
-
channels = upsample_initial_channel
|
| 121 |
-
self.downsample_blocks = nn.ModuleList([])
|
| 122 |
-
|
| 123 |
-
for i, _ in enumerate(upsample_rates):
|
| 124 |
-
stride = stride_f0s[i]
|
| 125 |
-
kernel = 1 if stride == 1 else stride * 2 - stride % 2
|
| 126 |
-
|
| 127 |
-
self.downsample_blocks.append(weight_norm(nn.Conv1d(1, channels // 2 ** (i + 2), kernel, stride, padding=0 if stride == 1 else (kernel - stride) // 2)))
|
| 128 |
-
|
| 129 |
-
self.mel_conv = weight_norm(nn.Conv1d(num_mels, channels // 2, 7, 1, padding=3))
|
| 130 |
-
self.mel_conv.apply(init_weights)
|
| 131 |
-
|
| 132 |
-
if gin_channels != 0: self.cond = nn.Conv1d(256, channels // 2, 1)
|
| 133 |
-
|
| 134 |
-
self.upsample_blocks = nn.ModuleList([])
|
| 135 |
-
self.upsample_conv_blocks = nn.ModuleList([])
|
| 136 |
-
|
| 137 |
-
for rate in upsample_rates:
|
| 138 |
-
new_channels = channels // 2
|
| 139 |
-
self.upsample_blocks.append(nn.Upsample(scale_factor=rate, mode="linear"))
|
| 140 |
-
self.upsample_conv_blocks.append(ParallelResBlock(in_channels=channels + channels // 4, out_channels=new_channels, kernel_sizes=(3, 7, 11), dilation=(1, 3, 5), leaky_relu_slope=leaky_relu_slope))
|
| 141 |
-
channels = new_channels
|
| 142 |
-
|
| 143 |
-
self.conv_post = weight_norm(nn.Conv1d(channels, 1, 7, 1, padding=3, bias=False))
|
| 144 |
-
self.conv_post.apply(init_weights)
|
| 145 |
-
|
| 146 |
-
def forward(self, mel, f0, g = None):
|
| 147 |
-
har_source = self.m_source(F.interpolate(f0.unsqueeze(1), size=mel.shape[-1] * self.upp, mode="linear").transpose(1, 2)).transpose(1, 2)
|
| 148 |
-
x = F.interpolate(self.pre_conv(har_source), size=mel.shape[-1], mode="linear")
|
| 149 |
-
|
| 150 |
-
mel = self.mel_conv(mel)
|
| 151 |
-
if g is not None: mel += self.cond(g)
|
| 152 |
-
|
| 153 |
-
x = torch.cat([mel, x], dim=1)
|
| 154 |
-
|
| 155 |
-
for ups, res, down in zip(self.upsample_blocks, self.upsample_conv_blocks, self.downsample_blocks):
|
| 156 |
-
x = F.leaky_relu(x, self.leaky_relu_slope)
|
| 157 |
-
x = checkpoint(res, torch.cat([checkpoint(ups, x, use_reentrant=False), down(har_source)], dim=1), use_reentrant=False) if self.training and self.checkpointing else res(torch.cat([ups(x), down(har_source)], dim=1))
|
| 158 |
-
|
| 159 |
-
return torch.tanh(self.conv_post(F.leaky_relu(x, self.leaky_relu_slope)))
|
| 160 |
-
|
| 161 |
-
def remove_weight_norm(self):
|
| 162 |
-
remove_weight_norm(self.pre_conv)
|
| 163 |
-
remove_weight_norm(self.mel_conv)
|
| 164 |
-
remove_weight_norm(self.conv_post)
|
| 165 |
-
|
| 166 |
-
for block in self.downsample_blocks:
|
| 167 |
-
block.remove_weight_norm()
|
| 168 |
-
|
| 169 |
-
for block in self.upsample_conv_blocks:
|
| 170 |
-
block.remove_weight_norm()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/residuals.py
DELETED
|
@@ -1,140 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
from torch.nn.utils import remove_weight_norm
|
| 6 |
-
from torch.nn.utils.parametrizations import weight_norm
|
| 7 |
-
|
| 8 |
-
sys.path.append(os.getcwd())
|
| 9 |
-
|
| 10 |
-
from .modules import WaveNet
|
| 11 |
-
from .commons import get_padding, init_weights
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
LRELU_SLOPE = 0.1
|
| 15 |
-
|
| 16 |
-
def create_conv1d_layer(channels, kernel_size, dilation):
|
| 17 |
-
return weight_norm(torch.nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation, padding=get_padding(kernel_size, dilation)))
|
| 18 |
-
|
| 19 |
-
def apply_mask(tensor, mask):
|
| 20 |
-
return tensor * mask if mask is not None else tensor
|
| 21 |
-
|
| 22 |
-
class ResBlockBase(torch.nn.Module):
|
| 23 |
-
def __init__(self, channels, kernel_size, dilations):
|
| 24 |
-
super(ResBlockBase, self).__init__()
|
| 25 |
-
|
| 26 |
-
self.convs1 = torch.nn.ModuleList([create_conv1d_layer(channels, kernel_size, d) for d in dilations])
|
| 27 |
-
self.convs1.apply(init_weights)
|
| 28 |
-
|
| 29 |
-
self.convs2 = torch.nn.ModuleList([create_conv1d_layer(channels, kernel_size, 1) for _ in dilations])
|
| 30 |
-
self.convs2.apply(init_weights)
|
| 31 |
-
|
| 32 |
-
def forward(self, x, x_mask=None):
|
| 33 |
-
for c1, c2 in zip(self.convs1, self.convs2):
|
| 34 |
-
x = c2(apply_mask(torch.nn.functional.leaky_relu(c1(apply_mask(torch.nn.functional.leaky_relu(x, LRELU_SLOPE), x_mask)), LRELU_SLOPE), x_mask)) + x
|
| 35 |
-
|
| 36 |
-
return apply_mask(x, x_mask)
|
| 37 |
-
|
| 38 |
-
def remove_weight_norm(self):
|
| 39 |
-
for conv in self.convs1 + self.convs2:
|
| 40 |
-
remove_weight_norm(conv)
|
| 41 |
-
|
| 42 |
-
class ResBlock(ResBlockBase):
|
| 43 |
-
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
| 44 |
-
super(ResBlock, self).__init__(channels, kernel_size, dilation)
|
| 45 |
-
|
| 46 |
-
class Log(torch.nn.Module):
|
| 47 |
-
def forward(self, x, x_mask, reverse=False, **kwargs):
|
| 48 |
-
if not reverse:
|
| 49 |
-
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
| 50 |
-
return y, torch.sum(-y, [1, 2])
|
| 51 |
-
else: return torch.exp(x) * x_mask
|
| 52 |
-
|
| 53 |
-
class Flip(torch.nn.Module):
|
| 54 |
-
def forward(self, x, *args, reverse=False, **kwargs):
|
| 55 |
-
x = torch.flip(x, [1])
|
| 56 |
-
|
| 57 |
-
if not reverse: return x, torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
| 58 |
-
else: return x
|
| 59 |
-
|
| 60 |
-
class ElementwiseAffine(torch.nn.Module):
|
| 61 |
-
def __init__(self, channels):
|
| 62 |
-
super().__init__()
|
| 63 |
-
self.channels = channels
|
| 64 |
-
self.m = torch.nn.Parameter(torch.zeros(channels, 1))
|
| 65 |
-
self.logs = torch.nn.Parameter(torch.zeros(channels, 1))
|
| 66 |
-
|
| 67 |
-
def forward(self, x, x_mask, reverse=False, **kwargs):
|
| 68 |
-
if not reverse: return ((self.m + torch.exp(self.logs) * x) * x_mask), torch.sum(self.logs * x_mask, [1, 2])
|
| 69 |
-
else: return (x - self.m) * torch.exp(-self.logs) * x_mask
|
| 70 |
-
|
| 71 |
-
class ResidualCouplingBlock(torch.nn.Module):
|
| 72 |
-
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0):
|
| 73 |
-
super(ResidualCouplingBlock, self).__init__()
|
| 74 |
-
self.channels = channels
|
| 75 |
-
self.hidden_channels = hidden_channels
|
| 76 |
-
self.kernel_size = kernel_size
|
| 77 |
-
self.dilation_rate = dilation_rate
|
| 78 |
-
self.n_layers = n_layers
|
| 79 |
-
self.n_flows = n_flows
|
| 80 |
-
self.gin_channels = gin_channels
|
| 81 |
-
self.flows = torch.nn.ModuleList()
|
| 82 |
-
|
| 83 |
-
for _ in range(n_flows):
|
| 84 |
-
self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
|
| 85 |
-
self.flows.append(Flip())
|
| 86 |
-
|
| 87 |
-
def forward(self, x, x_mask, g = None, reverse = False):
|
| 88 |
-
if not reverse:
|
| 89 |
-
for flow in self.flows:
|
| 90 |
-
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
| 91 |
-
else:
|
| 92 |
-
for flow in reversed(self.flows):
|
| 93 |
-
x = flow.forward(x, x_mask, g=g, reverse=reverse)
|
| 94 |
-
|
| 95 |
-
return x
|
| 96 |
-
|
| 97 |
-
def remove_weight_norm(self):
|
| 98 |
-
for i in range(self.n_flows):
|
| 99 |
-
self.flows[i * 2].remove_weight_norm()
|
| 100 |
-
|
| 101 |
-
def __prepare_scriptable__(self):
|
| 102 |
-
for i in range(self.n_flows):
|
| 103 |
-
for hook in self.flows[i * 2]._forward_pre_hooks.values():
|
| 104 |
-
if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(self.flows[i * 2])
|
| 105 |
-
|
| 106 |
-
return self
|
| 107 |
-
|
| 108 |
-
class ResidualCouplingLayer(torch.nn.Module):
|
| 109 |
-
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0, mean_only=False):
|
| 110 |
-
assert channels % 2 == 0, "Channels/2"
|
| 111 |
-
super().__init__()
|
| 112 |
-
self.channels = channels
|
| 113 |
-
self.hidden_channels = hidden_channels
|
| 114 |
-
self.kernel_size = kernel_size
|
| 115 |
-
self.dilation_rate = dilation_rate
|
| 116 |
-
self.n_layers = n_layers
|
| 117 |
-
self.half_channels = channels // 2
|
| 118 |
-
self.mean_only = mean_only
|
| 119 |
-
|
| 120 |
-
self.pre = torch.nn.Conv1d(self.half_channels, hidden_channels, 1)
|
| 121 |
-
self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
|
| 122 |
-
self.post = torch.nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
| 123 |
-
|
| 124 |
-
self.post.weight.data.zero_()
|
| 125 |
-
self.post.bias.data.zero_()
|
| 126 |
-
|
| 127 |
-
def forward(self, x, x_mask, g=None, reverse=False):
|
| 128 |
-
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
| 129 |
-
stats = self.post(self.enc((self.pre(x0) * x_mask), x_mask, g=g)) * x_mask
|
| 130 |
-
|
| 131 |
-
if not self.mean_only: m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
| 132 |
-
else:
|
| 133 |
-
m = stats
|
| 134 |
-
logs = torch.zeros_like(m)
|
| 135 |
-
|
| 136 |
-
if not reverse: return torch.cat([x0, (m + x1 * torch.exp(logs) * x_mask)], 1), torch.sum(logs, [1, 2])
|
| 137 |
-
else: return torch.cat([x0, ((x1 - m) * torch.exp(-logs) * x_mask)], 1)
|
| 138 |
-
|
| 139 |
-
def remove_weight_norm(self):
|
| 140 |
-
self.enc.remove_weight_norm()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/separator.py
DELETED
|
@@ -1,320 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import time
|
| 4 |
-
import yaml
|
| 5 |
-
import torch
|
| 6 |
-
import codecs
|
| 7 |
-
import hashlib
|
| 8 |
-
import logging
|
| 9 |
-
import platform
|
| 10 |
-
import warnings
|
| 11 |
-
import requests
|
| 12 |
-
import onnxruntime
|
| 13 |
-
|
| 14 |
-
from importlib import metadata, import_module
|
| 15 |
-
|
| 16 |
-
now_dir = os.getcwd()
|
| 17 |
-
sys.path.append(now_dir)
|
| 18 |
-
|
| 19 |
-
from main.configs.config import Config
|
| 20 |
-
from main.tools.huggingface import HF_download_file
|
| 21 |
-
|
| 22 |
-
translations = Config().translations
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class Separator:
|
| 26 |
-
def __init__(self, logger=logging.getLogger(__name__), log_level=logging.INFO, log_formatter=None, model_file_dir="assets/models/uvr5", output_dir=None, output_format="wav", output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=44100, mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False}, demucs_params={"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True}):
|
| 27 |
-
self.logger = logger
|
| 28 |
-
self.log_level = log_level
|
| 29 |
-
self.log_formatter = log_formatter
|
| 30 |
-
self.log_handler = logging.StreamHandler()
|
| 31 |
-
|
| 32 |
-
if self.log_formatter is None: self.log_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(module)s - %(message)s")
|
| 33 |
-
self.log_handler.setFormatter(self.log_formatter)
|
| 34 |
-
|
| 35 |
-
if not self.logger.hasHandlers(): self.logger.addHandler(self.log_handler)
|
| 36 |
-
if log_level > logging.DEBUG: warnings.filterwarnings("ignore")
|
| 37 |
-
|
| 38 |
-
self.logger.info(translations["separator_info"].format(output_dir=output_dir, output_format=output_format))
|
| 39 |
-
self.model_file_dir = model_file_dir
|
| 40 |
-
|
| 41 |
-
if output_dir is None:
|
| 42 |
-
output_dir = now_dir
|
| 43 |
-
self.logger.info(translations["output_dir_is_none"])
|
| 44 |
-
|
| 45 |
-
self.output_dir = output_dir
|
| 46 |
-
|
| 47 |
-
os.makedirs(self.model_file_dir, exist_ok=True)
|
| 48 |
-
os.makedirs(self.output_dir, exist_ok=True)
|
| 49 |
-
|
| 50 |
-
self.output_format = output_format
|
| 51 |
-
self.output_bitrate = output_bitrate
|
| 52 |
-
|
| 53 |
-
if self.output_format is None: self.output_format = "wav"
|
| 54 |
-
self.normalization_threshold = normalization_threshold
|
| 55 |
-
if normalization_threshold <= 0 or normalization_threshold > 1: raise ValueError(translations[">0or=1"])
|
| 56 |
-
|
| 57 |
-
self.output_single_stem = output_single_stem
|
| 58 |
-
if output_single_stem is not None: self.logger.debug(translations["output_single"].format(output_single_stem=output_single_stem))
|
| 59 |
-
|
| 60 |
-
self.invert_using_spec = invert_using_spec
|
| 61 |
-
if self.invert_using_spec: self.logger.debug(translations["step2"])
|
| 62 |
-
|
| 63 |
-
self.sample_rate = int(sample_rate)
|
| 64 |
-
self.arch_specific_params = {"MDX": mdx_params, "Demucs": demucs_params}
|
| 65 |
-
self.torch_device = None
|
| 66 |
-
self.torch_device_cpu = None
|
| 67 |
-
self.torch_device_mps = None
|
| 68 |
-
self.onnx_execution_provider = None
|
| 69 |
-
self.model_instance = None
|
| 70 |
-
self.model_is_uvr_vip = False
|
| 71 |
-
self.model_friendly_name = None
|
| 72 |
-
self.setup_accelerated_inferencing_device()
|
| 73 |
-
|
| 74 |
-
def setup_accelerated_inferencing_device(self):
|
| 75 |
-
system_info = self.get_system_info()
|
| 76 |
-
self.log_onnxruntime_packages()
|
| 77 |
-
self.setup_torch_device(system_info)
|
| 78 |
-
|
| 79 |
-
def get_system_info(self):
|
| 80 |
-
os_name = platform.system()
|
| 81 |
-
os_version = platform.version()
|
| 82 |
-
self.logger.info(f"{translations['os']}: {os_name} {os_version}")
|
| 83 |
-
system_info = platform.uname()
|
| 84 |
-
self.logger.info(translations["platform_info"].format(system_info=system_info, node=system_info.node, release=system_info.release, machine=system_info.machine, processor=system_info.processor))
|
| 85 |
-
python_version = platform.python_version()
|
| 86 |
-
self.logger.info(f"{translations['name_ver'].format(name='python')}: {python_version}")
|
| 87 |
-
pytorch_version = torch.__version__
|
| 88 |
-
self.logger.info(f"{translations['name_ver'].format(name='pytorch')}: {pytorch_version}")
|
| 89 |
-
|
| 90 |
-
return system_info
|
| 91 |
-
|
| 92 |
-
def log_onnxruntime_packages(self):
|
| 93 |
-
onnxruntime_gpu_package = self.get_package_distribution("onnxruntime-gpu")
|
| 94 |
-
onnxruntime_cpu_package = self.get_package_distribution("onnxruntime")
|
| 95 |
-
|
| 96 |
-
if onnxruntime_gpu_package is not None: self.logger.info(f"{translations['install_onnx'].format(pu='GPU')}: {onnxruntime_gpu_package.version}")
|
| 97 |
-
if onnxruntime_cpu_package is not None: self.logger.info(f"{translations['install_onnx'].format(pu='CPU')}: {onnxruntime_cpu_package.version}")
|
| 98 |
-
|
| 99 |
-
def setup_torch_device(self, system_info):
|
| 100 |
-
hardware_acceleration_enabled = False
|
| 101 |
-
ort_providers = onnxruntime.get_available_providers()
|
| 102 |
-
self.torch_device_cpu = torch.device("cpu")
|
| 103 |
-
|
| 104 |
-
if torch.cuda.is_available():
|
| 105 |
-
self.configure_cuda(ort_providers)
|
| 106 |
-
hardware_acceleration_enabled = True
|
| 107 |
-
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and system_info.processor == "arm":
|
| 108 |
-
self.configure_mps(ort_providers)
|
| 109 |
-
hardware_acceleration_enabled = True
|
| 110 |
-
|
| 111 |
-
if not hardware_acceleration_enabled:
|
| 112 |
-
self.logger.info(translations["running_in_cpu"])
|
| 113 |
-
self.torch_device = self.torch_device_cpu
|
| 114 |
-
self.onnx_execution_provider = ["CPUExecutionProvider"]
|
| 115 |
-
|
| 116 |
-
def configure_cuda(self, ort_providers):
|
| 117 |
-
self.logger.info(translations["running_in_cuda"])
|
| 118 |
-
self.torch_device = torch.device("cuda")
|
| 119 |
-
|
| 120 |
-
if "CUDAExecutionProvider" in ort_providers:
|
| 121 |
-
self.logger.info(translations["onnx_have"].format(have='CUDAExecutionProvider'))
|
| 122 |
-
self.onnx_execution_provider = ["CUDAExecutionProvider"]
|
| 123 |
-
else: self.logger.warning(translations["onnx_not_have"].format(have='CUDAExecutionProvider'))
|
| 124 |
-
|
| 125 |
-
def configure_mps(self, ort_providers):
|
| 126 |
-
self.logger.info(translations["set_torch_mps"])
|
| 127 |
-
self.torch_device_mps = torch.device("mps")
|
| 128 |
-
self.torch_device = self.torch_device_mps
|
| 129 |
-
|
| 130 |
-
if "CoreMLExecutionProvider" in ort_providers:
|
| 131 |
-
self.logger.info(translations["onnx_have"].format(have='CoreMLExecutionProvider'))
|
| 132 |
-
self.onnx_execution_provider = ["CoreMLExecutionProvider"]
|
| 133 |
-
else: self.logger.warning(translations["onnx_not_have"].format(have='CoreMLExecutionProvider'))
|
| 134 |
-
|
| 135 |
-
def get_package_distribution(self, package_name):
|
| 136 |
-
try:
|
| 137 |
-
return metadata.distribution(package_name)
|
| 138 |
-
except metadata.PackageNotFoundError:
|
| 139 |
-
self.logger.debug(translations["python_not_install"].format(package_name=package_name))
|
| 140 |
-
return None
|
| 141 |
-
|
| 142 |
-
def get_model_hash(self, model_path):
|
| 143 |
-
self.logger.debug(translations["hash"].format(model_path=model_path))
|
| 144 |
-
|
| 145 |
-
try:
|
| 146 |
-
with open(model_path, "rb") as f:
|
| 147 |
-
f.seek(-10000 * 1024, 2)
|
| 148 |
-
return hashlib.md5(f.read()).hexdigest()
|
| 149 |
-
except IOError as e:
|
| 150 |
-
self.logger.error(translations["ioerror"].format(e=e))
|
| 151 |
-
return hashlib.md5(open(model_path, "rb").read()).hexdigest()
|
| 152 |
-
|
| 153 |
-
def download_file_if_not_exists(self, url, output_path):
|
| 154 |
-
if os.path.isfile(output_path):
|
| 155 |
-
self.logger.debug(translations["cancel_download"].format(output_path=output_path))
|
| 156 |
-
return
|
| 157 |
-
|
| 158 |
-
self.logger.debug(translations["download_model"].format(url=url, output_path=output_path))
|
| 159 |
-
HF_download_file(url, output_path)
|
| 160 |
-
|
| 161 |
-
def print_uvr_vip_message(self):
|
| 162 |
-
if self.model_is_uvr_vip:
|
| 163 |
-
self.logger.warning(translations["vip_model"].format(model_friendly_name=self.model_friendly_name))
|
| 164 |
-
self.logger.warning(translations["vip_print"])
|
| 165 |
-
|
| 166 |
-
def list_supported_model_files(self):
|
| 167 |
-
response = requests.get(codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/enj/znva/wfba/hie_zbqryf.wfba", "rot13"))
|
| 168 |
-
response.raise_for_status()
|
| 169 |
-
model_downloads_list = response.json()
|
| 170 |
-
self.logger.debug(translations["load_download_json"])
|
| 171 |
-
|
| 172 |
-
return {"MDX": {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"]}, "Demucs": {key: value for key, value in model_downloads_list["demucs_download_list"].items() if key.startswith("Demucs v4")}}
|
| 173 |
-
|
| 174 |
-
def download_model_files(self, model_filename):
|
| 175 |
-
model_path = os.path.join(self.model_file_dir, model_filename)
|
| 176 |
-
supported_model_files_grouped = self.list_supported_model_files()
|
| 177 |
-
|
| 178 |
-
yaml_config_filename = None
|
| 179 |
-
self.logger.debug(translations["search_model"].format(model_filename=model_filename))
|
| 180 |
-
|
| 181 |
-
for model_type, model_list in supported_model_files_grouped.items():
|
| 182 |
-
for model_friendly_name, model_download_list in model_list.items():
|
| 183 |
-
self.model_is_uvr_vip = "VIP" in model_friendly_name
|
| 184 |
-
model_repo_url_prefix = codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/hie5_zbqryf", "rot13")
|
| 185 |
-
|
| 186 |
-
if isinstance(model_download_list, str) and model_download_list == model_filename:
|
| 187 |
-
self.logger.debug(translations["single_model"].format(model_friendly_name=model_friendly_name))
|
| 188 |
-
self.model_friendly_name = model_friendly_name
|
| 189 |
-
|
| 190 |
-
try:
|
| 191 |
-
self.download_file_if_not_exists(f"{model_repo_url_prefix}/MDX/{model_filename}", model_path)
|
| 192 |
-
except RuntimeError:
|
| 193 |
-
self.logger.warning(translations["not_found_model"])
|
| 194 |
-
self.download_file_if_not_exists(f"{model_repo_url_prefix}/Demucs/{model_filename}", model_path)
|
| 195 |
-
|
| 196 |
-
self.print_uvr_vip_message()
|
| 197 |
-
self.logger.debug(translations["single_model_path"].format(model_path=model_path))
|
| 198 |
-
|
| 199 |
-
return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
|
| 200 |
-
elif isinstance(model_download_list, dict):
|
| 201 |
-
this_model_matches_input_filename = False
|
| 202 |
-
|
| 203 |
-
for file_name, file_url in model_download_list.items():
|
| 204 |
-
if file_name == model_filename or file_url == model_filename:
|
| 205 |
-
self.logger.debug(translations["find_model"].format(model_filename=model_filename, model_friendly_name=model_friendly_name))
|
| 206 |
-
this_model_matches_input_filename = True
|
| 207 |
-
|
| 208 |
-
if this_model_matches_input_filename:
|
| 209 |
-
self.logger.debug(translations["find_models"].format(model_friendly_name=model_friendly_name))
|
| 210 |
-
self.model_friendly_name = model_friendly_name
|
| 211 |
-
self.print_uvr_vip_message()
|
| 212 |
-
|
| 213 |
-
for config_key, config_value in model_download_list.items():
|
| 214 |
-
self.logger.debug(f"{translations['find_path']}: {config_key} -> {config_value}")
|
| 215 |
-
|
| 216 |
-
if config_value.startswith("http"): self.download_file_if_not_exists(config_value, os.path.join(self.model_file_dir, config_key))
|
| 217 |
-
elif config_key.endswith(".ckpt"):
|
| 218 |
-
try:
|
| 219 |
-
self.download_file_if_not_exists(f"{model_repo_url_prefix}/Demucs/{config_key}", os.path.join(self.model_file_dir, config_key))
|
| 220 |
-
except RuntimeError:
|
| 221 |
-
self.logger.warning(translations["not_found_model_warehouse"])
|
| 222 |
-
|
| 223 |
-
if model_filename.endswith(".yaml"):
|
| 224 |
-
self.logger.warning(translations["yaml_warning"].format(model_filename=model_filename))
|
| 225 |
-
self.logger.warning(translations["yaml_warning_2"].format(config_key=config_key))
|
| 226 |
-
self.logger.warning(translations["yaml_warning_3"])
|
| 227 |
-
|
| 228 |
-
model_filename = config_key
|
| 229 |
-
model_path = os.path.join(self.model_file_dir, f"{model_filename}")
|
| 230 |
-
|
| 231 |
-
yaml_config_filename = config_value
|
| 232 |
-
yaml_config_filepath = os.path.join(self.model_file_dir, yaml_config_filename)
|
| 233 |
-
|
| 234 |
-
try:
|
| 235 |
-
self.download_file_if_not_exists(f"{model_repo_url_prefix}/mdx_c_configs/{yaml_config_filename}", yaml_config_filepath)
|
| 236 |
-
except RuntimeError:
|
| 237 |
-
self.logger.debug(translations["yaml_debug"])
|
| 238 |
-
else: self.download_file_if_not_exists(f"{model_repo_url_prefix}/Demucs/{config_value}", os.path.join(self.model_file_dir, config_value))
|
| 239 |
-
|
| 240 |
-
self.logger.debug(translations["download_model_friendly"].format(model_friendly_name=model_friendly_name, model_path=model_path))
|
| 241 |
-
return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
|
| 242 |
-
|
| 243 |
-
raise ValueError(translations["not_found_model_2"].format(model_filename=model_filename))
|
| 244 |
-
|
| 245 |
-
def load_model_data_from_yaml(self, yaml_config_filename):
|
| 246 |
-
model_data_yaml_filepath = os.path.join(self.model_file_dir, yaml_config_filename) if not os.path.exists(yaml_config_filename) else yaml_config_filename
|
| 247 |
-
self.logger.debug(translations["load_yaml"].format(model_data_yaml_filepath=model_data_yaml_filepath))
|
| 248 |
-
|
| 249 |
-
model_data = yaml.load(open(model_data_yaml_filepath, encoding="utf-8"), Loader=yaml.FullLoader)
|
| 250 |
-
self.logger.debug(translations["load_yaml_2"].format(model_data=model_data))
|
| 251 |
-
|
| 252 |
-
if "roformer" in model_data_yaml_filepath: model_data["is_roformer"] = True
|
| 253 |
-
return model_data
|
| 254 |
-
|
| 255 |
-
def load_model_data_using_hash(self, model_path):
|
| 256 |
-
self.logger.debug(translations["hash_md5"])
|
| 257 |
-
model_hash = self.get_model_hash(model_path)
|
| 258 |
-
|
| 259 |
-
self.logger.debug(translations["model_hash"].format(model_path=model_path, model_hash=model_hash))
|
| 260 |
-
mdx_model_data_path = codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/enj/znva/wfba/zbqry_qngn.wfba", "rot13")
|
| 261 |
-
self.logger.debug(translations["mdx_data"].format(mdx_model_data_path=mdx_model_data_path))
|
| 262 |
-
|
| 263 |
-
response = requests.get(mdx_model_data_path)
|
| 264 |
-
response.raise_for_status()
|
| 265 |
-
|
| 266 |
-
mdx_model_data_object = response.json()
|
| 267 |
-
self.logger.debug(translations["load_mdx"])
|
| 268 |
-
|
| 269 |
-
if model_hash in mdx_model_data_object: model_data = mdx_model_data_object[model_hash]
|
| 270 |
-
else: raise ValueError(translations["model_not_support"].format(model_hash=model_hash))
|
| 271 |
-
|
| 272 |
-
self.logger.debug(translations["uvr_json"].format(model_hash=model_hash, model_data=model_data))
|
| 273 |
-
return model_data
|
| 274 |
-
|
| 275 |
-
def load_model(self, model_filename):
|
| 276 |
-
self.logger.info(translations["loading_model"].format(model_filename=model_filename))
|
| 277 |
-
load_model_start_time = time.perf_counter()
|
| 278 |
-
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
|
| 279 |
-
self.logger.debug(translations["download_model_friendly_2"].format(model_friendly_name=model_friendly_name, model_path=model_path))
|
| 280 |
-
|
| 281 |
-
if model_path.lower().endswith(".yaml"): yaml_config_filename = model_path
|
| 282 |
-
|
| 283 |
-
common_params = {"logger": self.logger, "log_level": self.log_level, "torch_device": self.torch_device, "torch_device_cpu": self.torch_device_cpu, "torch_device_mps": self.torch_device_mps, "onnx_execution_provider": self.onnx_execution_provider, "model_name": model_filename.split(".")[0], "model_path": model_path, "model_data": self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename is not None else self.load_model_data_using_hash(model_path), "output_format": self.output_format, "output_bitrate": self.output_bitrate, "output_dir": self.output_dir, "normalization_threshold": self.normalization_threshold, "output_single_stem": self.output_single_stem, "invert_using_spec": self.invert_using_spec, "sample_rate": self.sample_rate}
|
| 284 |
-
separator_classes = {"MDX": "mdx_separator.MDXSeparator", "Demucs": "demucs_separator.DemucsSeparator"}
|
| 285 |
-
|
| 286 |
-
if model_type not in self.arch_specific_params or model_type not in separator_classes: raise ValueError(translations["model_type_not_support"].format(model_type=model_type))
|
| 287 |
-
if model_type == "Demucs" and sys.version_info < (3, 10): raise Exception(translations["demucs_not_support_python<3.10"])
|
| 288 |
-
|
| 289 |
-
self.logger.debug(f"{translations['import_module']} {model_type}: {separator_classes[model_type]}")
|
| 290 |
-
module_name, class_name = separator_classes[model_type].split(".")
|
| 291 |
-
separator_class = getattr(import_module(f"main.library.architectures.{module_name}"), class_name)
|
| 292 |
-
|
| 293 |
-
self.logger.debug(f"{translations['initialization']} {model_type}: {separator_class}")
|
| 294 |
-
self.model_instance = separator_class(common_config=common_params, arch_config=self.arch_specific_params[model_type])
|
| 295 |
-
|
| 296 |
-
self.logger.debug(translations["loading_model_success"])
|
| 297 |
-
self.logger.info(f"{translations['loading_model_duration']}: {time.strftime('%H:%M:%S', time.gmtime(int(time.perf_counter() - load_model_start_time)))}")
|
| 298 |
-
|
| 299 |
-
def separate(self, audio_file_path):
|
| 300 |
-
self.logger.info(f"{translations['starting_separator']}: {audio_file_path}")
|
| 301 |
-
separate_start_time = time.perf_counter()
|
| 302 |
-
|
| 303 |
-
self.logger.debug(translations["normalization"].format(normalization_threshold=self.normalization_threshold))
|
| 304 |
-
output_files = self.model_instance.separate(audio_file_path)
|
| 305 |
-
|
| 306 |
-
self.model_instance.clear_gpu_cache()
|
| 307 |
-
self.model_instance.clear_file_specific_paths()
|
| 308 |
-
|
| 309 |
-
self.print_uvr_vip_message()
|
| 310 |
-
|
| 311 |
-
self.logger.debug(translations["separator_success_3"])
|
| 312 |
-
self.logger.info(f"{translations['separator_duration']}: {time.strftime('%H:%M:%S', time.gmtime(int(time.perf_counter() - separate_start_time)))}")
|
| 313 |
-
return output_files
|
| 314 |
-
|
| 315 |
-
def download_model_and_data(self, model_filename):
|
| 316 |
-
self.logger.info(translations["loading_separator_model"].format(model_filename=model_filename))
|
| 317 |
-
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
|
| 318 |
-
|
| 319 |
-
if model_path.lower().endswith(".yaml"): yaml_config_filename = model_path
|
| 320 |
-
self.logger.info(translations["downloading_model"].format(model_type=model_type, model_friendly_name=model_friendly_name, model_path=model_path, model_data_dict_size=len(self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename is not None else self.load_model_data_using_hash(model_path))))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/stftpitchshift.py
DELETED
|
@@ -1,250 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
|
| 3 |
-
from numpy.lib.stride_tricks import sliding_window_view
|
| 4 |
-
|
| 5 |
-
def istft(frames, framesize, hopsize):
|
| 6 |
-
frames = np.atleast_2d(frames)
|
| 7 |
-
assert frames.ndim == 2
|
| 8 |
-
|
| 9 |
-
analysis_window_size = np.ravel(framesize)[0]
|
| 10 |
-
synthesis_window_size = np.ravel(framesize)[-1]
|
| 11 |
-
|
| 12 |
-
assert analysis_window_size >= synthesis_window_size
|
| 13 |
-
|
| 14 |
-
A = asymmetric_analysis_window(analysis_window_size, synthesis_window_size) if analysis_window_size != synthesis_window_size else symmetric_window(analysis_window_size)
|
| 15 |
-
S = asymmetric_synthesis_window(analysis_window_size, synthesis_window_size) if analysis_window_size != synthesis_window_size else symmetric_window(synthesis_window_size)
|
| 16 |
-
|
| 17 |
-
W = S * hopsize / np.sum(A * S)
|
| 18 |
-
N = frames.shape[0] * hopsize + analysis_window_size
|
| 19 |
-
|
| 20 |
-
y = np.zeros((N), float)
|
| 21 |
-
|
| 22 |
-
frames[:, 0] = 0
|
| 23 |
-
frames[:, -1] = 0
|
| 24 |
-
frames0 = sliding_window_view(y, analysis_window_size, writeable=True)[::hopsize]
|
| 25 |
-
frames1 = np.fft.irfft(frames, axis=-1, norm='forward') * W
|
| 26 |
-
|
| 27 |
-
for i in range(min(len(frames0), len(frames1))):
|
| 28 |
-
frames0[i] += frames1[i]
|
| 29 |
-
|
| 30 |
-
return y
|
| 31 |
-
|
| 32 |
-
def asymmetric_synthesis_window(analysis_window_size, synthesis_window_size):
|
| 33 |
-
n = analysis_window_size
|
| 34 |
-
m = synthesis_window_size // 2
|
| 35 |
-
|
| 36 |
-
right = symmetric_window(2 * m)
|
| 37 |
-
window = np.zeros(n)
|
| 38 |
-
|
| 39 |
-
window[n-m-m:n-m] = np.square(right[:m]) / symmetric_window(2 * n - 2 * m)[n-m-m:n-m]
|
| 40 |
-
window[-m:] = right[-m:]
|
| 41 |
-
|
| 42 |
-
return window
|
| 43 |
-
|
| 44 |
-
def asymmetric_analysis_window(analysis_window_size, synthesis_window_size):
|
| 45 |
-
n = analysis_window_size
|
| 46 |
-
m = synthesis_window_size // 2
|
| 47 |
-
|
| 48 |
-
window = np.zeros(n)
|
| 49 |
-
window[:n-m] = symmetric_window(2 * n - 2 * m)[:n-m]
|
| 50 |
-
window[-m:] = symmetric_window(2 * m)[-m:]
|
| 51 |
-
|
| 52 |
-
return window
|
| 53 |
-
|
| 54 |
-
def symmetric_window(symmetric_window_size):
|
| 55 |
-
n = symmetric_window_size
|
| 56 |
-
window = 0.5 - 0.5 * np.cos(2 * np.pi * np.arange(n) / n)
|
| 57 |
-
|
| 58 |
-
return window
|
| 59 |
-
|
| 60 |
-
def stft(x, framesize, hopsize):
|
| 61 |
-
x = np.atleast_1d(x)
|
| 62 |
-
assert x.ndim == 1
|
| 63 |
-
|
| 64 |
-
analysis_window_size = np.ravel(framesize)[0]
|
| 65 |
-
synthesis_window_size = np.ravel(framesize)[-1]
|
| 66 |
-
|
| 67 |
-
assert analysis_window_size >= synthesis_window_size
|
| 68 |
-
|
| 69 |
-
W = asymmetric_analysis_window(analysis_window_size, synthesis_window_size) if analysis_window_size != synthesis_window_size else symmetric_window(analysis_window_size)
|
| 70 |
-
|
| 71 |
-
frames0 = sliding_window_view(x, analysis_window_size, writeable=False)[::hopsize]
|
| 72 |
-
frames1 = np.fft.rfft(frames0 * W, axis=-1, norm='forward')
|
| 73 |
-
|
| 74 |
-
return frames1
|
| 75 |
-
|
| 76 |
-
def normalize(frames, frames0):
|
| 77 |
-
for i in range(len(frames)):
|
| 78 |
-
a = np.real(frames0[i])
|
| 79 |
-
b = np.real(frames[i])
|
| 80 |
-
a = np.dot(a, a)
|
| 81 |
-
b = np.dot(b, b)
|
| 82 |
-
|
| 83 |
-
if b == 0: continue
|
| 84 |
-
frames[i] = np.real(frames[i]) * np.sqrt(a / b) + 1j * np.imag(frames[i])
|
| 85 |
-
|
| 86 |
-
return frames
|
| 87 |
-
|
| 88 |
-
def lowpass(cepstrum, quefrency):
|
| 89 |
-
cepstrum[1:quefrency] *= 2
|
| 90 |
-
cepstrum[quefrency+1:] = 0
|
| 91 |
-
|
| 92 |
-
return cepstrum
|
| 93 |
-
|
| 94 |
-
def lifter(frames, quefrency):
|
| 95 |
-
envelopes = np.zeros(frames.shape)
|
| 96 |
-
|
| 97 |
-
for i, frame in enumerate(frames):
|
| 98 |
-
with np.errstate(divide='ignore', invalid='ignore'):
|
| 99 |
-
spectrum = np.log10(np.real(frame))
|
| 100 |
-
|
| 101 |
-
envelopes[i] = np.power(10, np.real(np.fft.rfft(lowpass(np.fft.irfft(spectrum, norm='forward'), quefrency), norm='forward')))
|
| 102 |
-
|
| 103 |
-
return envelopes
|
| 104 |
-
|
| 105 |
-
def resample(x, factor):
|
| 106 |
-
if factor == 1: return x.copy()
|
| 107 |
-
y = np.zeros(x.shape, dtype=x.dtype)
|
| 108 |
-
|
| 109 |
-
n = len(x)
|
| 110 |
-
m = int(n * factor)
|
| 111 |
-
|
| 112 |
-
i = np.arange(min(n, m))
|
| 113 |
-
k = i * (n / m)
|
| 114 |
-
|
| 115 |
-
j = np.trunc(k).astype(int)
|
| 116 |
-
k = k - j
|
| 117 |
-
|
| 118 |
-
ok = (0 <= j) & (j < n - 1)
|
| 119 |
-
y[i[ok]] = k[ok] * x[j[ok] + 1] + (1 - k[ok]) * x[j[ok]]
|
| 120 |
-
|
| 121 |
-
return y
|
| 122 |
-
|
| 123 |
-
def shiftpitch(frames, factors, samplerate):
|
| 124 |
-
for i in range(len(frames)):
|
| 125 |
-
magnitudes = np.vstack([resample(np.real(frames[i]), factor) for factor in factors])
|
| 126 |
-
frequencies = np.vstack([resample(np.imag(frames[i]), factor) * factor for factor in factors])
|
| 127 |
-
|
| 128 |
-
magnitudes[(frequencies <= 0) | (frequencies >= samplerate / 2)] = 0
|
| 129 |
-
mask = np.argmax(magnitudes, axis=0)
|
| 130 |
-
|
| 131 |
-
magnitudes = np.take_along_axis(magnitudes, mask[None,:], axis=0)
|
| 132 |
-
frequencies = np.take_along_axis(frequencies, mask[None,:], axis=0)
|
| 133 |
-
|
| 134 |
-
frames[i] = magnitudes + 1j * frequencies
|
| 135 |
-
|
| 136 |
-
return frames
|
| 137 |
-
|
| 138 |
-
def wrap(x):
|
| 139 |
-
return (x + np.pi) % (2 * np.pi) - np.pi
|
| 140 |
-
|
| 141 |
-
def encode(frames, framesize, hopsize, samplerate):
|
| 142 |
-
M, N = frames.shape
|
| 143 |
-
analysis_framesize = np.ravel(framesize)[0]
|
| 144 |
-
|
| 145 |
-
freqinc = samplerate / analysis_framesize
|
| 146 |
-
phaseinc = 2 * np.pi * hopsize / analysis_framesize
|
| 147 |
-
|
| 148 |
-
buffer = np.zeros(N)
|
| 149 |
-
data = np.zeros((M, N), complex)
|
| 150 |
-
|
| 151 |
-
for m, frame in enumerate(frames):
|
| 152 |
-
arg = np.angle(frame)
|
| 153 |
-
delta = arg - buffer
|
| 154 |
-
|
| 155 |
-
buffer = arg
|
| 156 |
-
|
| 157 |
-
i = np.arange(N)
|
| 158 |
-
data[m] = np.abs(frame) + 1j * ((i + (wrap(delta - i * phaseinc) / phaseinc)) * freqinc)
|
| 159 |
-
|
| 160 |
-
return data
|
| 161 |
-
|
| 162 |
-
def decode(frames, framesize, hopsize, samplerate):
|
| 163 |
-
M, N = frames.shape
|
| 164 |
-
analysis_framesize = np.ravel(framesize)[0]
|
| 165 |
-
synthesis_framesize = np.ravel(framesize)[-1]
|
| 166 |
-
|
| 167 |
-
freqinc = samplerate / analysis_framesize
|
| 168 |
-
phaseinc = 2 * np.pi * hopsize / analysis_framesize
|
| 169 |
-
timeshift = 2 * np.pi * synthesis_framesize * np.arange(N) / N if synthesis_framesize != analysis_framesize else 0
|
| 170 |
-
|
| 171 |
-
buffer = np.zeros(N)
|
| 172 |
-
data = np.zeros((M, N), complex)
|
| 173 |
-
|
| 174 |
-
for m, frame in enumerate(frames):
|
| 175 |
-
i = np.arange(N)
|
| 176 |
-
delta = (i + ((np.imag(frame) - i * freqinc) / freqinc)) * phaseinc
|
| 177 |
-
buffer += delta
|
| 178 |
-
arg = buffer.copy()
|
| 179 |
-
arg -= timeshift
|
| 180 |
-
data[m] = np.real(frame) * np.exp(1j * arg)
|
| 181 |
-
|
| 182 |
-
return data
|
| 183 |
-
|
| 184 |
-
class StftPitchShift:
|
| 185 |
-
def __init__(self, framesize, hopsize, samplerate):
|
| 186 |
-
self.framesize = framesize
|
| 187 |
-
self.hopsize = hopsize
|
| 188 |
-
self.samplerate = samplerate
|
| 189 |
-
|
| 190 |
-
def shiftpitch(self, input, factors = 1, quefrency = 0, distortion = 1, normalization = False):
|
| 191 |
-
input = np.atleast_1d(input)
|
| 192 |
-
dtype = input.dtype
|
| 193 |
-
shape = input.shape
|
| 194 |
-
|
| 195 |
-
input = np.squeeze(input)
|
| 196 |
-
if input.ndim != 1: raise ValueError('input.ndim != 1')
|
| 197 |
-
|
| 198 |
-
if np.issubdtype(dtype, np.integer):
|
| 199 |
-
a, b = np.iinfo(dtype).min, np.iinfo(dtype).max
|
| 200 |
-
input = ((input.astype(float) - a) / (b - a)) * 2 - 1
|
| 201 |
-
elif not np.issubdtype(dtype, np.floating): raise TypeError('not np.issubdtype(dtype, np.floating)')
|
| 202 |
-
|
| 203 |
-
def isnotnormal(x):
|
| 204 |
-
return (np.isinf(x)) | (np.isnan(x)) | (abs(x) < np.finfo(x.dtype).tiny)
|
| 205 |
-
|
| 206 |
-
framesize = self.framesize
|
| 207 |
-
hopsize = self.hopsize
|
| 208 |
-
samplerate = self.samplerate
|
| 209 |
-
|
| 210 |
-
factors = np.asarray(factors).flatten()
|
| 211 |
-
quefrency = int(quefrency * samplerate)
|
| 212 |
-
|
| 213 |
-
frames = encode(stft(input, framesize, hopsize), framesize, hopsize, samplerate)
|
| 214 |
-
|
| 215 |
-
if normalization: frames0 = frames.copy()
|
| 216 |
-
|
| 217 |
-
if quefrency:
|
| 218 |
-
envelopes = lifter(frames, quefrency)
|
| 219 |
-
mask = isnotnormal(envelopes)
|
| 220 |
-
|
| 221 |
-
frames.real /= envelopes
|
| 222 |
-
frames.real[mask] = 0
|
| 223 |
-
|
| 224 |
-
if distortion != 1:
|
| 225 |
-
envelopes[mask] = 0
|
| 226 |
-
|
| 227 |
-
for i in range(len(envelopes)):
|
| 228 |
-
envelopes[i] = resample(envelopes[i], distortion)
|
| 229 |
-
|
| 230 |
-
mask = isnotnormal(envelopes)
|
| 231 |
-
|
| 232 |
-
frames = shiftpitch(frames, factors, samplerate)
|
| 233 |
-
frames.real *= envelopes
|
| 234 |
-
frames.real[mask] = 0
|
| 235 |
-
else: frames = shiftpitch(frames, factors, samplerate)
|
| 236 |
-
|
| 237 |
-
if normalization: frames = normalize(frames, frames0)
|
| 238 |
-
|
| 239 |
-
output = istft(decode(frames, framesize, hopsize, samplerate), framesize, hopsize)
|
| 240 |
-
output.resize(shape, refcheck=False)
|
| 241 |
-
|
| 242 |
-
if np.issubdtype(dtype, np.integer):
|
| 243 |
-
a, b = np.iinfo(dtype).min, np.iinfo(dtype).max
|
| 244 |
-
output = (((output + 1) / 2) * (b - a) + a).clip(a, b).astype(dtype)
|
| 245 |
-
elif output.dtype != dtype: output = output.astype(dtype)
|
| 246 |
-
|
| 247 |
-
assert output.dtype == dtype
|
| 248 |
-
assert output.shape == shape
|
| 249 |
-
|
| 250 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/synthesizers.py
DELETED
|
@@ -1,490 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import math
|
| 4 |
-
import torch
|
| 5 |
-
import numpy as np
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
|
| 8 |
-
from torch.nn.utils import remove_weight_norm
|
| 9 |
-
from torch.utils.checkpoint import checkpoint
|
| 10 |
-
from torch.nn.utils.parametrizations import weight_norm
|
| 11 |
-
|
| 12 |
-
sys.path.append(os.getcwd())
|
| 13 |
-
|
| 14 |
-
from .modules import WaveNet
|
| 15 |
-
from .refinegan import RefineGANGenerator
|
| 16 |
-
from .mrf_hifigan import HiFiGANMRFGenerator
|
| 17 |
-
from .residuals import ResidualCouplingBlock, ResBlock, LRELU_SLOPE
|
| 18 |
-
from .commons import init_weights, slice_segments, rand_slice_segments, sequence_mask, convert_pad_shape
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class Generator(torch.nn.Module):
|
| 22 |
-
def __init__(self, initial_channel, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
|
| 23 |
-
super(Generator, self).__init__()
|
| 24 |
-
self.num_kernels = len(resblock_kernel_sizes)
|
| 25 |
-
self.num_upsamples = len(upsample_rates)
|
| 26 |
-
self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
| 27 |
-
self.ups_and_resblocks = torch.nn.ModuleList()
|
| 28 |
-
|
| 29 |
-
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 30 |
-
self.ups_and_resblocks.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2)))
|
| 31 |
-
ch = upsample_initial_channel // (2 ** (i + 1))
|
| 32 |
-
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
| 33 |
-
self.ups_and_resblocks.append(ResBlock(ch, k, d))
|
| 34 |
-
|
| 35 |
-
self.conv_post = torch.nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
| 36 |
-
self.ups_and_resblocks.apply(init_weights)
|
| 37 |
-
if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
| 38 |
-
|
| 39 |
-
def forward(self, x, g = None):
|
| 40 |
-
x = self.conv_pre(x)
|
| 41 |
-
if g is not None: x = x + self.cond(g)
|
| 42 |
-
|
| 43 |
-
resblock_idx = 0
|
| 44 |
-
|
| 45 |
-
for _ in range(self.num_upsamples):
|
| 46 |
-
x = self.ups_and_resblocks[resblock_idx](F.leaky_relu(x, LRELU_SLOPE))
|
| 47 |
-
resblock_idx += 1
|
| 48 |
-
xs = 0
|
| 49 |
-
|
| 50 |
-
for _ in range(self.num_kernels):
|
| 51 |
-
xs += self.ups_and_resblocks[resblock_idx](x)
|
| 52 |
-
resblock_idx += 1
|
| 53 |
-
|
| 54 |
-
x = xs / self.num_kernels
|
| 55 |
-
|
| 56 |
-
return torch.tanh(self.conv_post(F.leaky_relu(x)))
|
| 57 |
-
|
| 58 |
-
def __prepare_scriptable__(self):
|
| 59 |
-
for l in self.ups_and_resblocks:
|
| 60 |
-
for hook in l._forward_pre_hooks.values():
|
| 61 |
-
if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(l)
|
| 62 |
-
|
| 63 |
-
return self
|
| 64 |
-
|
| 65 |
-
def remove_weight_norm(self):
|
| 66 |
-
for l in self.ups_and_resblocks:
|
| 67 |
-
remove_weight_norm(l)
|
| 68 |
-
|
| 69 |
-
class SineGen(torch.nn.Module):
|
| 70 |
-
def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0, flag_for_pulse=False):
|
| 71 |
-
super(SineGen, self).__init__()
|
| 72 |
-
self.sine_amp = sine_amp
|
| 73 |
-
self.noise_std = noise_std
|
| 74 |
-
self.harmonic_num = harmonic_num
|
| 75 |
-
self.dim = self.harmonic_num + 1
|
| 76 |
-
self.sampling_rate = samp_rate
|
| 77 |
-
self.voiced_threshold = voiced_threshold
|
| 78 |
-
|
| 79 |
-
def _f02uv(self, f0):
|
| 80 |
-
return torch.ones_like(f0) * (f0 > self.voiced_threshold)
|
| 81 |
-
|
| 82 |
-
def _f02sine(self, f0, upp):
|
| 83 |
-
rad = f0 / self.sampling_rate * torch.arange(1, upp + 1, dtype=f0.dtype, device=f0.device)
|
| 84 |
-
rad += F.pad((torch.fmod(rad[:, :-1, -1:].float() + 0.5, 1.0) - 0.5).cumsum(dim=1).fmod(1.0).to(f0), (0, 0, 1, 0), mode='constant')
|
| 85 |
-
rad = rad.reshape(f0.shape[0], -1, 1)
|
| 86 |
-
rad *= torch.arange(1, self.dim + 1, dtype=f0.dtype, device=f0.device).reshape(1, 1, -1)
|
| 87 |
-
rand_ini = torch.rand(1, 1, self.dim, device=f0.device)
|
| 88 |
-
rand_ini[..., 0] = 0
|
| 89 |
-
rad += rand_ini
|
| 90 |
-
|
| 91 |
-
return torch.sin(2 * np.pi * rad)
|
| 92 |
-
|
| 93 |
-
def forward(self, f0, upp):
|
| 94 |
-
with torch.no_grad():
|
| 95 |
-
f0 = f0.unsqueeze(-1)
|
| 96 |
-
sine_waves = self._f02sine(f0, upp) * self.sine_amp
|
| 97 |
-
uv = F.interpolate(self._f02uv(f0).transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1)
|
| 98 |
-
sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
|
| 99 |
-
|
| 100 |
-
return sine_waves
|
| 101 |
-
|
| 102 |
-
class SourceModuleHnNSF(torch.nn.Module):
|
| 103 |
-
def __init__(self, sample_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0):
|
| 104 |
-
super(SourceModuleHnNSF, self).__init__()
|
| 105 |
-
self.sine_amp = sine_amp
|
| 106 |
-
self.noise_std = add_noise_std
|
| 107 |
-
self.l_sin_gen = SineGen(sample_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
|
| 108 |
-
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
| 109 |
-
self.l_tanh = torch.nn.Tanh()
|
| 110 |
-
|
| 111 |
-
def forward(self, x, upsample_factor = 1):
|
| 112 |
-
return self.l_tanh(self.l_linear(self.l_sin_gen(x, upsample_factor).to(dtype=self.l_linear.weight.dtype)))
|
| 113 |
-
|
| 114 |
-
class GeneratorNSF(torch.nn.Module):
|
| 115 |
-
def __init__(self, initial_channel, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, sr, checkpointing = False):
|
| 116 |
-
super(GeneratorNSF, self).__init__()
|
| 117 |
-
self.num_kernels = len(resblock_kernel_sizes)
|
| 118 |
-
self.num_upsamples = len(upsample_rates)
|
| 119 |
-
self.upp = math.prod(upsample_rates)
|
| 120 |
-
self.f0_upsamp = torch.nn.Upsample(scale_factor=self.upp)
|
| 121 |
-
self.m_source = SourceModuleHnNSF(sample_rate=sr, harmonic_num=0)
|
| 122 |
-
|
| 123 |
-
self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
| 124 |
-
self.checkpointing = checkpointing
|
| 125 |
-
|
| 126 |
-
self.ups = torch.nn.ModuleList()
|
| 127 |
-
self.noise_convs = torch.nn.ModuleList()
|
| 128 |
-
|
| 129 |
-
channels = [upsample_initial_channel // (2 ** (i + 1)) for i in range(self.num_upsamples)]
|
| 130 |
-
stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < self.num_upsamples else 1 for i in range(self.num_upsamples)]
|
| 131 |
-
|
| 132 |
-
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 133 |
-
self.ups.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), channels[i], k, u, padding=((k - u) // 2) if u % 2 == 0 else (u // 2 + u % 2), output_padding=u % 2)))
|
| 134 |
-
stride = stride_f0s[i]
|
| 135 |
-
kernel = 1 if stride == 1 else stride * 2 - stride % 2
|
| 136 |
-
self.noise_convs.append(torch.nn.Conv1d(1, channels[i], kernel_size=kernel, stride=stride, padding=0 if stride == 1 else (kernel - stride) // 2))
|
| 137 |
-
|
| 138 |
-
self.resblocks = torch.nn.ModuleList([ResBlock(channels[i], k, d) for i in range(len(self.ups)) for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes)])
|
| 139 |
-
self.conv_post = torch.nn.Conv1d(channels[-1], 1, 7, 1, padding=3, bias=False)
|
| 140 |
-
|
| 141 |
-
self.ups.apply(init_weights)
|
| 142 |
-
if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
| 143 |
-
|
| 144 |
-
def forward(self, x, f0, g = None):
|
| 145 |
-
har_source = self.m_source(f0, self.upp).transpose(1, 2)
|
| 146 |
-
x = self.conv_pre(x)
|
| 147 |
-
if g is not None: x += self.cond(g)
|
| 148 |
-
|
| 149 |
-
for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
|
| 150 |
-
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 151 |
-
|
| 152 |
-
if self.training and self.checkpointing:
|
| 153 |
-
x = checkpoint(ups, x, use_reentrant=False) + noise_convs(har_source)
|
| 154 |
-
xs = sum([checkpoint(resblock, x, use_reentrant=False) for j, resblock in enumerate(self.resblocks) if j in range(i * self.num_kernels, (i + 1) * self.num_kernels)])
|
| 155 |
-
else:
|
| 156 |
-
x = ups(x) + noise_convs(har_source)
|
| 157 |
-
xs = sum([resblock(x) for j, resblock in enumerate(self.resblocks) if j in range(i * self.num_kernels, (i + 1) * self.num_kernels)])
|
| 158 |
-
|
| 159 |
-
x = xs / self.num_kernels
|
| 160 |
-
|
| 161 |
-
return torch.tanh(self.conv_post(F.leaky_relu(x)))
|
| 162 |
-
|
| 163 |
-
def remove_weight_norm(self):
|
| 164 |
-
for l in self.ups:
|
| 165 |
-
remove_weight_norm(l)
|
| 166 |
-
|
| 167 |
-
for l in self.resblocks:
|
| 168 |
-
l.remove_weight_norm()
|
| 169 |
-
|
| 170 |
-
class LayerNorm(torch.nn.Module):
|
| 171 |
-
def __init__(self, channels, eps=1e-5, onnx=False):
|
| 172 |
-
super().__init__()
|
| 173 |
-
self.channels = channels
|
| 174 |
-
self.eps = eps
|
| 175 |
-
self.onnx = onnx
|
| 176 |
-
self.gamma = torch.nn.Parameter(torch.ones(channels))
|
| 177 |
-
self.beta = torch.nn.Parameter(torch.zeros(channels))
|
| 178 |
-
|
| 179 |
-
def forward(self, x):
|
| 180 |
-
x = x.transpose(1, -1)
|
| 181 |
-
return (F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) if self.onnx else F.layer_norm(x, (x.size(-1),), self.gamma, self.beta, self.eps)).transpose(1, -1)
|
| 182 |
-
|
| 183 |
-
class MultiHeadAttention(torch.nn.Module):
|
| 184 |
-
def __init__(self, channels, out_channels, n_heads, p_dropout=0.0, window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False, onnx=False):
|
| 185 |
-
super().__init__()
|
| 186 |
-
assert channels % n_heads == 0
|
| 187 |
-
self.channels = channels
|
| 188 |
-
self.out_channels = out_channels
|
| 189 |
-
self.n_heads = n_heads
|
| 190 |
-
self.p_dropout = p_dropout
|
| 191 |
-
self.window_size = window_size
|
| 192 |
-
self.heads_share = heads_share
|
| 193 |
-
self.block_length = block_length
|
| 194 |
-
self.proximal_bias = proximal_bias
|
| 195 |
-
self.proximal_init = proximal_init
|
| 196 |
-
self.onnx = onnx
|
| 197 |
-
self.attn = None
|
| 198 |
-
self.k_channels = channels // n_heads
|
| 199 |
-
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
|
| 200 |
-
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
|
| 201 |
-
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
|
| 202 |
-
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
|
| 203 |
-
self.drop = torch.nn.Dropout(p_dropout)
|
| 204 |
-
|
| 205 |
-
if window_size is not None:
|
| 206 |
-
n_heads_rel = 1 if heads_share else n_heads
|
| 207 |
-
rel_stddev = self.k_channels**-0.5
|
| 208 |
-
|
| 209 |
-
self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
| 210 |
-
self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
| 211 |
-
|
| 212 |
-
torch.nn.init.xavier_uniform_(self.conv_q.weight)
|
| 213 |
-
torch.nn.init.xavier_uniform_(self.conv_k.weight)
|
| 214 |
-
torch.nn.init.xavier_uniform_(self.conv_v.weight)
|
| 215 |
-
|
| 216 |
-
if proximal_init:
|
| 217 |
-
with torch.no_grad():
|
| 218 |
-
self.conv_k.weight.copy_(self.conv_q.weight)
|
| 219 |
-
self.conv_k.bias.copy_(self.conv_q.bias)
|
| 220 |
-
|
| 221 |
-
def forward(self, x, c, attn_mask=None):
|
| 222 |
-
q, k, v = self.conv_q(x), self.conv_k(c), self.conv_v(c)
|
| 223 |
-
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
| 224 |
-
|
| 225 |
-
return self.conv_o(x)
|
| 226 |
-
|
| 227 |
-
def attention(self, query, key, value, mask=None):
|
| 228 |
-
b, d, t_s, t_t = (*key.size(), query.size(2))
|
| 229 |
-
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
| 230 |
-
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
| 231 |
-
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
| 232 |
-
|
| 233 |
-
if self.window_size is not None:
|
| 234 |
-
assert (t_s == t_t), "(t_s == t_t)"
|
| 235 |
-
scores = scores + self._relative_position_to_absolute_position(self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), self._get_relative_embeddings(self.emb_rel_k, t_s, onnx=self.onnx)), onnx=self.onnx)
|
| 236 |
-
|
| 237 |
-
if self.proximal_bias:
|
| 238 |
-
assert t_s == t_t, "t_s == t_t"
|
| 239 |
-
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
| 240 |
-
|
| 241 |
-
if mask is not None:
|
| 242 |
-
scores = scores.masked_fill(mask == 0, -1e4)
|
| 243 |
-
if self.block_length is not None:
|
| 244 |
-
assert (t_s == t_t), "(t_s == t_t)"
|
| 245 |
-
scores = scores.masked_fill((torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)) == 0, -1e4)
|
| 246 |
-
|
| 247 |
-
p_attn = self.drop(F.softmax(scores, dim=-1))
|
| 248 |
-
output = torch.matmul(p_attn, value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3))
|
| 249 |
-
|
| 250 |
-
if self.window_size is not None: output = output + self._matmul_with_relative_values(self._absolute_position_to_relative_position(p_attn, onnx=self.onnx), self._get_relative_embeddings(self.emb_rel_v, t_s, onnx=self.onnx))
|
| 251 |
-
return (output.transpose(2, 3).contiguous().view(b, d, t_t)), p_attn
|
| 252 |
-
|
| 253 |
-
def _matmul_with_relative_values(self, x, y):
|
| 254 |
-
return torch.matmul(x, y.unsqueeze(0))
|
| 255 |
-
|
| 256 |
-
def _matmul_with_relative_keys(self, x, y):
|
| 257 |
-
return torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
| 258 |
-
|
| 259 |
-
def _get_relative_embeddings(self, relative_embeddings, length, onnx=False):
|
| 260 |
-
if onnx:
|
| 261 |
-
pad_length = torch.clamp(length - (self.window_size + 1), min=0)
|
| 262 |
-
slice_start_position = torch.clamp((self.window_size + 1) - length, min=0)
|
| 263 |
-
|
| 264 |
-
return (F.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0]) if pad_length > 0 else relative_embeddings)[:, slice_start_position:(slice_start_position + 2 * length - 1)]
|
| 265 |
-
else:
|
| 266 |
-
pad_length = max(length - (self.window_size + 1), 0)
|
| 267 |
-
slice_start_position = max((self.window_size + 1) - length, 0)
|
| 268 |
-
|
| 269 |
-
return (F.pad(relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) if pad_length > 0 else relative_embeddings)[:, slice_start_position:(slice_start_position + 2 * length - 1)]
|
| 270 |
-
|
| 271 |
-
def _relative_position_to_absolute_position(self, x, onnx=False):
|
| 272 |
-
batch, heads, length, _ = x.size()
|
| 273 |
-
|
| 274 |
-
return (F.pad(F.pad(x, [0, 1, 0, 0, 0, 0, 0, 0]).view([batch, heads, length * 2 * length]), [0, length - 1, 0, 0, 0, 0]).view([batch, heads, length + 1, 2 * length - 1]) if onnx else F.pad(F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])).view([batch, heads, length * 2 * length]), convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])).view([batch, heads, length + 1, 2 * length - 1]))[:, :, :length, length - 1 :]
|
| 275 |
-
|
| 276 |
-
def _absolute_position_to_relative_position(self, x, onnx=False):
|
| 277 |
-
batch, heads, length, _ = x.size()
|
| 278 |
-
|
| 279 |
-
return (F.pad(F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0]).view([batch, heads, length*length + length * (length - 1)]), [length, 0, 0, 0, 0, 0]).view([batch, heads, length, 2 * length]) if onnx else F.pad(F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])).view([batch, heads, length**2 + length * (length - 1)]), convert_pad_shape([[0, 0], [0, 0], [length, 0]])).view([batch, heads, length, 2 * length]))[:, :, :, 1:]
|
| 280 |
-
|
| 281 |
-
def _attention_bias_proximal(self, length):
|
| 282 |
-
r = torch.arange(length, dtype=torch.float32)
|
| 283 |
-
|
| 284 |
-
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs((torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)))), 0), 0)
|
| 285 |
-
|
| 286 |
-
class FFN(torch.nn.Module):
|
| 287 |
-
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0, activation=None, causal=False, onnx=False):
|
| 288 |
-
super().__init__()
|
| 289 |
-
self.in_channels = in_channels
|
| 290 |
-
self.out_channels = out_channels
|
| 291 |
-
self.filter_channels = filter_channels
|
| 292 |
-
self.kernel_size = kernel_size
|
| 293 |
-
self.p_dropout = p_dropout
|
| 294 |
-
self.activation = activation
|
| 295 |
-
self.causal = causal
|
| 296 |
-
self.onnx = onnx
|
| 297 |
-
self.padding = self._causal_padding if causal else self._same_padding
|
| 298 |
-
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size)
|
| 299 |
-
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size)
|
| 300 |
-
self.drop = torch.nn.Dropout(p_dropout)
|
| 301 |
-
|
| 302 |
-
def forward(self, x, x_mask):
|
| 303 |
-
x = self.conv_1(self.padding(x * x_mask))
|
| 304 |
-
|
| 305 |
-
return self.conv_2(self.padding(self.drop(((x * torch.sigmoid(1.702 * x)) if self.activation == "gelu" else torch.relu(x))) * x_mask)) * x_mask
|
| 306 |
-
|
| 307 |
-
def _causal_padding(self, x):
|
| 308 |
-
if self.kernel_size == 1: return x
|
| 309 |
-
|
| 310 |
-
return F.pad(x, [self.kernel_size - 1, 0, 0, 0, 0, 0]) if self.onnx else F.pad(x, convert_pad_shape([[0, 0], [0, 0], [(self.kernel_size - 1), 0]]))
|
| 311 |
-
|
| 312 |
-
def _same_padding(self, x):
|
| 313 |
-
if self.kernel_size == 1: return x
|
| 314 |
-
|
| 315 |
-
return F.pad(x, [(self.kernel_size - 1) // 2, self.kernel_size // 2, 0, 0, 0, 0]) if self.onnx else F.pad(x, convert_pad_shape([[0, 0], [0, 0], [((self.kernel_size - 1) // 2), (self.kernel_size // 2)]]))
|
| 316 |
-
|
| 317 |
-
class Encoder(torch.nn.Module):
|
| 318 |
-
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.0, window_size=10, onnx=False, **kwargs):
|
| 319 |
-
super().__init__()
|
| 320 |
-
self.hidden_channels = hidden_channels
|
| 321 |
-
self.filter_channels = filter_channels
|
| 322 |
-
self.n_heads = n_heads
|
| 323 |
-
self.n_layers = n_layers
|
| 324 |
-
self.kernel_size = kernel_size
|
| 325 |
-
self.p_dropout = p_dropout
|
| 326 |
-
self.window_size = window_size
|
| 327 |
-
self.drop = torch.nn.Dropout(p_dropout)
|
| 328 |
-
self.attn_layers = torch.nn.ModuleList()
|
| 329 |
-
self.norm_layers_1 = torch.nn.ModuleList()
|
| 330 |
-
self.ffn_layers = torch.nn.ModuleList()
|
| 331 |
-
self.norm_layers_2 = torch.nn.ModuleList()
|
| 332 |
-
|
| 333 |
-
for _ in range(self.n_layers):
|
| 334 |
-
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size, onnx=onnx))
|
| 335 |
-
self.norm_layers_1.append(LayerNorm(hidden_channels, onnx=onnx))
|
| 336 |
-
|
| 337 |
-
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, onnx=onnx))
|
| 338 |
-
self.norm_layers_2.append(LayerNorm(hidden_channels, onnx=onnx))
|
| 339 |
-
|
| 340 |
-
def forward(self, x, x_mask):
|
| 341 |
-
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
| 342 |
-
x = x * x_mask
|
| 343 |
-
|
| 344 |
-
for i in range(self.n_layers):
|
| 345 |
-
x = self.norm_layers_1[i](x + self.drop(self.attn_layers[i](x, x, attn_mask)))
|
| 346 |
-
x = self.norm_layers_2[i](x + self.drop(self.ffn_layers[i](x, x_mask)))
|
| 347 |
-
|
| 348 |
-
return x * x_mask
|
| 349 |
-
|
| 350 |
-
class TextEncoder(torch.nn.Module):
|
| 351 |
-
def __init__(self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, embedding_dim, f0=True, onnx=False):
|
| 352 |
-
super(TextEncoder, self).__init__()
|
| 353 |
-
self.out_channels = out_channels
|
| 354 |
-
self.hidden_channels = hidden_channels
|
| 355 |
-
self.filter_channels = filter_channels
|
| 356 |
-
self.n_heads = n_heads
|
| 357 |
-
self.n_layers = n_layers
|
| 358 |
-
self.kernel_size = kernel_size
|
| 359 |
-
self.p_dropout = float(p_dropout)
|
| 360 |
-
self.emb_phone = torch.nn.Linear(embedding_dim, hidden_channels)
|
| 361 |
-
self.lrelu = torch.nn.LeakyReLU(0.1, inplace=True)
|
| 362 |
-
if f0: self.emb_pitch = torch.nn.Embedding(256, hidden_channels)
|
| 363 |
-
self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), onnx=onnx)
|
| 364 |
-
self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
| 365 |
-
|
| 366 |
-
def forward(self, phone, pitch, lengths):
|
| 367 |
-
x = torch.transpose(self.lrelu(((self.emb_phone(phone) if pitch is None else (self.emb_phone(phone) + self.emb_pitch(pitch))) * math.sqrt(self.hidden_channels))), 1, -1)
|
| 368 |
-
x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype)
|
| 369 |
-
m, logs = torch.split((self.proj(self.encoder(x * x_mask, x_mask)) * x_mask), self.out_channels, dim=1)
|
| 370 |
-
|
| 371 |
-
return m, logs, x_mask
|
| 372 |
-
|
| 373 |
-
class PosteriorEncoder(torch.nn.Module):
|
| 374 |
-
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0):
|
| 375 |
-
super(PosteriorEncoder, self).__init__()
|
| 376 |
-
self.in_channels = in_channels
|
| 377 |
-
self.out_channels = out_channels
|
| 378 |
-
self.hidden_channels = hidden_channels
|
| 379 |
-
self.kernel_size = kernel_size
|
| 380 |
-
self.dilation_rate = dilation_rate
|
| 381 |
-
self.n_layers = n_layers
|
| 382 |
-
self.gin_channels = gin_channels
|
| 383 |
-
self.pre = torch.nn.Conv1d(in_channels, hidden_channels, 1)
|
| 384 |
-
self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
|
| 385 |
-
self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
| 386 |
-
|
| 387 |
-
def forward(self, x, x_lengths, g = None):
|
| 388 |
-
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
| 389 |
-
m, logs = torch.split((self.proj(self.enc((self.pre(x) * x_mask), x_mask, g=g)) * x_mask), self.out_channels, dim=1)
|
| 390 |
-
|
| 391 |
-
return ((m + torch.randn_like(m) * torch.exp(logs)) * x_mask), m, logs, x_mask
|
| 392 |
-
|
| 393 |
-
def remove_weight_norm(self):
|
| 394 |
-
self.enc.remove_weight_norm()
|
| 395 |
-
|
| 396 |
-
class Synthesizer(torch.nn.Module):
|
| 397 |
-
def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, use_f0, text_enc_hidden_dim=768, vocoder="Default", checkpointing=False, onnx=False, **kwargs):
|
| 398 |
-
super(Synthesizer, self).__init__()
|
| 399 |
-
self.spec_channels = spec_channels
|
| 400 |
-
self.inter_channels = inter_channels
|
| 401 |
-
self.hidden_channels = hidden_channels
|
| 402 |
-
self.filter_channels = filter_channels
|
| 403 |
-
self.n_heads = n_heads
|
| 404 |
-
self.n_layers = n_layers
|
| 405 |
-
self.kernel_size = kernel_size
|
| 406 |
-
self.p_dropout = float(p_dropout)
|
| 407 |
-
self.resblock_kernel_sizes = resblock_kernel_sizes
|
| 408 |
-
self.resblock_dilation_sizes = resblock_dilation_sizes
|
| 409 |
-
self.upsample_rates = upsample_rates
|
| 410 |
-
self.upsample_initial_channel = upsample_initial_channel
|
| 411 |
-
self.upsample_kernel_sizes = upsample_kernel_sizes
|
| 412 |
-
self.segment_size = segment_size
|
| 413 |
-
self.gin_channels = gin_channels
|
| 414 |
-
self.spk_embed_dim = spk_embed_dim
|
| 415 |
-
self.use_f0 = use_f0
|
| 416 |
-
self.enc_p = TextEncoder(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), text_enc_hidden_dim, f0=use_f0, onnx=onnx)
|
| 417 |
-
|
| 418 |
-
if use_f0:
|
| 419 |
-
if vocoder == "RefineGAN": self.dec = RefineGANGenerator(sample_rate=sr, upsample_rates=upsample_rates, num_mels=inter_channels, checkpointing=checkpointing)
|
| 420 |
-
elif vocoder in ["MRF-HiFi-GAN", "MRF HiFi-GAN"]: self.dec = HiFiGANMRFGenerator(in_channel=inter_channels, upsample_initial_channel=upsample_initial_channel, upsample_rates=upsample_rates, upsample_kernel_sizes=upsample_kernel_sizes, resblock_kernel_sizes=resblock_kernel_sizes, resblock_dilations=resblock_dilation_sizes, gin_channels=gin_channels, sample_rate=sr, harmonic_num=8, checkpointing=checkpointing)
|
| 421 |
-
else: self.dec = GeneratorNSF(inter_channels, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, checkpointing=checkpointing)
|
| 422 |
-
else: self.dec = Generator(inter_channels, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
| 423 |
-
|
| 424 |
-
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
|
| 425 |
-
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels)
|
| 426 |
-
self.emb_g = torch.nn.Embedding(self.spk_embed_dim, gin_channels)
|
| 427 |
-
|
| 428 |
-
def remove_weight_norm(self):
|
| 429 |
-
self.dec.remove_weight_norm()
|
| 430 |
-
self.flow.remove_weight_norm()
|
| 431 |
-
self.enc_q.remove_weight_norm()
|
| 432 |
-
|
| 433 |
-
@torch.jit.ignore
|
| 434 |
-
def forward(self, phone, phone_lengths, pitch = None, pitchf = None, y = None, y_lengths = None, ds = None):
|
| 435 |
-
g = self.emb_g(ds).unsqueeze(-1)
|
| 436 |
-
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
| 437 |
-
|
| 438 |
-
if y is not None:
|
| 439 |
-
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
| 440 |
-
z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size)
|
| 441 |
-
|
| 442 |
-
return (self.dec(z_slice, slice_segments(pitchf, ids_slice, self.segment_size, 2), g=g) if self.use_f0 else self.dec(z_slice, g=g)), ids_slice, x_mask, y_mask, (z, self.flow(z, y_mask, g=g), m_p, logs_p, m_q, logs_q)
|
| 443 |
-
else: return None, None, x_mask, None, (None, None, m_p, logs_p, None, None)
|
| 444 |
-
|
| 445 |
-
@torch.jit.export
|
| 446 |
-
def infer(self, phone, phone_lengths, pitch = None, nsff0 = None, sid = None, rate = None):
|
| 447 |
-
g = self.emb_g(sid).unsqueeze(-1)
|
| 448 |
-
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
| 449 |
-
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
| 450 |
-
|
| 451 |
-
if rate is not None:
|
| 452 |
-
assert isinstance(rate, torch.Tensor)
|
| 453 |
-
head = int(z_p.shape[2] * (1.0 - rate.item()))
|
| 454 |
-
z_p = z_p[:, :, head:]
|
| 455 |
-
x_mask = x_mask[:, :, head:]
|
| 456 |
-
if self.use_f0: nsff0 = nsff0[:, head:]
|
| 457 |
-
|
| 458 |
-
if self.use_f0:
|
| 459 |
-
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
| 460 |
-
o = self.dec(z * x_mask, nsff0, g=g)
|
| 461 |
-
else:
|
| 462 |
-
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
| 463 |
-
o = self.dec(z * x_mask, g=g)
|
| 464 |
-
|
| 465 |
-
return o, x_mask, (z, z_p, m_p, logs_p)
|
| 466 |
-
|
| 467 |
-
class SynthesizerONNX(Synthesizer):
|
| 468 |
-
def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, use_f0, text_enc_hidden_dim=768, vocoder="Default", checkpointing=False, **kwargs):
|
| 469 |
-
super().__init__(spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, use_f0, text_enc_hidden_dim, vocoder, checkpointing, True)
|
| 470 |
-
self.speaker_map = None
|
| 471 |
-
|
| 472 |
-
def remove_weight_norm(self):
|
| 473 |
-
self.dec.remove_weight_norm()
|
| 474 |
-
self.flow.remove_weight_norm()
|
| 475 |
-
self.enc_q.remove_weight_norm()
|
| 476 |
-
|
| 477 |
-
def construct_spkmixmap(self, n_speaker):
|
| 478 |
-
self.speaker_map = torch.zeros((n_speaker, 1, 1, self.gin_channels))
|
| 479 |
-
|
| 480 |
-
for i in range(n_speaker):
|
| 481 |
-
self.speaker_map[i] = self.emb_g(torch.LongTensor([[i]]))
|
| 482 |
-
|
| 483 |
-
self.speaker_map = self.speaker_map.unsqueeze(0)
|
| 484 |
-
|
| 485 |
-
def forward(self, phone, phone_lengths, g=None, rnd=None, pitch=None, nsff0=None, max_len=None):
|
| 486 |
-
g = self.emb_g(g).unsqueeze(-1)
|
| 487 |
-
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
| 488 |
-
z_p = (m_p + torch.exp(logs_p) * rnd) * x_mask
|
| 489 |
-
|
| 490 |
-
return self.dec((self.flow(z_p, x_mask, g=g, reverse=True) * x_mask)[:, :, :max_len], nsff0, g=g) if self.use_f0 else self.dec((self.flow(z_p, x_mask, g=g, reverse=True) * x_mask)[:, :, :max_len], g=g)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/architectures/demucs_separator.py
DELETED
|
@@ -1,180 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import yaml
|
| 4 |
-
import torch
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
from hashlib import sha256
|
| 8 |
-
|
| 9 |
-
sys.path.append(os.getcwd())
|
| 10 |
-
|
| 11 |
-
from main.configs.config import Config
|
| 12 |
-
from main.library.uvr5_separator import spec_utils, common_separator
|
| 13 |
-
from main.library.uvr5_separator.demucs import hdemucs, states, apply
|
| 14 |
-
|
| 15 |
-
translations = Config().translations
|
| 16 |
-
sys.path.insert(0, os.path.join(os.getcwd(), "main", "library", "uvr5_separator"))
|
| 17 |
-
|
| 18 |
-
DEMUCS_4_SOURCE_MAPPER = {common_separator.CommonSeparator.BASS_STEM: 0, common_separator.CommonSeparator.DRUM_STEM: 1, common_separator.CommonSeparator.OTHER_STEM: 2, common_separator.CommonSeparator.VOCAL_STEM: 3}
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class DemucsSeparator(common_separator.CommonSeparator):
|
| 22 |
-
def __init__(self, common_config, arch_config):
|
| 23 |
-
super().__init__(config=common_config)
|
| 24 |
-
self.segment_size = arch_config.get("segment_size", "Default")
|
| 25 |
-
self.shifts = arch_config.get("shifts", 2)
|
| 26 |
-
self.overlap = arch_config.get("overlap", 0.25)
|
| 27 |
-
self.segments_enabled = arch_config.get("segments_enabled", True)
|
| 28 |
-
self.logger.debug(translations["demucs_info"].format(segment_size=self.segment_size, segments_enabled=self.segments_enabled))
|
| 29 |
-
self.logger.debug(translations["demucs_info_2"].format(shifts=self.shifts, overlap=self.overlap))
|
| 30 |
-
self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER
|
| 31 |
-
self.audio_file_path = None
|
| 32 |
-
self.audio_file_base = None
|
| 33 |
-
self.demucs_model_instance = None
|
| 34 |
-
self.logger.info(translations["start_demucs"])
|
| 35 |
-
|
| 36 |
-
def separate(self, audio_file_path):
|
| 37 |
-
self.logger.debug(translations["start_separator"])
|
| 38 |
-
source = None
|
| 39 |
-
inst_source = {}
|
| 40 |
-
self.audio_file_path = audio_file_path
|
| 41 |
-
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
|
| 42 |
-
self.logger.debug(translations["prepare_mix"])
|
| 43 |
-
mix = self.prepare_mix(self.audio_file_path)
|
| 44 |
-
self.logger.debug(translations["demix"].format(shape=mix.shape))
|
| 45 |
-
self.logger.debug(translations["cancel_mix"])
|
| 46 |
-
self.demucs_model_instance = hdemucs.HDemucs(sources=["drums", "bass", "other", "vocals"])
|
| 47 |
-
self.demucs_model_instance = get_demucs_model(name=os.path.splitext(os.path.basename(self.model_path))[0], repo=os.path.dirname(self.model_path))
|
| 48 |
-
self.demucs_model_instance = apply.demucs_segments(self.segment_size, self.demucs_model_instance)
|
| 49 |
-
self.demucs_model_instance.to(self.torch_device)
|
| 50 |
-
self.demucs_model_instance.eval()
|
| 51 |
-
self.logger.debug(translations["model_review"])
|
| 52 |
-
source = self.demix_demucs(mix)
|
| 53 |
-
del self.demucs_model_instance
|
| 54 |
-
self.clear_gpu_cache()
|
| 55 |
-
self.logger.debug(translations["del_gpu_cache_after_demix"])
|
| 56 |
-
output_files = []
|
| 57 |
-
self.logger.debug(translations["process_output_file"])
|
| 58 |
-
|
| 59 |
-
if isinstance(inst_source, np.ndarray):
|
| 60 |
-
self.logger.debug(translations["process_ver"])
|
| 61 |
-
inst_source[self.demucs_source_map[common_separator.CommonSeparator.VOCAL_STEM]] = spec_utils.reshape_sources(inst_source[self.demucs_source_map[common_separator.CommonSeparator.VOCAL_STEM]], source[self.demucs_source_map[common_separator.CommonSeparator.VOCAL_STEM]])
|
| 62 |
-
source = inst_source
|
| 63 |
-
|
| 64 |
-
if isinstance(source, np.ndarray):
|
| 65 |
-
source_length = len(source)
|
| 66 |
-
self.logger.debug(translations["source_length"].format(source_length=source_length))
|
| 67 |
-
self.logger.debug(translations["set_map"].format(part=source_length))
|
| 68 |
-
|
| 69 |
-
match source_length:
|
| 70 |
-
case 2: self.demucs_source_map = {common_separator.CommonSeparator.INST_STEM: 0, common_separator.CommonSeparator.VOCAL_STEM: 1}
|
| 71 |
-
case 6: self.demucs_source_map = {common_separator.CommonSeparator.BASS_STEM: 0, common_separator.CommonSeparator.DRUM_STEM: 1, common_separator.CommonSeparator.OTHER_STEM: 2, common_separator.CommonSeparator.VOCAL_STEM: 3, common_separator.CommonSeparator.GUITAR_STEM: 4, common_separator.CommonSeparator.PIANO_STEM: 5}
|
| 72 |
-
case _: self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER
|
| 73 |
-
|
| 74 |
-
self.logger.debug(translations["process_all_part"])
|
| 75 |
-
|
| 76 |
-
for stem_name, stem_value in self.demucs_source_map.items():
|
| 77 |
-
if self.output_single_stem is not None:
|
| 78 |
-
if stem_name.lower() != self.output_single_stem.lower():
|
| 79 |
-
self.logger.debug(translations["skip_part"].format(stem_name=stem_name, output_single_stem=self.output_single_stem))
|
| 80 |
-
continue
|
| 81 |
-
|
| 82 |
-
stem_path = os.path.join(f"{self.audio_file_base}_({stem_name})_{self.model_name}.{self.output_format.lower()}")
|
| 83 |
-
self.final_process(stem_path, source[stem_value].T, stem_name)
|
| 84 |
-
output_files.append(stem_path)
|
| 85 |
-
|
| 86 |
-
return output_files
|
| 87 |
-
|
| 88 |
-
def demix_demucs(self, mix):
|
| 89 |
-
self.logger.debug(translations["starting_demix_demucs"])
|
| 90 |
-
processed = {}
|
| 91 |
-
mix = torch.tensor(mix, dtype=torch.float32)
|
| 92 |
-
ref = mix.mean(0)
|
| 93 |
-
mix = (mix - ref.mean()) / ref.std()
|
| 94 |
-
mix_infer = mix
|
| 95 |
-
|
| 96 |
-
with torch.no_grad():
|
| 97 |
-
self.logger.debug(translations["model_infer"])
|
| 98 |
-
sources = apply.apply_model(model=self.demucs_model_instance, mix=mix_infer[None], shifts=self.shifts, split=self.segments_enabled, overlap=self.overlap, static_shifts=1 if self.shifts == 0 else self.shifts, set_progress_bar=None, device=self.torch_device, progress=True)[0]
|
| 99 |
-
|
| 100 |
-
sources = (sources * ref.std() + ref.mean()).cpu().numpy()
|
| 101 |
-
sources[[0, 1]] = sources[[1, 0]]
|
| 102 |
-
|
| 103 |
-
processed[mix] = sources[:, :, 0:None].copy()
|
| 104 |
-
return np.concatenate([s[:, :, 0:None] for s in list(processed.values())], axis=-1)
|
| 105 |
-
|
| 106 |
-
class LocalRepo:
|
| 107 |
-
def __init__(self, root):
|
| 108 |
-
self.root = root
|
| 109 |
-
self.scan()
|
| 110 |
-
|
| 111 |
-
def scan(self):
|
| 112 |
-
self._models, self._checksums = {}, {}
|
| 113 |
-
for filename in os.listdir(self.root):
|
| 114 |
-
filepath = os.path.join(self.root, filename)
|
| 115 |
-
if not os.path.isfile(filepath): continue
|
| 116 |
-
|
| 117 |
-
if os.path.splitext(filename)[1] == ".th":
|
| 118 |
-
stem = os.path.splitext(filename)[0]
|
| 119 |
-
|
| 120 |
-
if "-" in stem:
|
| 121 |
-
xp_sig, checksum = stem.split("-", 1)
|
| 122 |
-
self._checksums[xp_sig] = checksum
|
| 123 |
-
else: xp_sig = stem
|
| 124 |
-
|
| 125 |
-
if xp_sig in self._models: raise RuntimeError(translations["del_all_but_one"].format(xp_sig=xp_sig))
|
| 126 |
-
self._models[xp_sig] = filepath
|
| 127 |
-
|
| 128 |
-
def has_model(self, sig):
|
| 129 |
-
return sig in self._models
|
| 130 |
-
|
| 131 |
-
def get_model(self, sig):
|
| 132 |
-
try:
|
| 133 |
-
file = self._models[sig]
|
| 134 |
-
except KeyError:
|
| 135 |
-
raise RuntimeError(translations["not_found_model_signature"].format(sig=sig))
|
| 136 |
-
|
| 137 |
-
if sig in self._checksums: check_checksum(file, self._checksums[sig])
|
| 138 |
-
return states.load_model(file)
|
| 139 |
-
|
| 140 |
-
class BagOnlyRepo:
|
| 141 |
-
def __init__(self, root, model_repo):
|
| 142 |
-
self.root = root
|
| 143 |
-
self.model_repo = model_repo
|
| 144 |
-
self.scan()
|
| 145 |
-
|
| 146 |
-
def scan(self):
|
| 147 |
-
self._bags = {}
|
| 148 |
-
for filename in os.listdir(self.root):
|
| 149 |
-
filepath = os.path.join(self.root, filename)
|
| 150 |
-
|
| 151 |
-
if os.path.isfile(filepath) and os.path.splitext(filename)[1] == ".yaml":
|
| 152 |
-
stem = os.path.splitext(filename)[0]
|
| 153 |
-
self._bags[stem] = filepath
|
| 154 |
-
|
| 155 |
-
def get_model(self, name):
|
| 156 |
-
try:
|
| 157 |
-
yaml_file = self._bags[name]
|
| 158 |
-
except KeyError:
|
| 159 |
-
raise RuntimeError(translations["name_not_pretrained"].format(name=name))
|
| 160 |
-
|
| 161 |
-
with open(yaml_file, 'r') as f:
|
| 162 |
-
bag = yaml.safe_load(f)
|
| 163 |
-
|
| 164 |
-
return apply.BagOfModels([self.model_repo.get_model(sig) for sig in bag["models"]], bag.get("weights"), bag.get("segment"))
|
| 165 |
-
|
| 166 |
-
def check_checksum(path, checksum):
|
| 167 |
-
sha = sha256()
|
| 168 |
-
|
| 169 |
-
with open(path, "rb") as file:
|
| 170 |
-
while 1:
|
| 171 |
-
buf = file.read(2**20)
|
| 172 |
-
if not buf: break
|
| 173 |
-
sha.update(buf)
|
| 174 |
-
|
| 175 |
-
actual_checksum = sha.hexdigest()[:len(checksum)]
|
| 176 |
-
if actual_checksum != checksum: raise RuntimeError(translations["invalid_checksum"].format(path=path, checksum=checksum, actual_checksum=actual_checksum))
|
| 177 |
-
|
| 178 |
-
def get_demucs_model(name, repo = None):
|
| 179 |
-
model_repo = LocalRepo(repo)
|
| 180 |
-
return (model_repo.get_model(name) if model_repo.has_model(name) else BagOnlyRepo(repo, model_repo).get_model(name)).eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/architectures/fairseq.py
DELETED
|
@@ -1,1480 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
import sys
|
| 3 |
-
import math
|
| 4 |
-
import uuid
|
| 5 |
-
import torch
|
| 6 |
-
import types
|
| 7 |
-
import contextlib
|
| 8 |
-
|
| 9 |
-
import numpy as np
|
| 10 |
-
import torch.nn.functional as F
|
| 11 |
-
|
| 12 |
-
from torch import nn
|
| 13 |
-
from omegaconf import DictConfig, open_dict
|
| 14 |
-
|
| 15 |
-
class Dictionary:
|
| 16 |
-
def __init__(self, *args, **kwargs):
|
| 17 |
-
pass
|
| 18 |
-
|
| 19 |
-
fairseq = types.ModuleType("fairseq")
|
| 20 |
-
fairseq_data = types.ModuleType("fairseq.data")
|
| 21 |
-
fairseq_data_dictionary = types.ModuleType("fairseq.data.dictionary")
|
| 22 |
-
fairseq_data_dictionary.Dictionary = Dictionary
|
| 23 |
-
fairseq.data = fairseq_data
|
| 24 |
-
fairseq_data.dictionary = fairseq_data_dictionary
|
| 25 |
-
|
| 26 |
-
sys.modules["fairseq"] = fairseq
|
| 27 |
-
sys.modules["fairseq.data"] = fairseq_data
|
| 28 |
-
sys.modules["fairseq.data.dictionary"] = fairseq_data_dictionary
|
| 29 |
-
|
| 30 |
-
def load_model(filename):
|
| 31 |
-
state = torch.load(filename, map_location="cpu")
|
| 32 |
-
|
| 33 |
-
model = HubertModel(HubertConfig(**state['cfg']['model']))
|
| 34 |
-
model.load_state_dict(state['model'], strict=False)
|
| 35 |
-
|
| 36 |
-
return [model], Model_Config(state["cfg"]), Model_Config(state["cfg"]["task"])
|
| 37 |
-
|
| 38 |
-
def softmax(x, dim, onnx_trace = False):
|
| 39 |
-
return F.softmax(x.float(), dim=dim) if onnx_trace else F.softmax(x, dim=dim, dtype=torch.float32)
|
| 40 |
-
|
| 41 |
-
def log_softmax(x, dim, onnx_trace = False):
|
| 42 |
-
return F.log_softmax(x.float(), dim=dim) if onnx_trace else F.log_softmax(x, dim=dim, dtype=torch.float32)
|
| 43 |
-
|
| 44 |
-
def eval_str_dict(x, type=dict):
|
| 45 |
-
if x is None: return None
|
| 46 |
-
if isinstance(x, str): x = eval(x)
|
| 47 |
-
return x
|
| 48 |
-
|
| 49 |
-
def with_incremental_state(cls):
|
| 50 |
-
cls.__bases__ = (FairseqIncrementalState,) + tuple(b for b in cls.__bases__ if b != FairseqIncrementalState)
|
| 51 |
-
return cls
|
| 52 |
-
|
| 53 |
-
def quant_noise(module, p, block_size):
|
| 54 |
-
if p <= 0: return module
|
| 55 |
-
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
| 56 |
-
|
| 57 |
-
is_conv = module.weight.ndim == 4
|
| 58 |
-
if not is_conv: assert (module.weight.size(1) % block_size == 0)
|
| 59 |
-
else:
|
| 60 |
-
if module.kernel_size == (1, 1): assert (module.in_channels % block_size == 0)
|
| 61 |
-
else:
|
| 62 |
-
k = module.kernel_size[0] * module.kernel_size[1]
|
| 63 |
-
assert k % block_size == 0
|
| 64 |
-
|
| 65 |
-
def _forward_pre_hook(mod, input):
|
| 66 |
-
if mod.training:
|
| 67 |
-
if not is_conv:
|
| 68 |
-
weight = mod.weight
|
| 69 |
-
in_features = weight.size(1)
|
| 70 |
-
out_features = weight.size(0)
|
| 71 |
-
|
| 72 |
-
mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
|
| 73 |
-
mask.bernoulli_(p)
|
| 74 |
-
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
| 75 |
-
else:
|
| 76 |
-
weight = mod.weight
|
| 77 |
-
in_channels = mod.in_channels
|
| 78 |
-
out_channels = mod.out_channels
|
| 79 |
-
|
| 80 |
-
if mod.kernel_size == (1, 1):
|
| 81 |
-
mask = torch.zeros(int(in_channels // block_size * out_channels), device=weight.device)
|
| 82 |
-
mask.bernoulli_(p)
|
| 83 |
-
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
| 84 |
-
else:
|
| 85 |
-
mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
|
| 86 |
-
mask.bernoulli_(p)
|
| 87 |
-
mask = (mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]))
|
| 88 |
-
|
| 89 |
-
mask = mask.to(torch.bool)
|
| 90 |
-
s = 1 / (1 - p)
|
| 91 |
-
mod.weight.data = s * weight.masked_fill(mask, 0)
|
| 92 |
-
|
| 93 |
-
module.register_forward_pre_hook(_forward_pre_hook)
|
| 94 |
-
return module
|
| 95 |
-
|
| 96 |
-
class FairseqDropout(nn.Module):
|
| 97 |
-
def __init__(self, p, module_name=None):
|
| 98 |
-
super().__init__()
|
| 99 |
-
self.p = p
|
| 100 |
-
self.module_name = module_name
|
| 101 |
-
self.apply_during_inference = False
|
| 102 |
-
|
| 103 |
-
def forward(self, x, inplace = False):
|
| 104 |
-
return F.dropout(x, p=self.p, training=True, inplace=inplace) if self.p > 0 and (self.training or self.apply_during_inference) else x
|
| 105 |
-
|
| 106 |
-
def make_generation_fast_(self, name, retain_dropout = False, retain_dropout_modules = None, **kwargs):
|
| 107 |
-
if retain_dropout:
|
| 108 |
-
if (retain_dropout_modules is None or self.module_name in retain_dropout_modules): self.apply_during_inference = True
|
| 109 |
-
|
| 110 |
-
class FairseqIncrementalState(object):
|
| 111 |
-
def __init__(self, *args, **kwargs):
|
| 112 |
-
super().__init__(*args, **kwargs)
|
| 113 |
-
self.init_incremental_state()
|
| 114 |
-
|
| 115 |
-
def init_incremental_state(self):
|
| 116 |
-
self._incremental_state_id = str(uuid.uuid4())
|
| 117 |
-
|
| 118 |
-
def _get_full_incremental_state_key(self, key):
|
| 119 |
-
return "{}.{}".format(self._incremental_state_id, key)
|
| 120 |
-
|
| 121 |
-
def get_incremental_state(self, incremental_state, key):
|
| 122 |
-
full_key = self._get_full_incremental_state_key(key)
|
| 123 |
-
if incremental_state is None or full_key not in incremental_state: return None
|
| 124 |
-
return incremental_state[full_key]
|
| 125 |
-
|
| 126 |
-
def set_incremental_state(self, incremental_state, key, value):
|
| 127 |
-
if incremental_state is not None: incremental_state[self._get_full_incremental_state_key(key)] = value
|
| 128 |
-
return incremental_state
|
| 129 |
-
|
| 130 |
-
class FairseqDecoder(nn.Module):
|
| 131 |
-
def __init__(self, dictionary):
|
| 132 |
-
super().__init__()
|
| 133 |
-
self.dictionary = dictionary
|
| 134 |
-
self.onnx_trace = False
|
| 135 |
-
self.adaptive_softmax = None
|
| 136 |
-
|
| 137 |
-
def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
|
| 138 |
-
x, extra = self.extract_features(prev_output_tokens, encoder_out=encoder_out, **kwargs)
|
| 139 |
-
return self.output_layer(x), extra
|
| 140 |
-
|
| 141 |
-
def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs):
|
| 142 |
-
pass
|
| 143 |
-
|
| 144 |
-
def output_layer(self, features, **kwargs):
|
| 145 |
-
pass
|
| 146 |
-
|
| 147 |
-
def get_normalized_probs(self, net_output, log_probs, sample = None):
|
| 148 |
-
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
|
| 149 |
-
|
| 150 |
-
def get_normalized_probs_scriptable(self, net_output, log_probs, sample = None):
|
| 151 |
-
if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
|
| 152 |
-
if sample is not None:
|
| 153 |
-
assert "target" in sample
|
| 154 |
-
target = sample["target"]
|
| 155 |
-
else: target = None
|
| 156 |
-
|
| 157 |
-
out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
|
| 158 |
-
return out.exp_() if not log_probs else out
|
| 159 |
-
|
| 160 |
-
logits = net_output[0]
|
| 161 |
-
return log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) if log_probs else softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
|
| 162 |
-
|
| 163 |
-
def max_positions(self):
|
| 164 |
-
return 1e6
|
| 165 |
-
|
| 166 |
-
def upgrade_state_dict_named(self, state_dict, name):
|
| 167 |
-
return state_dict
|
| 168 |
-
|
| 169 |
-
def prepare_for_onnx_export_(self):
|
| 170 |
-
self.onnx_trace = True
|
| 171 |
-
|
| 172 |
-
@with_incremental_state
|
| 173 |
-
class FairseqIncrementalDecoder(FairseqDecoder):
|
| 174 |
-
def __init__(self, dictionary):
|
| 175 |
-
super().__init__(dictionary)
|
| 176 |
-
|
| 177 |
-
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
|
| 178 |
-
pass
|
| 179 |
-
|
| 180 |
-
def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
|
| 181 |
-
pass
|
| 182 |
-
|
| 183 |
-
def reorder_incremental_state(self, incremental_state, new_order):
|
| 184 |
-
pass
|
| 185 |
-
|
| 186 |
-
def reorder_incremental_state_scripting(self, incremental_state, new_order):
|
| 187 |
-
for module in self.modules():
|
| 188 |
-
if hasattr(module, "reorder_incremental_state"):
|
| 189 |
-
result = module.reorder_incremental_state(incremental_state, new_order)
|
| 190 |
-
if result is not None: incremental_state = result
|
| 191 |
-
|
| 192 |
-
def set_beam_size(self, beam_size):
|
| 193 |
-
if getattr(self, "_beam_size", -1) != beam_size:
|
| 194 |
-
seen = set()
|
| 195 |
-
|
| 196 |
-
def apply_set_beam_size(module):
|
| 197 |
-
if (module != self and hasattr(module, "set_beam_size") and module not in seen):
|
| 198 |
-
seen.add(module)
|
| 199 |
-
module.set_beam_size(beam_size)
|
| 200 |
-
|
| 201 |
-
self.apply(apply_set_beam_size)
|
| 202 |
-
self._beam_size = beam_size
|
| 203 |
-
|
| 204 |
-
class MultiheadAttention(FairseqIncrementalDecoder):
|
| 205 |
-
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, dictionary=None, q_noise=0.0, qn_block_size=8, xformers_att_config=None, xformers_blocksparse_layout=None, xformers_blocksparse_blocksize=16):
|
| 206 |
-
super().__init__(dictionary)
|
| 207 |
-
xformers_att_config = eval_str_dict(xformers_att_config)
|
| 208 |
-
self.use_xformers = xformers_att_config is not None
|
| 209 |
-
if self.use_xformers: raise ImportError
|
| 210 |
-
self.embed_dim = embed_dim
|
| 211 |
-
self.kdim = kdim if kdim is not None else embed_dim
|
| 212 |
-
self.vdim = vdim if vdim is not None else embed_dim
|
| 213 |
-
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
| 214 |
-
self.num_heads = num_heads
|
| 215 |
-
self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__)
|
| 216 |
-
self.head_dim = embed_dim // num_heads
|
| 217 |
-
assert (self.head_dim * num_heads == self.embed_dim)
|
| 218 |
-
self.scaling = self.head_dim**-0.5
|
| 219 |
-
self.self_attention = self_attention
|
| 220 |
-
self.encoder_decoder_attention = encoder_decoder_attention
|
| 221 |
-
assert not self.self_attention or self.qkv_same_dim
|
| 222 |
-
self.k_proj = quant_noise(nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size)
|
| 223 |
-
self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
|
| 224 |
-
self.q_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
|
| 225 |
-
self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
|
| 226 |
-
if add_bias_kv:
|
| 227 |
-
self.bias_k = nn.Parameter(torch.Tensor(1, 1, embed_dim))
|
| 228 |
-
self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim))
|
| 229 |
-
else: self.bias_k = self.bias_v = None
|
| 230 |
-
self.add_zero_attn = add_zero_attn
|
| 231 |
-
self.beam_size = 1
|
| 232 |
-
self.reset_parameters()
|
| 233 |
-
self.onnx_trace = False
|
| 234 |
-
self.skip_embed_dim_check = False
|
| 235 |
-
self.init_incremental_state()
|
| 236 |
-
|
| 237 |
-
def prepare_for_onnx_export_(self):
|
| 238 |
-
self.onnx_trace = True
|
| 239 |
-
|
| 240 |
-
def reset_parameters(self):
|
| 241 |
-
if self.qkv_same_dim:
|
| 242 |
-
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
| 243 |
-
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
| 244 |
-
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
| 245 |
-
else:
|
| 246 |
-
nn.init.xavier_uniform_(self.k_proj.weight)
|
| 247 |
-
nn.init.xavier_uniform_(self.v_proj.weight)
|
| 248 |
-
nn.init.xavier_uniform_(self.q_proj.weight)
|
| 249 |
-
|
| 250 |
-
nn.init.xavier_uniform_(self.out_proj.weight)
|
| 251 |
-
|
| 252 |
-
if self.out_proj.bias is not None: nn.init.constant_(self.out_proj.bias, 0.0)
|
| 253 |
-
if self.bias_k is not None: nn.init.xavier_normal_(self.bias_k)
|
| 254 |
-
if self.bias_v is not None: nn.init.xavier_normal_(self.bias_v)
|
| 255 |
-
|
| 256 |
-
def _get_reserve_head_index(self, num_heads_to_keep: int):
|
| 257 |
-
k_proj_heads_norm, q_proj_heads_norm, v_proj_heads_norm = [], [], []
|
| 258 |
-
for i in range(self.num_heads):
|
| 259 |
-
start_idx = i * self.head_dim
|
| 260 |
-
end_idx = (i + 1) * self.head_dim
|
| 261 |
-
k_proj_heads_norm.append(torch.sum(torch.abs(self.k_proj.weight[start_idx:end_idx])).tolist() + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist())
|
| 262 |
-
q_proj_heads_norm.append(torch.sum(torch.abs(self.q_proj.weight[start_idx:end_idx])).tolist() + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist())
|
| 263 |
-
v_proj_heads_norm.append(torch.sum(torch.abs(self.v_proj.weight[start_idx:end_idx])).tolist() + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist())
|
| 264 |
-
|
| 265 |
-
heads_norm = []
|
| 266 |
-
for i in range(self.num_heads):
|
| 267 |
-
heads_norm.append(k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i])
|
| 268 |
-
|
| 269 |
-
sorted_head_index = sorted(range(self.num_heads), key=lambda k: heads_norm[k], reverse=True)
|
| 270 |
-
reserve_head_index = []
|
| 271 |
-
for i in range(num_heads_to_keep):
|
| 272 |
-
reserve_head_index.append((sorted_head_index[i] * self.head_dim, (sorted_head_index[i] + 1) * self.head_dim))
|
| 273 |
-
return reserve_head_index
|
| 274 |
-
|
| 275 |
-
def _adaptive_prune_heads(self, reserve_head_index):
|
| 276 |
-
new_q_weight, new_q_bias, new_k_weight, new_k_bias, new_v_weight, new_v_bias, new_out_proj_weight = [], [], [], [], [], [], []
|
| 277 |
-
|
| 278 |
-
for ele in reserve_head_index:
|
| 279 |
-
start_idx, end_idx = ele
|
| 280 |
-
new_q_weight.append(self.q_proj.weight[start_idx:end_idx])
|
| 281 |
-
new_q_bias.append(self.q_proj.bias[start_idx:end_idx])
|
| 282 |
-
new_k_weight.append(self.k_proj.weight[start_idx:end_idx])
|
| 283 |
-
new_k_bias.append(self.k_proj.bias[start_idx:end_idx])
|
| 284 |
-
new_v_weight.append(self.v_proj.weight[start_idx:end_idx])
|
| 285 |
-
new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
|
| 286 |
-
new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx])
|
| 287 |
-
|
| 288 |
-
new_q_weight = torch.cat(new_q_weight).detach()
|
| 289 |
-
new_k_weight = torch.cat(new_k_weight).detach()
|
| 290 |
-
new_v_weight = torch.cat(new_v_weight).detach()
|
| 291 |
-
new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach()
|
| 292 |
-
new_q_weight.requires_grad = True
|
| 293 |
-
new_k_weight.requires_grad = True
|
| 294 |
-
new_v_weight.requires_grad = True
|
| 295 |
-
new_out_proj_weight.requires_grad = True
|
| 296 |
-
new_q_bias = torch.cat(new_q_bias).detach()
|
| 297 |
-
new_q_bias.requires_grad = True
|
| 298 |
-
new_k_bias = torch.cat(new_k_bias).detach()
|
| 299 |
-
new_k_bias.requires_grad = True
|
| 300 |
-
new_v_bias = torch.cat(new_v_bias).detach()
|
| 301 |
-
new_v_bias.requires_grad = True
|
| 302 |
-
|
| 303 |
-
self.q_proj.weight = nn.Parameter(new_q_weight)
|
| 304 |
-
self.q_proj.bias = nn.Parameter(new_q_bias)
|
| 305 |
-
self.k_proj.weight = nn.Parameter(new_k_weight)
|
| 306 |
-
self.k_proj.bias = nn.Parameter(new_k_bias)
|
| 307 |
-
self.v_proj.weight = nn.Parameter(new_v_weight)
|
| 308 |
-
self.v_proj.bias = nn.Parameter(new_v_bias)
|
| 309 |
-
self.out_proj.weight = nn.Parameter(new_out_proj_weight)
|
| 310 |
-
self.num_heads = len(reserve_head_index)
|
| 311 |
-
self.embed_dim = self.head_dim * self.num_heads
|
| 312 |
-
self.q_proj.out_features = self.embed_dim
|
| 313 |
-
self.k_proj.out_features = self.embed_dim
|
| 314 |
-
self.v_proj.out_features = self.embed_dim
|
| 315 |
-
|
| 316 |
-
def _set_skip_embed_dim_check(self):
|
| 317 |
-
self.skip_embed_dim_check = True
|
| 318 |
-
|
| 319 |
-
def _pad_masks(self, key_padding_mask, attn_mask):
|
| 320 |
-
if attn_mask is not None:
|
| 321 |
-
shape = attn_mask.size()[:-1] + torch.Size([1])
|
| 322 |
-
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1)
|
| 323 |
-
|
| 324 |
-
if key_padding_mask is not None:
|
| 325 |
-
shape = key_padding_mask.size()[:-1] + torch.Size([1])
|
| 326 |
-
key_padding_mask = torch.cat([key_padding_mask, key_padding_mask.new_zeros(shape)], dim=-1)
|
| 327 |
-
|
| 328 |
-
return key_padding_mask, attn_mask
|
| 329 |
-
|
| 330 |
-
def _add_bias(self, k, v, key_padding_mask, attn_mask, bsz):
|
| 331 |
-
assert self.bias_k is not None or self.bias_v is not None
|
| 332 |
-
key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
| 333 |
-
return torch.cat([k, self.bias_k.repeat(1, bsz, 1)]), torch.cat([v, self.bias_v.repeat(1, bsz, 1)]), key_padding_mask, attn_mask
|
| 334 |
-
|
| 335 |
-
def _append_zero_attn(self, k, v, key_padding_mask, attn_mask):
|
| 336 |
-
zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:]
|
| 337 |
-
key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
| 338 |
-
return torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=-2), torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2), key_padding_mask, attn_mask
|
| 339 |
-
|
| 340 |
-
def forward(self, query, key, value, key_padding_mask = None, incremental_state = None, need_weights = True, static_kv = False, attn_mask = None, before_softmax = False, need_head_weights = False):
|
| 341 |
-
if need_head_weights: need_weights = True
|
| 342 |
-
is_tpu = query.device.type == "xla"
|
| 343 |
-
tgt_len, bsz, embed_dim = query.size()
|
| 344 |
-
src_len = tgt_len
|
| 345 |
-
|
| 346 |
-
if not self.skip_embed_dim_check: assert (embed_dim == self.embed_dim)
|
| 347 |
-
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
| 348 |
-
|
| 349 |
-
if key is not None:
|
| 350 |
-
src_len, key_bsz, _ = key.size()
|
| 351 |
-
if not torch.jit.is_scripting():
|
| 352 |
-
assert value is not None
|
| 353 |
-
assert src_len, key_bsz == value.shape[:2]
|
| 354 |
-
|
| 355 |
-
if (not self.onnx_trace and not is_tpu and incremental_state is None and not static_kv and not torch.jit.is_scripting() and not self.skip_embed_dim_check):
|
| 356 |
-
assert key is not None and value is not None
|
| 357 |
-
return F.multi_head_attention_forward(query, key, value, self.embed_dim, self.num_heads, torch.empty([0]), torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), self.bias_k, self.bias_v, self.add_zero_attn, self.dropout_module.p, self.out_proj.weight, self.out_proj.bias, self.training or self.dropout_module.apply_during_inference, key_padding_mask.bool() if key_padding_mask is not None else None, need_weights, attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight)
|
| 358 |
-
|
| 359 |
-
if incremental_state is not None:
|
| 360 |
-
saved_state = self._get_input_buffer(incremental_state)
|
| 361 |
-
if saved_state is not None and "prev_key" in saved_state:
|
| 362 |
-
if static_kv:
|
| 363 |
-
assert self.encoder_decoder_attention and not self.self_attention
|
| 364 |
-
key = value = None
|
| 365 |
-
else: saved_state = None
|
| 366 |
-
|
| 367 |
-
if self.self_attention:
|
| 368 |
-
q = self.q_proj(query)
|
| 369 |
-
k = self.k_proj(query)
|
| 370 |
-
v = self.v_proj(query)
|
| 371 |
-
elif self.encoder_decoder_attention:
|
| 372 |
-
q = self.q_proj(query)
|
| 373 |
-
if key is None:
|
| 374 |
-
assert value is None
|
| 375 |
-
k = v = None
|
| 376 |
-
else:
|
| 377 |
-
if self.beam_size > 1 and bsz == key.size(1):
|
| 378 |
-
key = key.view(key.size(0), -1, self.beam_size, key.size(2))[:, :, 0, :]
|
| 379 |
-
if key_padding_mask is not None: key_padding_mask = key_padding_mask.view(-1, self.beam_size, key_padding_mask.size(1))[:, 0, :]
|
| 380 |
-
k = self.k_proj(key)
|
| 381 |
-
v = self.v_proj(key)
|
| 382 |
-
else:
|
| 383 |
-
assert key is not None and value is not None
|
| 384 |
-
q = self.q_proj(query)
|
| 385 |
-
k = self.k_proj(key)
|
| 386 |
-
v = self.v_proj(value)
|
| 387 |
-
|
| 388 |
-
q *= self.scaling
|
| 389 |
-
|
| 390 |
-
if self.bias_k is not None:
|
| 391 |
-
assert self.bias_v is not None
|
| 392 |
-
k, v, attn_mask, key_padding_mask = self._add_bias(k, v, attn_mask, key_padding_mask, bsz)
|
| 393 |
-
|
| 394 |
-
q = (q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1))
|
| 395 |
-
kv_bsz = bsz
|
| 396 |
-
|
| 397 |
-
if k is not None:
|
| 398 |
-
kv_bsz = k.size(1)
|
| 399 |
-
k = (k.contiguous().view(-1, kv_bsz * self.num_heads, self.head_dim).transpose(0, 1))
|
| 400 |
-
|
| 401 |
-
if v is not None: v = (v.contiguous().view(-1, kv_bsz * self.num_heads, self.head_dim).transpose(0, 1))
|
| 402 |
-
|
| 403 |
-
if saved_state is not None:
|
| 404 |
-
if "prev_key" in saved_state:
|
| 405 |
-
_prev_key = saved_state["prev_key"]
|
| 406 |
-
assert _prev_key is not None
|
| 407 |
-
|
| 408 |
-
kv_bsz = _prev_key.size(0)
|
| 409 |
-
prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim)
|
| 410 |
-
|
| 411 |
-
if static_kv: k = prev_key
|
| 412 |
-
else:
|
| 413 |
-
assert k is not None
|
| 414 |
-
k = torch.cat([prev_key, k], dim=1)
|
| 415 |
-
src_len = k.size(1)
|
| 416 |
-
|
| 417 |
-
if "prev_value" in saved_state:
|
| 418 |
-
_prev_value = saved_state["prev_value"]
|
| 419 |
-
assert _prev_value is not None or kv_bsz == _prev_value.size(0)
|
| 420 |
-
prev_value = _prev_value.view(kv_bsz * self.num_heads, -1, self.head_dim)
|
| 421 |
-
|
| 422 |
-
if static_kv: v = prev_value
|
| 423 |
-
else:
|
| 424 |
-
assert v is not None
|
| 425 |
-
v = torch.cat([prev_value, v], dim=1)
|
| 426 |
-
|
| 427 |
-
prev_key_padding_mask = None
|
| 428 |
-
if "prev_key_padding_mask" in saved_state: prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
| 429 |
-
|
| 430 |
-
assert k is not None and v is not None
|
| 431 |
-
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(key_padding_mask=key_padding_mask, prev_key_padding_mask=prev_key_padding_mask, batch_size=kv_bsz, src_len=k.size(1), static_kv=static_kv)
|
| 432 |
-
|
| 433 |
-
saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim)
|
| 434 |
-
saved_state["prev_value"] = v.view(kv_bsz, self.num_heads, -1, self.head_dim)
|
| 435 |
-
saved_state["prev_key_padding_mask"] = key_padding_mask
|
| 436 |
-
|
| 437 |
-
assert incremental_state is not None
|
| 438 |
-
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
| 439 |
-
|
| 440 |
-
assert k is not None
|
| 441 |
-
assert k.size(1) == src_len
|
| 442 |
-
|
| 443 |
-
if key_padding_mask is not None and key_padding_mask.dim() == 0: key_padding_mask = None
|
| 444 |
-
|
| 445 |
-
if key_padding_mask is not None:
|
| 446 |
-
assert key_padding_mask.size(0) == kv_bsz
|
| 447 |
-
assert key_padding_mask.size(1) == src_len
|
| 448 |
-
|
| 449 |
-
if self.add_zero_attn:
|
| 450 |
-
assert v is not None
|
| 451 |
-
src_len += 1
|
| 452 |
-
k, v, key_padding_mask, attn_mask = self._append_zero_attn(k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
| 453 |
-
|
| 454 |
-
if self.encoder_decoder_attention and bsz != kv_bsz:
|
| 455 |
-
attn_weights = torch.einsum("bxhtd,bhsd->bxhts", q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]), k.view((kv_bsz, self.num_heads) + k.size()[1:]))
|
| 456 |
-
attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:])
|
| 457 |
-
else: attn_weights = torch.bmm(q, k.transpose(1, 2))
|
| 458 |
-
|
| 459 |
-
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
| 460 |
-
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
| 461 |
-
|
| 462 |
-
if attn_mask is not None:
|
| 463 |
-
attn_mask = attn_mask.unsqueeze(0)
|
| 464 |
-
if self.onnx_trace: attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
|
| 465 |
-
attn_weights += attn_mask
|
| 466 |
-
|
| 467 |
-
if key_padding_mask is not None:
|
| 468 |
-
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 469 |
-
attn_weights = attn_weights.view(kv_bsz, -1, self.num_heads, tgt_len, src_len).masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(torch.bool), float("-inf")) if not is_tpu else attn_weights.transpose(0, 2).masked_fill(key_padding_mask, float("-inf")).transpose(0, 2)
|
| 470 |
-
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 471 |
-
|
| 472 |
-
if before_softmax: return attn_weights, v
|
| 473 |
-
|
| 474 |
-
attn_weights_float = softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
|
| 475 |
-
attn_weights = attn_weights_float.type_as(attn_weights)
|
| 476 |
-
attn_probs = self.dropout_module(attn_weights)
|
| 477 |
-
|
| 478 |
-
assert v is not None
|
| 479 |
-
attn = None
|
| 480 |
-
|
| 481 |
-
if self.encoder_decoder_attention and bsz != kv_bsz:
|
| 482 |
-
attn = torch.einsum("bxhts,bhsd->bxhtd", attn_probs.view((kv_bsz, -1, self.num_heads) + attn_probs.size()[1:]), v.view((kv_bsz, self.num_heads) + v.size()[1:]))
|
| 483 |
-
attn = attn.reshape((-1,) + attn.size()[-2:])
|
| 484 |
-
else: attn = torch.bmm(attn_probs, v)
|
| 485 |
-
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
| 486 |
-
|
| 487 |
-
if self.onnx_trace and attn.size(1) == 1: attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim)
|
| 488 |
-
else: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
|
| 489 |
-
|
| 490 |
-
attn = self.out_proj(attn)
|
| 491 |
-
attn_weights = None
|
| 492 |
-
|
| 493 |
-
if need_weights:
|
| 494 |
-
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
| 495 |
-
if not need_head_weights: attn_weights = attn_weights.mean(dim=0)
|
| 496 |
-
|
| 497 |
-
return attn, attn_weights
|
| 498 |
-
|
| 499 |
-
@staticmethod
|
| 500 |
-
def _append_prev_key_padding_mask(key_padding_mask, prev_key_padding_mask, batch_size, src_len, static_kv):
|
| 501 |
-
if prev_key_padding_mask is not None and static_kv: new_key_padding_mask = prev_key_padding_mask
|
| 502 |
-
elif prev_key_padding_mask is not None and key_padding_mask is not None: new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
|
| 503 |
-
elif prev_key_padding_mask is not None:
|
| 504 |
-
if src_len > prev_key_padding_mask.size(1):
|
| 505 |
-
filler = torch.zeros((batch_size, src_len - prev_key_padding_mask.size(1)), device=prev_key_padding_mask.device)
|
| 506 |
-
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
|
| 507 |
-
else: new_key_padding_mask = prev_key_padding_mask.float()
|
| 508 |
-
elif key_padding_mask is not None:
|
| 509 |
-
if src_len > key_padding_mask.size(1):
|
| 510 |
-
filler = torch.zeros((batch_size, src_len - key_padding_mask.size(1)), device=key_padding_mask.device)
|
| 511 |
-
new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
|
| 512 |
-
else: new_key_padding_mask = key_padding_mask.float()
|
| 513 |
-
else: new_key_padding_mask = prev_key_padding_mask
|
| 514 |
-
return new_key_padding_mask
|
| 515 |
-
|
| 516 |
-
@torch.jit.export
|
| 517 |
-
def reorder_incremental_state(self, incremental_state, new_order):
|
| 518 |
-
input_buffer = self._get_input_buffer(incremental_state)
|
| 519 |
-
if input_buffer is not None:
|
| 520 |
-
for k in input_buffer.keys():
|
| 521 |
-
input_buffer_k = input_buffer[k]
|
| 522 |
-
if input_buffer_k is not None:
|
| 523 |
-
if self.encoder_decoder_attention:
|
| 524 |
-
if input_buffer_k.size(0) * self.beam_size == new_order.size(0): return incremental_state
|
| 525 |
-
elif self.beam_size > 1: input_buffer[k] = input_buffer_k.index_select(0, new_order.reshape(-1, self.beam_size)[:, 0] // self.beam_size)
|
| 526 |
-
else: input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
| 527 |
-
else: input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
| 528 |
-
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
| 529 |
-
return incremental_state
|
| 530 |
-
|
| 531 |
-
def set_beam_size(self, beam_size):
|
| 532 |
-
self.beam_size = beam_size
|
| 533 |
-
|
| 534 |
-
def _get_input_buffer(self, incremental_state):
|
| 535 |
-
result = self.get_incremental_state(incremental_state, "attn_state")
|
| 536 |
-
if result is not None: return result
|
| 537 |
-
else: return {}
|
| 538 |
-
|
| 539 |
-
def _set_input_buffer(self, incremental_state, buffer):
|
| 540 |
-
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
| 541 |
-
|
| 542 |
-
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
| 543 |
-
return attn_weights
|
| 544 |
-
|
| 545 |
-
def upgrade_state_dict_named(self, state_dict, name):
|
| 546 |
-
prefix = name + "." if name != "" else ""
|
| 547 |
-
items_to_add = {}
|
| 548 |
-
keys_to_remove = []
|
| 549 |
-
for k in state_dict.keys():
|
| 550 |
-
if k.endswith(prefix + "in_proj_weight"):
|
| 551 |
-
dim = int(state_dict[k].shape[0] / 3)
|
| 552 |
-
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
| 553 |
-
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
|
| 554 |
-
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
|
| 555 |
-
keys_to_remove.append(k)
|
| 556 |
-
k_bias = prefix + "in_proj_bias"
|
| 557 |
-
if k_bias in state_dict.keys():
|
| 558 |
-
dim = int(state_dict[k].shape[0] / 3)
|
| 559 |
-
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
|
| 560 |
-
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim]
|
| 561 |
-
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
|
| 562 |
-
keys_to_remove.append(prefix + "in_proj_bias")
|
| 563 |
-
|
| 564 |
-
for k in keys_to_remove:
|
| 565 |
-
del state_dict[k]
|
| 566 |
-
|
| 567 |
-
for key, value in items_to_add.items():
|
| 568 |
-
state_dict[key] = value
|
| 569 |
-
|
| 570 |
-
def init_bert_params(module):
|
| 571 |
-
def normal_(data):
|
| 572 |
-
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
| 573 |
-
|
| 574 |
-
if isinstance(module, nn.Linear):
|
| 575 |
-
normal_(module.weight.data)
|
| 576 |
-
if module.bias is not None: module.bias.data.zero_()
|
| 577 |
-
if isinstance(module, nn.Embedding):
|
| 578 |
-
normal_(module.weight.data)
|
| 579 |
-
if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_()
|
| 580 |
-
if isinstance(module, MultiheadAttention):
|
| 581 |
-
normal_(module.q_proj.weight.data)
|
| 582 |
-
normal_(module.k_proj.weight.data)
|
| 583 |
-
normal_(module.v_proj.weight.data)
|
| 584 |
-
|
| 585 |
-
def make_conv_pos(e, k, g):
|
| 586 |
-
pos_conv = nn.Conv1d(e, e, kernel_size=k, padding=k // 2, groups=g)
|
| 587 |
-
dropout = 0
|
| 588 |
-
|
| 589 |
-
nn.init.normal_(pos_conv.weight, mean=0, std=math.sqrt((4 * (1.0 - dropout)) / (k * e)))
|
| 590 |
-
nn.init.constant_(pos_conv.bias, 0)
|
| 591 |
-
|
| 592 |
-
return nn.Sequential(nn.utils.parametrizations.weight_norm(pos_conv, name="weight", dim=2), SamePad(k), nn.GELU())
|
| 593 |
-
|
| 594 |
-
def is_xla_tensor(tensor):
|
| 595 |
-
return torch.is_tensor(tensor) and tensor.device.type == "xla"
|
| 596 |
-
|
| 597 |
-
def index_put(tensor, indices, value):
|
| 598 |
-
if is_xla_tensor(tensor):
|
| 599 |
-
for _ in range(indices.dim(), tensor.dim()):
|
| 600 |
-
indices = indices.unsqueeze(-1)
|
| 601 |
-
|
| 602 |
-
if indices.size(-1) < tensor.size(-1): indices = indices.expand_as(tensor)
|
| 603 |
-
tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
|
| 604 |
-
else: tensor[indices] = value
|
| 605 |
-
|
| 606 |
-
return tensor
|
| 607 |
-
|
| 608 |
-
def pad_to_multiple(x, multiple, dim=-1, value=0):
|
| 609 |
-
if x is None: return None, 0
|
| 610 |
-
tsz = x.size(dim)
|
| 611 |
-
m = tsz / multiple
|
| 612 |
-
remainder = math.ceil(m) * multiple - tsz
|
| 613 |
-
if m.is_integer(): return x, 0
|
| 614 |
-
return F.pad(x, (*((0,) * (-1 - dim) * 2), 0, remainder), value=value), remainder
|
| 615 |
-
|
| 616 |
-
def compute_mask_indices(shape, padding_mask, mask_prob, mask_length, mask_type = "static", mask_other = 0.0, min_masks = 0, no_overlap = False, min_space = 0, require_same_masks = True, mask_dropout = 0.0, add_masks = False, seed = None, epoch = None, indices = None, idc_select_ver = 1, num_mask_ver = 2):
|
| 617 |
-
bsz, all_sz = shape
|
| 618 |
-
mask = np.full((bsz, all_sz), False)
|
| 619 |
-
|
| 620 |
-
if num_mask_ver == 1: all_num_mask = max(min_masks, int(mask_prob * all_sz / float(mask_length) + np.random.rand()))
|
| 621 |
-
mask_idcs = []
|
| 622 |
-
|
| 623 |
-
for i in range(bsz):
|
| 624 |
-
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) if seed is not None and epoch is not None and indices is not None else None
|
| 625 |
-
rng = np.random.default_rng(seed_i)
|
| 626 |
-
|
| 627 |
-
if padding_mask is not None:
|
| 628 |
-
sz = all_sz - padding_mask[i].long().sum().item()
|
| 629 |
-
assert sz >= 0, sz
|
| 630 |
-
else: sz = all_sz
|
| 631 |
-
|
| 632 |
-
if num_mask_ver == 1: num_mask = max(min_masks, int(mask_prob * sz / float(mask_length) + np.random.rand())) if padding_mask is not None else all_num_mask
|
| 633 |
-
elif num_mask_ver == 2: num_mask = max(min_masks, int(mask_prob * sz / float(mask_length) + rng.random()))
|
| 634 |
-
else: raise ValueError
|
| 635 |
-
|
| 636 |
-
if mask_type == "static": lengths = np.full(num_mask, mask_length)
|
| 637 |
-
elif mask_type == "uniform": lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
| 638 |
-
elif mask_type == "normal": lengths = [max(1, int(round(x))) for x in rng.normal(mask_length, mask_other, size=num_mask)]
|
| 639 |
-
elif mask_type == "poisson": lengths = [int(round(x)) for x in rng.poisson(mask_length, size=num_mask)]
|
| 640 |
-
else: raise Exception
|
| 641 |
-
|
| 642 |
-
if sum(lengths) == 0:
|
| 643 |
-
if mask_type == "static": raise ValueError
|
| 644 |
-
else: lengths = [min(mask_length, sz - 1)]
|
| 645 |
-
|
| 646 |
-
if no_overlap:
|
| 647 |
-
mask_idc = []
|
| 648 |
-
|
| 649 |
-
def arrange(s, e, length, keep_length):
|
| 650 |
-
span_start = rng.randint(s, e - length)
|
| 651 |
-
mask_idc.extend(span_start + i for i in range(length))
|
| 652 |
-
new_parts = []
|
| 653 |
-
|
| 654 |
-
if span_start - s - min_space >= keep_length: new_parts.append((s, span_start - min_space + 1))
|
| 655 |
-
if e - span_start - length - min_space > keep_length: new_parts.append((span_start + length + min_space, e))
|
| 656 |
-
|
| 657 |
-
return new_parts
|
| 658 |
-
|
| 659 |
-
parts = [(0, sz)]
|
| 660 |
-
min_length = min(lengths)
|
| 661 |
-
|
| 662 |
-
for length in sorted(lengths, reverse=True):
|
| 663 |
-
lens = np.fromiter((e - s if e - s >= length + min_space else 0 for s, e in parts), np.int32)
|
| 664 |
-
l_sum = np.sum(lens)
|
| 665 |
-
if l_sum == 0: break
|
| 666 |
-
s, e = parts.pop(rng.choice(len(parts), p=lens / np.sum(lens)))
|
| 667 |
-
parts.extend(arrange(s, e, length, min_length))
|
| 668 |
-
mask_idc = np.asarray(mask_idc)
|
| 669 |
-
else:
|
| 670 |
-
if idc_select_ver == 1:
|
| 671 |
-
min_len = min(lengths)
|
| 672 |
-
if sz - min_len <= num_mask: min_len = sz - num_mask - 1
|
| 673 |
-
mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
|
| 674 |
-
elif idc_select_ver == 2: mask_idc = rng.choice(sz, num_mask, replace=False)
|
| 675 |
-
else: raise ValueError
|
| 676 |
-
|
| 677 |
-
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
|
| 678 |
-
|
| 679 |
-
mask_idc = np.unique(mask_idc[mask_idc < sz])
|
| 680 |
-
if len(mask_idc) >= sz: raise ValueError
|
| 681 |
-
mask_idcs.append(mask_idc)
|
| 682 |
-
|
| 683 |
-
target_len = None
|
| 684 |
-
if require_same_masks: target_len = max([len(m) for m in mask_idcs]) if add_masks else min([len(m) for m in mask_idcs])
|
| 685 |
-
|
| 686 |
-
for i, mask_idc in enumerate(mask_idcs):
|
| 687 |
-
if target_len is not None and len(mask_idc) > target_len: mask_idc = rng.choice(mask_idc, target_len, replace=False)
|
| 688 |
-
mask[i, mask_idc] = True
|
| 689 |
-
|
| 690 |
-
if target_len is not None and len(mask_idc) < target_len:
|
| 691 |
-
to_mask = rng.choice(np.flatnonzero(~mask[i]), target_len - len(mask_idc), replace=False)
|
| 692 |
-
mask[i, to_mask] = True
|
| 693 |
-
|
| 694 |
-
if mask_dropout > 0:
|
| 695 |
-
masked = np.flatnonzero(mask[i])
|
| 696 |
-
mask[i, rng.choice(masked, np.rint(len(masked) * mask_dropout).astype(int), replace=False)] = False
|
| 697 |
-
|
| 698 |
-
return mask
|
| 699 |
-
|
| 700 |
-
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
|
| 701 |
-
return nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
| 702 |
-
|
| 703 |
-
def prune_state_dict(state_dict, model_cfg):
|
| 704 |
-
arch = None
|
| 705 |
-
if model_cfg is not None: arch = (model_cfg._name if isinstance(model_cfg, DictConfig) else getattr(model_cfg, "arch", None))
|
| 706 |
-
|
| 707 |
-
if not model_cfg or arch is None or arch == "ptt_transformer": return state_dict
|
| 708 |
-
|
| 709 |
-
encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
|
| 710 |
-
decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
|
| 711 |
-
|
| 712 |
-
if not encoder_layers_to_keep and not decoder_layers_to_keep: return state_dict
|
| 713 |
-
|
| 714 |
-
def create_pruning_pass(layers_to_keep, layer_name):
|
| 715 |
-
keep_layers = sorted(int(layer_string) for layer_string in layers_to_keep.split(","))
|
| 716 |
-
mapping_dict = {}
|
| 717 |
-
for i in range(len(keep_layers)):
|
| 718 |
-
mapping_dict[str(keep_layers[i])] = str(i)
|
| 719 |
-
|
| 720 |
-
return {"substitution_regex": re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name)), "mapping_dict": mapping_dict}
|
| 721 |
-
|
| 722 |
-
pruning_passes = []
|
| 723 |
-
new_state_dict = {}
|
| 724 |
-
|
| 725 |
-
if encoder_layers_to_keep: pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
|
| 726 |
-
if decoder_layers_to_keep: pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
|
| 727 |
-
|
| 728 |
-
for layer_name in state_dict.keys():
|
| 729 |
-
match = re.search(r"\.layers\.(\d+)\.", layer_name)
|
| 730 |
-
if not match:
|
| 731 |
-
new_state_dict[layer_name] = state_dict[layer_name]
|
| 732 |
-
continue
|
| 733 |
-
|
| 734 |
-
original_layer_number = match.group(1)
|
| 735 |
-
for pruning_pass in pruning_passes:
|
| 736 |
-
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass["substitution_regex"].search(layer_name):
|
| 737 |
-
substitution_match = pruning_pass["substitution_regex"].search(layer_name)
|
| 738 |
-
new_state_dict[(layer_name[: substitution_match.start(1)] + pruning_pass["mapping_dict"][original_layer_number] + layer_name[substitution_match.end(1) :])] = state_dict[layer_name]
|
| 739 |
-
|
| 740 |
-
with open_dict(model_cfg) if isinstance(model_cfg, DictConfig) else contextlib.ExitStack():
|
| 741 |
-
if hasattr(model_cfg, "encoder_layers_to_keep"): model_cfg.encoder_layers_to_keep = None
|
| 742 |
-
if hasattr(model_cfg, "decoder_layers_to_keep"): model_cfg.decoder_layers_to_keep = None
|
| 743 |
-
|
| 744 |
-
return new_state_dict
|
| 745 |
-
|
| 746 |
-
def relu_squared(x):
|
| 747 |
-
return F.relu(x).pow(2)
|
| 748 |
-
|
| 749 |
-
def get_activation_fn(activation):
|
| 750 |
-
def gelu(x):
|
| 751 |
-
return nn.functional.gelu(x.float()).type_as(x)
|
| 752 |
-
|
| 753 |
-
def gelu_accurate(x):
|
| 754 |
-
if not hasattr(gelu_accurate, "_a"):
|
| 755 |
-
gelu_accurate._a = math.sqrt(2 / math.pi)
|
| 756 |
-
return (0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))))
|
| 757 |
-
|
| 758 |
-
if activation == "relu": return F.relu
|
| 759 |
-
elif activation == "relu_squared": return relu_squared
|
| 760 |
-
elif activation == "gelu": return gelu
|
| 761 |
-
elif activation == "gelu_fast": return gelu_accurate
|
| 762 |
-
elif activation == "gelu_accurate": return gelu_accurate
|
| 763 |
-
elif activation == "tanh": return torch.tanh
|
| 764 |
-
elif activation == "linear": return lambda x: x
|
| 765 |
-
elif activation == "swish": return nn.SiLU
|
| 766 |
-
else: raise RuntimeError
|
| 767 |
-
|
| 768 |
-
class SamePad(nn.Module):
|
| 769 |
-
def __init__(self, kernel_size, causal=False):
|
| 770 |
-
super().__init__()
|
| 771 |
-
if causal: self.remove = kernel_size - 1
|
| 772 |
-
else: self.remove = 1 if kernel_size % 2 == 0 else 0
|
| 773 |
-
|
| 774 |
-
def forward(self, x):
|
| 775 |
-
if self.remove > 0: x = x[:, :, : -self.remove]
|
| 776 |
-
return x
|
| 777 |
-
|
| 778 |
-
class TransformerSentenceEncoderLayer(nn.Module):
|
| 779 |
-
def __init__(self, embedding_dim = 768, ffn_embedding_dim = 3072, num_attention_heads = 8, dropout = 0.1, attention_dropout = 0.1, activation_dropout = 0.1, activation_fn = "relu", layer_norm_first = False):
|
| 780 |
-
super().__init__()
|
| 781 |
-
self.embedding_dim = embedding_dim
|
| 782 |
-
self.dropout = dropout
|
| 783 |
-
self.activation_dropout = activation_dropout
|
| 784 |
-
self.activation_fn = get_activation_fn(activation_fn)
|
| 785 |
-
self.self_attn = MultiheadAttention(self.embedding_dim, num_attention_heads, dropout=attention_dropout, self_attention=True)
|
| 786 |
-
self.dropout1 = nn.Dropout(dropout)
|
| 787 |
-
self.dropout2 = nn.Dropout(self.activation_dropout)
|
| 788 |
-
self.dropout3 = nn.Dropout(dropout)
|
| 789 |
-
self.layer_norm_first = layer_norm_first
|
| 790 |
-
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
| 791 |
-
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
| 792 |
-
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
| 793 |
-
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
| 794 |
-
|
| 795 |
-
def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None):
|
| 796 |
-
residual = x
|
| 797 |
-
|
| 798 |
-
if self.layer_norm_first:
|
| 799 |
-
x = self.self_attn_layer_norm(x)
|
| 800 |
-
x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, attn_mask=self_attn_mask, need_weights=False)
|
| 801 |
-
x = residual + self.dropout1(x)
|
| 802 |
-
residual = x
|
| 803 |
-
x = self.fc2(self.dropout2(self.activation_fn(self.fc1(self.final_layer_norm(x)))))
|
| 804 |
-
layer_result = x
|
| 805 |
-
x = residual + self.dropout3(x)
|
| 806 |
-
else:
|
| 807 |
-
x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, need_weights=False)
|
| 808 |
-
x = self.self_attn_layer_norm(residual + self.dropout1(x))
|
| 809 |
-
residual = x
|
| 810 |
-
x = self.fc2(self.dropout2(self.activation_fn(self.fc1(x))))
|
| 811 |
-
layer_result = x
|
| 812 |
-
x = self.final_layer_norm(residual + self.dropout3(x))
|
| 813 |
-
|
| 814 |
-
return x, (attn, layer_result)
|
| 815 |
-
|
| 816 |
-
class AdapterFast(nn.Module):
|
| 817 |
-
def __init__(self, adapter_num, input_dim, hidden_dim, act_fn):
|
| 818 |
-
super().__init__()
|
| 819 |
-
self.adapter_num = adapter_num
|
| 820 |
-
self.input_dim = input_dim
|
| 821 |
-
self.hidden_dim = hidden_dim
|
| 822 |
-
self.W_a = nn.Parameter(torch.empty(adapter_num, hidden_dim, input_dim))
|
| 823 |
-
self.W_b = nn.Parameter(torch.empty(adapter_num, input_dim, hidden_dim))
|
| 824 |
-
self.b_a = nn.Parameter(torch.empty(adapter_num, hidden_dim))
|
| 825 |
-
self.b_b = nn.Parameter(torch.empty(adapter_num, input_dim))
|
| 826 |
-
self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim))
|
| 827 |
-
self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim))
|
| 828 |
-
self.act_fn = nn.Identity()
|
| 829 |
-
if act_fn == "relu": self.act_fn = nn.ReLU()
|
| 830 |
-
elif act_fn == "gelu": self.act_fn = nn.GELU()
|
| 831 |
-
elif act_fn == "selu": self.act_fn = nn.SELU()
|
| 832 |
-
else: raise ValueError
|
| 833 |
-
|
| 834 |
-
self.input_dim = input_dim
|
| 835 |
-
self.reset_parameters()
|
| 836 |
-
|
| 837 |
-
def reset_parameters(self):
|
| 838 |
-
for ii in range(self.adapter_num):
|
| 839 |
-
nn.init.kaiming_uniform_(self.W_a[ii], a=math.sqrt(5))
|
| 840 |
-
nn.init.kaiming_uniform_(self.W_b[ii], a=math.sqrt(5))
|
| 841 |
-
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_a[ii])
|
| 842 |
-
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
| 843 |
-
nn.init.uniform_(self.b_a[ii], -bound, bound)
|
| 844 |
-
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii])
|
| 845 |
-
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
| 846 |
-
nn.init.uniform_(self.b_b[ii], -bound, bound)
|
| 847 |
-
|
| 848 |
-
nn.init.ones_(self.ln_W)
|
| 849 |
-
nn.init.zeros_(self.ln_b)
|
| 850 |
-
|
| 851 |
-
def forward(self, x, adapter_id):
|
| 852 |
-
ii = adapter_id
|
| 853 |
-
return F.linear(self.act_fn(F.linear(F.layer_norm(x, (self.input_dim, ), self.ln_W[ii], self.ln_b[ii]), self.W_a[ii], self.b_a[ii])), self.W_b[ii], self.b_b[ii])
|
| 854 |
-
|
| 855 |
-
def extra_repr(self):
|
| 856 |
-
return ('adapter={}, input_dim={}, hidden_dim={}'.format(self.adapter_num, self.input_dim, self.hidden_dim))
|
| 857 |
-
|
| 858 |
-
class FeedForwardModule(nn.Module):
|
| 859 |
-
def __init__(self, input_feat, hidden_units, dropout1, dropout2, activation_fn="swish", bias=True):
|
| 860 |
-
super(FeedForwardModule, self).__init__()
|
| 861 |
-
self.layer_norm = LayerNorm(input_feat)
|
| 862 |
-
self.w_1 = nn.Linear(input_feat, hidden_units, bias=bias)
|
| 863 |
-
self.w_2 = nn.Linear(hidden_units, input_feat, bias=bias)
|
| 864 |
-
self.dropout1 = nn.Dropout(dropout1)
|
| 865 |
-
self.dropout2 = nn.Dropout(dropout2)
|
| 866 |
-
self.activation = get_activation_fn(activation_fn)(hidden_units)
|
| 867 |
-
|
| 868 |
-
def forward(self, x):
|
| 869 |
-
return self.dropout2(self.w_2(self.dropout1(self.activation(self.w_1(self.layer_norm(x))))))
|
| 870 |
-
|
| 871 |
-
class ConvolutionModule(nn.Module):
|
| 872 |
-
def __init__(self, embed_dim, channels, depthwise_kernel_size, dropout, activation_fn="swish", bias=False, export=False):
|
| 873 |
-
super(ConvolutionModule, self).__init__()
|
| 874 |
-
assert (depthwise_kernel_size - 1) % 2 == 0
|
| 875 |
-
self.layer_norm = LayerNorm(embed_dim, export=export)
|
| 876 |
-
self.pointwise_conv1 = nn.Conv1d(embed_dim, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias)
|
| 877 |
-
self.glu = nn.GLU(dim=1)
|
| 878 |
-
self.depthwise_conv = nn.Conv1d(channels, channels, depthwise_kernel_size, stride=1, padding=(depthwise_kernel_size - 1) // 2, groups=channels, bias=bias)
|
| 879 |
-
self.batch_norm = nn.BatchNorm1d(channels)
|
| 880 |
-
self.activation = get_activation_fn(activation_fn)(channels)
|
| 881 |
-
self.pointwise_conv2 = nn.Conv1d(channels, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias)
|
| 882 |
-
self.dropout = nn.Dropout(dropout)
|
| 883 |
-
|
| 884 |
-
def forward(self, x):
|
| 885 |
-
return self.dropout(self.pointwise_conv2(self.activation(self.batch_norm(self.depthwise_conv(self.glu(self.pointwise_conv1(self.layer_norm(x).transpose(1, 2)))))))).transpose(1, 2)
|
| 886 |
-
|
| 887 |
-
def rotate_half(x):
|
| 888 |
-
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
| 889 |
-
return torch.cat((-x2, x1), dim=x1.ndim - 1)
|
| 890 |
-
|
| 891 |
-
def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
|
| 892 |
-
cos, sin = (cos[offset : q.shape[0] + offset, ...], sin[offset : q.shape[0] + offset, ...])
|
| 893 |
-
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
| 894 |
-
|
| 895 |
-
class RotaryPositionalEmbedding(nn.Module):
|
| 896 |
-
def __init__(self, dim, base=10000, precision=torch.half):
|
| 897 |
-
super().__init__()
|
| 898 |
-
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 899 |
-
self.register_buffer("inv_freq", inv_freq)
|
| 900 |
-
self.seq_len_cached = 0
|
| 901 |
-
self.cos_cached = torch.empty(self.seq_len_cached, 1, 1, dim)
|
| 902 |
-
self.sin_cached = torch.empty(self.seq_len_cached, 1, 1, dim)
|
| 903 |
-
self.precision = precision
|
| 904 |
-
|
| 905 |
-
def forward(self, x, seq_len = 0):
|
| 906 |
-
if seq_len > self.seq_len_cached:
|
| 907 |
-
self.seq_len_cached = seq_len
|
| 908 |
-
freqs = torch.einsum("i,j->ij", torch.arange(seq_len, device=x.device).type_as(self.inv_freq), self.inv_freq)
|
| 909 |
-
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
| 910 |
-
self.cos_cached = emb.cos().view(emb.size(0), 1, 1, emb.size(1))
|
| 911 |
-
self.sin_cached = emb.sin().view(emb.size(0), 1, 1, emb.size(1))
|
| 912 |
-
return self.cos_cached, self.sin_cached
|
| 913 |
-
|
| 914 |
-
class ESPNETMultiHeadedAttention(nn.Module):
|
| 915 |
-
def __init__(self, n_feat, n_head, dropout):
|
| 916 |
-
super(ESPNETMultiHeadedAttention, self).__init__()
|
| 917 |
-
assert n_feat % n_head == 0
|
| 918 |
-
self.d_k = n_feat // n_head
|
| 919 |
-
self.h = n_head
|
| 920 |
-
self.linear_q = nn.Linear(n_feat, n_feat)
|
| 921 |
-
self.linear_k = nn.Linear(n_feat, n_feat)
|
| 922 |
-
self.linear_v = nn.Linear(n_feat, n_feat)
|
| 923 |
-
self.linear_out = nn.Linear(n_feat, n_feat)
|
| 924 |
-
self.attn = None
|
| 925 |
-
self.dropout = nn.Dropout(p=dropout)
|
| 926 |
-
|
| 927 |
-
def forward_qkv(self, query, key, value, **kwargs):
|
| 928 |
-
n_batch = query.size(0)
|
| 929 |
-
return self.linear_q(query).view(n_batch, -1, self.h, self.d_k).transpose(1, 2), self.linear_k(key).view(n_batch, -1, self.h, self.d_k).transpose(1, 2), self.linear_v(value).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
|
| 930 |
-
|
| 931 |
-
def forward_attention(self, value, scores, mask):
|
| 932 |
-
n_batch = value.size(0)
|
| 933 |
-
|
| 934 |
-
if mask is not None:
|
| 935 |
-
scores = scores.masked_fill(mask.unsqueeze(1).unsqueeze(2).to(bool), float("-inf"))
|
| 936 |
-
self.attn = torch.softmax(scores, dim=-1)
|
| 937 |
-
else: self.attn = torch.softmax(scores, dim=-1)
|
| 938 |
-
|
| 939 |
-
return self.linear_out((torch.matmul(self.dropout(self.attn), value).transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)))
|
| 940 |
-
|
| 941 |
-
def forward(self, query, key, value, key_padding_mask=None, **kwargs):
|
| 942 |
-
q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
|
| 943 |
-
return self.forward_attention(v, torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None
|
| 944 |
-
|
| 945 |
-
class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
|
| 946 |
-
def __init__(self, n_feat, n_head, dropout, zero_triu=False):
|
| 947 |
-
super().__init__(n_feat, n_head, dropout)
|
| 948 |
-
self.zero_triu = zero_triu
|
| 949 |
-
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
| 950 |
-
self.pos_bias_u = nn.Parameter(torch.zeros(self.h, self.d_k))
|
| 951 |
-
self.pos_bias_v = nn.Parameter(torch.zeros(self.h, self.d_k))
|
| 952 |
-
nn.init.xavier_uniform_(self.pos_bias_u)
|
| 953 |
-
nn.init.xavier_uniform_(self.pos_bias_v)
|
| 954 |
-
|
| 955 |
-
def rel_shift(self, x):
|
| 956 |
-
x = torch.cat([torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype), x], dim=-1).view(*x.size()[:2], x.size(3) + 1, x.size(2))[:, :, 1:].view_as(x)[:, :, :, : x.size(-1) // 2 + 1]
|
| 957 |
-
if self.zero_triu: x = x * torch.tril(torch.ones((x.size(2), x.size(3)), device=x.device), x.size(3) - x.size(2))[None, None, :, :]
|
| 958 |
-
return x
|
| 959 |
-
|
| 960 |
-
def forward(self, query, key, value, pos_emb, key_padding_mask=None, **kwargs):
|
| 961 |
-
pos_emb = pos_emb.transpose(0, 1)
|
| 962 |
-
q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
|
| 963 |
-
q = q.transpose(1, 2)
|
| 964 |
-
|
| 965 |
-
return self.forward_attention(v, (torch.matmul((q + self.pos_bias_u).transpose(1, 2), k.transpose(-2, -1)) + self.rel_shift(torch.matmul((q + self.pos_bias_v).transpose(1, 2), self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.h, self.d_k).transpose(1, 2).transpose(-2, -1)))) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None
|
| 966 |
-
|
| 967 |
-
class RotaryPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
|
| 968 |
-
def __init__(self, n_feat, n_head, dropout, precision, rotary_emd_base=10000):
|
| 969 |
-
super().__init__(n_feat, n_head, dropout)
|
| 970 |
-
precision = torch.float
|
| 971 |
-
self.rotary_ndims = self.d_k
|
| 972 |
-
if precision == "fp16": precision = torch.half
|
| 973 |
-
self.rotary_emb = RotaryPositionalEmbedding(self.rotary_ndims, base=rotary_emd_base, precision=precision)
|
| 974 |
-
|
| 975 |
-
def forward(self, query, key, value, key_padding_mask=None, **kwargs):
|
| 976 |
-
T, B, C = value.size()
|
| 977 |
-
query = query.view(T, B, self.h, self.d_k)
|
| 978 |
-
key = key.view(T, B, self.h, self.d_k)
|
| 979 |
-
value = value.view(T, B, self.h, self.d_k)
|
| 980 |
-
|
| 981 |
-
cos, sin = self.rotary_emb(value, seq_len=T)
|
| 982 |
-
query, key = apply_rotary_pos_emb(query, key, cos, sin, offset=0)
|
| 983 |
-
|
| 984 |
-
query = query.view(T, B, self.h * self.d_k)
|
| 985 |
-
key = key.view(T, B, self.h * self.d_k)
|
| 986 |
-
value = value.view(T, B, self.h * self.d_k)
|
| 987 |
-
|
| 988 |
-
q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
|
| 989 |
-
return self.forward_attention(v, torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None
|
| 990 |
-
|
| 991 |
-
class ConformerEncoderLayer(nn.Module):
|
| 992 |
-
def __init__(self, embed_dim, ffn_embed_dim, attention_heads, dropout, use_fp16, depthwise_conv_kernel_size=31, activation_fn="swish", attn_type=None, pos_enc_type="abs"):
|
| 993 |
-
self.pos_enc_type = pos_enc_type
|
| 994 |
-
super(ConformerEncoderLayer, self).__init__()
|
| 995 |
-
self.ffn1 = FeedForwardModule(embed_dim, ffn_embed_dim, dropout, dropout)
|
| 996 |
-
self.self_attn_layer_norm = LayerNorm(embed_dim, export=False)
|
| 997 |
-
self.self_attn_dropout = nn.Dropout(dropout)
|
| 998 |
-
|
| 999 |
-
if attn_type == "espnet":
|
| 1000 |
-
if self.pos_enc_type == "rel_pos": self.self_attn = RelPositionMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout)
|
| 1001 |
-
elif self.pos_enc_type == "rope": self.self_attn = RotaryPositionMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout, precision=use_fp16)
|
| 1002 |
-
elif self.pos_enc_type == "abs": self.self_attn = ESPNETMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout)
|
| 1003 |
-
else: raise Exception
|
| 1004 |
-
else: self.self_attn = MultiheadAttention(embed_dim, attention_heads, dropout=dropout)
|
| 1005 |
-
|
| 1006 |
-
self.conv_module = ConvolutionModule(embed_dim=embed_dim, channels=embed_dim, depthwise_kernel_size=depthwise_conv_kernel_size, dropout=dropout, activation_fn=activation_fn)
|
| 1007 |
-
self.ffn2 = FeedForwardModule(embed_dim, ffn_embed_dim, dropout, dropout, activation_fn=activation_fn)
|
| 1008 |
-
self.final_layer_norm = LayerNorm(embed_dim, export=False)
|
| 1009 |
-
|
| 1010 |
-
def forward(self, x, encoder_padding_mask, position_emb = None):
|
| 1011 |
-
residual = x
|
| 1012 |
-
x = self.ffn1(x) * 0.5 + residual
|
| 1013 |
-
residual = x
|
| 1014 |
-
x = self.self_attn_layer_norm(x)
|
| 1015 |
-
|
| 1016 |
-
if self.pos_enc_type == "rel_pos": x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, pos_emb=position_emb, need_weights=False)
|
| 1017 |
-
else: x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=False)
|
| 1018 |
-
|
| 1019 |
-
x = self.self_attn_dropout(x)
|
| 1020 |
-
x = x + residual
|
| 1021 |
-
residual = x
|
| 1022 |
-
x = residual + self.conv_module(x.transpose(0, 1)).transpose(0, 1)
|
| 1023 |
-
residual = x
|
| 1024 |
-
x = self.ffn2(x)
|
| 1025 |
-
layer_result = x
|
| 1026 |
-
x = self.final_layer_norm(x * 0.5 + residual)
|
| 1027 |
-
|
| 1028 |
-
return x, (attn, layer_result)
|
| 1029 |
-
|
| 1030 |
-
class ConformerWav2Vec2EncoderLayer(ConformerEncoderLayer):
|
| 1031 |
-
def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None, position_emb=None):
|
| 1032 |
-
return super().forward(x, self_attn_padding_mask, position_emb)
|
| 1033 |
-
|
| 1034 |
-
class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer):
|
| 1035 |
-
def __init__(self, embedding_dim = 768, ffn_embedding_dim = 3072, num_attention_heads = 8, dropout = 0.1, attention_dropout = 0.1, activation_dropout = 0.1, activation_fn = "relu", layer_norm_first = False, adapter_num=201, adapter_dim=64, adapter_act_fn="relu"):
|
| 1036 |
-
super().__init__(embedding_dim=embedding_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, dropout=dropout, attention_dropout=attention_dropout, activation_dropout=activation_dropout, activation_fn=activation_fn, layer_norm_first=layer_norm_first)
|
| 1037 |
-
self.adapter_num = adapter_num
|
| 1038 |
-
self.adapter_dim = adapter_dim
|
| 1039 |
-
self.adapter_layer = AdapterFast(adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn)
|
| 1040 |
-
|
| 1041 |
-
def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None, corpus_key=None):
|
| 1042 |
-
|
| 1043 |
-
x, (attn, layer_result) = super().forward(x=x, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, need_weights=need_weights, att_args=att_args)
|
| 1044 |
-
assert corpus_key is not None
|
| 1045 |
-
assert len(set(corpus_key)) == 1
|
| 1046 |
-
|
| 1047 |
-
return x + self.adapter_layer(x, corpus_key[0]), (attn, layer_result)
|
| 1048 |
-
|
| 1049 |
-
class TransposeLast(nn.Module):
|
| 1050 |
-
def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
|
| 1051 |
-
super().__init__()
|
| 1052 |
-
self.deconstruct_idx = deconstruct_idx
|
| 1053 |
-
self.tranpose_dim = tranpose_dim
|
| 1054 |
-
|
| 1055 |
-
def forward(self, x):
|
| 1056 |
-
if self.deconstruct_idx is not None: x = x[self.deconstruct_idx]
|
| 1057 |
-
return x.transpose(self.tranpose_dim, -1)
|
| 1058 |
-
|
| 1059 |
-
class TransformerEncoder(nn.Module):
|
| 1060 |
-
def build_encoder_layer(self, args, **kwargs):
|
| 1061 |
-
if args.layer_type == "transformer": layer = TransformerSentenceEncoderLayer(embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first)
|
| 1062 |
-
elif args.layer_type == "conformer": layer = ConformerWav2Vec2EncoderLayer(embed_dim=self.embedding_dim, ffn_embed_dim=args.encoder_ffn_embed_dim, attention_heads=args.encoder_attention_heads, dropout=args.dropout, depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, activation_fn="swish", attn_type=args.attn_type, use_fp16=args.fp16, pos_enc_type="abs")
|
| 1063 |
-
elif args.layer_type == "trf_adp":
|
| 1064 |
-
use_adp = False
|
| 1065 |
-
if args.adp_trf_idx == "all": use_adp = True
|
| 1066 |
-
else:
|
| 1067 |
-
if kwargs.get("layer_idx", None) in list(range(*[int(g) for g in args.adp_trf_idx.split(":")])): use_adp = True
|
| 1068 |
-
|
| 1069 |
-
layer = TransformerSentenceEncoderWithAdapterLayer(embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first, adapter_num=args.adp_num, adapter_dim=args.adp_dim, adapter_act_fn=args.adp_act_fn) if use_adp else TransformerSentenceEncoderLayer(embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first,)
|
| 1070 |
-
|
| 1071 |
-
return layer
|
| 1072 |
-
|
| 1073 |
-
def __init__(self, args):
|
| 1074 |
-
super().__init__()
|
| 1075 |
-
self.dropout = args.dropout
|
| 1076 |
-
self.embedding_dim = args.encoder_embed_dim
|
| 1077 |
-
self.required_seq_len_multiple = args.required_seq_len_multiple
|
| 1078 |
-
pos_conv_depth = getattr(args, "pos_conv_depth", 1)
|
| 1079 |
-
|
| 1080 |
-
if pos_conv_depth > 1:
|
| 1081 |
-
num_layers = args.pos_conv_depth
|
| 1082 |
-
k = max(3, args.conv_pos // num_layers)
|
| 1083 |
-
|
| 1084 |
-
def make_conv_block(e, k, g, l):
|
| 1085 |
-
return nn.Sequential(*[nn.Sequential(nn.Conv1d(e, e, kernel_size=k, padding=k // 2, groups=g), SamePad(k), TransposeLast(), LayerNorm(e, elementwise_affine=False), TransposeLast(), nn.GELU()) for _ in range(l)])
|
| 1086 |
-
|
| 1087 |
-
self.pos_conv = make_conv_block(self.embedding_dim, k, args.conv_pos_groups, num_layers)
|
| 1088 |
-
else: self.pos_conv = make_conv_pos(self.embedding_dim, args.conv_pos, args.conv_pos_groups)
|
| 1089 |
-
|
| 1090 |
-
self.layers = nn.ModuleList([self.build_encoder_layer(args, layer_idx=ii) for ii in range(args.encoder_layers)])
|
| 1091 |
-
self.layer_norm_first = args.layer_norm_first
|
| 1092 |
-
self.layer_norm = LayerNorm(self.embedding_dim)
|
| 1093 |
-
self.layerdrop = args.encoder_layerdrop
|
| 1094 |
-
self.apply(init_bert_params)
|
| 1095 |
-
|
| 1096 |
-
def forward(self, x, padding_mask=None, layer=None, corpus_key=None):
|
| 1097 |
-
x, layer_results = self.extract_features(x, padding_mask, layer, corpus_key=corpus_key)
|
| 1098 |
-
|
| 1099 |
-
if self.layer_norm_first and layer is None: x = self.layer_norm(x)
|
| 1100 |
-
return x, layer_results
|
| 1101 |
-
|
| 1102 |
-
def extract_features(self, x, padding_mask=None, tgt_layer=None, min_layer=0, corpus_key=None):
|
| 1103 |
-
if padding_mask is not None: x = index_put(x, padding_mask, 0)
|
| 1104 |
-
x = x + self.pos_conv(x.transpose(1, 2)).transpose(1, 2)
|
| 1105 |
-
|
| 1106 |
-
if not self.layer_norm_first: x = self.layer_norm(x)
|
| 1107 |
-
x, pad_length = pad_to_multiple(x, self.required_seq_len_multiple, dim=-2, value=0)
|
| 1108 |
-
|
| 1109 |
-
if pad_length > 0 and padding_mask is None:
|
| 1110 |
-
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
|
| 1111 |
-
padding_mask[:, -pad_length:] = True
|
| 1112 |
-
else: padding_mask, _ = pad_to_multiple(padding_mask, self.required_seq_len_multiple, dim=-1, value=True)
|
| 1113 |
-
|
| 1114 |
-
x = F.dropout(x, p=self.dropout, training=self.training).transpose(0, 1)
|
| 1115 |
-
layer_results = []
|
| 1116 |
-
r = None
|
| 1117 |
-
|
| 1118 |
-
for i, layer in enumerate(self.layers):
|
| 1119 |
-
dropout_probability = np.random.random() if self.layerdrop > 0 else 1
|
| 1120 |
-
if not self.training or (dropout_probability > self.layerdrop):
|
| 1121 |
-
layer_check = layer
|
| 1122 |
-
|
| 1123 |
-
if (corpus_key is None) or (not isinstance(layer_check, (TransformerSentenceEncoderWithAdapterLayer))): x, (z, lr) = layer(x, self_attn_padding_mask=padding_mask, need_weights=False)
|
| 1124 |
-
else: x, (z, lr) = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, corpus_key=corpus_key)
|
| 1125 |
-
|
| 1126 |
-
if i >= min_layer: layer_results.append((x, z, lr))
|
| 1127 |
-
if i == tgt_layer:
|
| 1128 |
-
r = x
|
| 1129 |
-
break
|
| 1130 |
-
|
| 1131 |
-
if r is not None: x = r
|
| 1132 |
-
x = x.transpose(0, 1)
|
| 1133 |
-
|
| 1134 |
-
if pad_length > 0:
|
| 1135 |
-
x = x[:, :-pad_length]
|
| 1136 |
-
def undo_pad(a, b, c):
|
| 1137 |
-
return (a[:-pad_length], b[:-pad_length] if b is not None else b, c[:-pad_length])
|
| 1138 |
-
|
| 1139 |
-
layer_results = [undo_pad(*u) for u in layer_results]
|
| 1140 |
-
|
| 1141 |
-
return x, layer_results
|
| 1142 |
-
|
| 1143 |
-
def max_positions(self):
|
| 1144 |
-
return self.args.max_positions
|
| 1145 |
-
|
| 1146 |
-
def upgrade_state_dict_named(self, state_dict, name):
|
| 1147 |
-
return state_dict
|
| 1148 |
-
|
| 1149 |
-
class Fp32GroupNorm(nn.GroupNorm):
|
| 1150 |
-
def __init__(self, *args, **kwargs):
|
| 1151 |
-
super().__init__(*args, **kwargs)
|
| 1152 |
-
|
| 1153 |
-
def forward(self, input):
|
| 1154 |
-
output = F.group_norm(input.float(), self.num_groups, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps)
|
| 1155 |
-
return output.type_as(input)
|
| 1156 |
-
|
| 1157 |
-
class Fp32LayerNorm(nn.LayerNorm):
|
| 1158 |
-
def __init__(self, *args, **kwargs):
|
| 1159 |
-
super().__init__(*args, **kwargs)
|
| 1160 |
-
|
| 1161 |
-
def forward(self, input):
|
| 1162 |
-
output = F.layer_norm(input.float(), self.normalized_shape, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps)
|
| 1163 |
-
return output.type_as(input)
|
| 1164 |
-
|
| 1165 |
-
class ConvFeatureExtractionModel(nn.Module):
|
| 1166 |
-
def __init__(self, conv_layers, dropout = 0.0, mode = "default", conv_bias = False):
|
| 1167 |
-
super().__init__()
|
| 1168 |
-
assert mode in {"default", "layer_norm"}
|
| 1169 |
-
|
| 1170 |
-
def block(n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False):
|
| 1171 |
-
def make_conv():
|
| 1172 |
-
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
| 1173 |
-
nn.init.kaiming_normal_(conv.weight)
|
| 1174 |
-
return conv
|
| 1175 |
-
|
| 1176 |
-
assert (is_layer_norm and is_group_norm) == False
|
| 1177 |
-
|
| 1178 |
-
if is_layer_norm: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.Sequential(TransposeLast(), Fp32LayerNorm(dim, elementwise_affine=True), TransposeLast()), nn.GELU())
|
| 1179 |
-
elif is_group_norm: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), Fp32GroupNorm(dim, dim, affine=True), nn.GELU())
|
| 1180 |
-
else: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
| 1181 |
-
|
| 1182 |
-
in_d = 1
|
| 1183 |
-
self.conv_layers = nn.ModuleList()
|
| 1184 |
-
for i, cl in enumerate(conv_layers):
|
| 1185 |
-
assert len(cl) == 3
|
| 1186 |
-
(dim, k, stride) = cl
|
| 1187 |
-
self.conv_layers.append(block(in_d, dim, k, stride, is_layer_norm=mode == "layer_norm", is_group_norm=mode == "default" and i == 0, conv_bias=conv_bias))
|
| 1188 |
-
in_d = dim
|
| 1189 |
-
|
| 1190 |
-
def forward(self, x):
|
| 1191 |
-
x = x.unsqueeze(1)
|
| 1192 |
-
for conv in self.conv_layers:
|
| 1193 |
-
x = conv(x)
|
| 1194 |
-
|
| 1195 |
-
return x
|
| 1196 |
-
|
| 1197 |
-
class GradMultiply(torch.autograd.Function):
|
| 1198 |
-
@staticmethod
|
| 1199 |
-
def forward(ctx, x, scale):
|
| 1200 |
-
ctx.scale = scale
|
| 1201 |
-
res = x.new(x)
|
| 1202 |
-
return res
|
| 1203 |
-
|
| 1204 |
-
@staticmethod
|
| 1205 |
-
def backward(ctx, grad):
|
| 1206 |
-
return grad * ctx.scale, None
|
| 1207 |
-
|
| 1208 |
-
class BaseFairseqModel(nn.Module):
|
| 1209 |
-
def __init__(self):
|
| 1210 |
-
super().__init__()
|
| 1211 |
-
self._is_generation_fast = False
|
| 1212 |
-
|
| 1213 |
-
def get_targets(self, sample, net_output):
|
| 1214 |
-
return sample["target"]
|
| 1215 |
-
|
| 1216 |
-
def extract_features(self, *args, **kwargs):
|
| 1217 |
-
return self(*args, **kwargs)
|
| 1218 |
-
|
| 1219 |
-
def load_state_dict(self, state_dict, strict=True, model_cfg = None, args = None):
|
| 1220 |
-
self.upgrade_state_dict(state_dict)
|
| 1221 |
-
new_state_dict = prune_state_dict(state_dict, model_cfg)
|
| 1222 |
-
return super().load_state_dict(new_state_dict, strict)
|
| 1223 |
-
|
| 1224 |
-
def upgrade_state_dict(self, state_dict):
|
| 1225 |
-
self.upgrade_state_dict_named(state_dict, "")
|
| 1226 |
-
|
| 1227 |
-
def upgrade_state_dict_named(self, state_dict, name):
|
| 1228 |
-
assert state_dict is not None
|
| 1229 |
-
|
| 1230 |
-
def do_upgrade(m, prefix):
|
| 1231 |
-
if len(prefix) > 0: prefix += "."
|
| 1232 |
-
|
| 1233 |
-
for n, c in m.named_children():
|
| 1234 |
-
name = prefix + n
|
| 1235 |
-
if hasattr(c, "upgrade_state_dict_named"): c.upgrade_state_dict_named(state_dict, name)
|
| 1236 |
-
elif hasattr(c, "upgrade_state_dict"): c.upgrade_state_dict(state_dict)
|
| 1237 |
-
do_upgrade(c, name)
|
| 1238 |
-
|
| 1239 |
-
do_upgrade(self, name)
|
| 1240 |
-
|
| 1241 |
-
def make_generation_fast_(self, **kwargs):
|
| 1242 |
-
if self._is_generation_fast: return
|
| 1243 |
-
self._is_generation_fast = True
|
| 1244 |
-
|
| 1245 |
-
def apply_remove_weight_norm(module):
|
| 1246 |
-
try:
|
| 1247 |
-
nn.utils.remove_weight_norm(module)
|
| 1248 |
-
except (AttributeError, ValueError):
|
| 1249 |
-
return
|
| 1250 |
-
|
| 1251 |
-
self.apply(apply_remove_weight_norm)
|
| 1252 |
-
|
| 1253 |
-
def apply_make_generation_fast_(module, prefix):
|
| 1254 |
-
if len(prefix) > 0: prefix += "."
|
| 1255 |
-
|
| 1256 |
-
base_func = BaseFairseqModel.make_generation_fast_
|
| 1257 |
-
for n, m in module.named_modules():
|
| 1258 |
-
if (m != self and hasattr(m, "make_generation_fast_") and m.make_generation_fast_.__func__ is not base_func): m.make_generation_fast_(name=prefix + n, **kwargs)
|
| 1259 |
-
|
| 1260 |
-
apply_make_generation_fast_(self, "")
|
| 1261 |
-
self.eval()
|
| 1262 |
-
|
| 1263 |
-
class HubertConfig:
|
| 1264 |
-
def __init__(self, _name, label_rate, encoder_layers_1, logit_temp_ctr, num_negatives, cross_sample_negatives, ctr_layers, extractor_mode = "default", encoder_layers = 12, encoder_embed_dim = 768, encoder_ffn_embed_dim = 3072, encoder_attention_heads = 12, activation_fn = "gelu", layer_type = "transformer", dropout = 0.1, attention_dropout = 0.1, activation_dropout = 0.0, encoder_layerdrop = 0.0, dropout_input = 0.0, dropout_features = 0.0, final_dim = 0, untie_final_proj = False, layer_norm_first = False, conv_feature_layers = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", conv_bias = False, logit_temp = 0.1, target_glu = False, feature_grad_mult = 1.0, mask_length = 10, mask_prob = 0.65, mask_selection = "static", mask_other = 0.0, no_mask_overlap = False, mask_min_space = 1, mask_channel_length = 10, mask_channel_prob = 0.0, mask_channel_selection = "static", mask_channel_other = 0.0, no_mask_channel_overlap = False, mask_channel_min_space = 1, conv_pos = 128, conv_pos_groups = 16, conv_pos_batch_norm = False, latent_temp = (2, 0.5, 0.999995), skip_masked = False, skip_nomask = False, checkpoint_activations = False, required_seq_len_multiple = 2, depthwise_conv_kernel_size = 31, attn_type = "", pos_enc_type = "abs", fp16 = False):
|
| 1265 |
-
self._name = _name
|
| 1266 |
-
self.label_rate = label_rate
|
| 1267 |
-
self.encoder_layers_1 = encoder_layers_1
|
| 1268 |
-
self.logit_temp_ctr = logit_temp_ctr
|
| 1269 |
-
self.num_negatives = num_negatives
|
| 1270 |
-
self.cross_sample_negatives = cross_sample_negatives
|
| 1271 |
-
self.ctr_layers = ctr_layers
|
| 1272 |
-
self.extractor_mode = extractor_mode
|
| 1273 |
-
self.encoder_layers = encoder_layers
|
| 1274 |
-
self.encoder_embed_dim = encoder_embed_dim
|
| 1275 |
-
self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
|
| 1276 |
-
self.encoder_attention_heads = encoder_attention_heads
|
| 1277 |
-
self.activation_fn = activation_fn
|
| 1278 |
-
self.layer_type = layer_type
|
| 1279 |
-
self.dropout = dropout
|
| 1280 |
-
self.attention_dropout = attention_dropout
|
| 1281 |
-
self.activation_dropout = activation_dropout
|
| 1282 |
-
self.encoder_layerdrop = encoder_layerdrop
|
| 1283 |
-
self.dropout_input = encoder_layerdrop
|
| 1284 |
-
self.dropout_features = dropout_features
|
| 1285 |
-
self.final_dim = final_dim
|
| 1286 |
-
self.untie_final_proj = untie_final_proj
|
| 1287 |
-
self.layer_norm_first = layer_norm_first
|
| 1288 |
-
self.conv_feature_layers = conv_feature_layers
|
| 1289 |
-
self.conv_bias = conv_bias
|
| 1290 |
-
self.logit_temp = logit_temp
|
| 1291 |
-
self.target_glu = target_glu
|
| 1292 |
-
self.feature_grad_mult = feature_grad_mult
|
| 1293 |
-
self.mask_length = mask_length
|
| 1294 |
-
self.mask_prob = mask_prob
|
| 1295 |
-
self.mask_selection = mask_selection
|
| 1296 |
-
self.mask_other = mask_other
|
| 1297 |
-
self.no_mask_overlap = no_mask_overlap
|
| 1298 |
-
self.mask_min_space = mask_min_space
|
| 1299 |
-
self.mask_channel_length = mask_channel_length
|
| 1300 |
-
self.mask_channel_prob = mask_channel_prob
|
| 1301 |
-
self.mask_channel_selection = mask_channel_selection
|
| 1302 |
-
self.mask_channel_other = mask_channel_other
|
| 1303 |
-
self.no_mask_channel_overlap = no_mask_channel_overlap
|
| 1304 |
-
self.mask_channel_min_space = mask_channel_min_space
|
| 1305 |
-
self.conv_pos = conv_pos
|
| 1306 |
-
self.conv_pos_groups = conv_pos_groups
|
| 1307 |
-
self.conv_pos_batch_norm = conv_pos_batch_norm
|
| 1308 |
-
self.latent_temp = latent_temp
|
| 1309 |
-
self.skip_masked = skip_masked
|
| 1310 |
-
self.skip_nomask = skip_nomask
|
| 1311 |
-
self.checkpoint_activations = checkpoint_activations
|
| 1312 |
-
self.required_seq_len_multiple = required_seq_len_multiple
|
| 1313 |
-
self.depthwise_conv_kernel_size = depthwise_conv_kernel_size
|
| 1314 |
-
self.attn_type = attn_type
|
| 1315 |
-
self.pos_enc_type = pos_enc_type
|
| 1316 |
-
self.fp16 = fp16
|
| 1317 |
-
|
| 1318 |
-
class Model_Config(dict):
|
| 1319 |
-
def __getattr__(*args):
|
| 1320 |
-
val = dict.get(*args)
|
| 1321 |
-
return Model_Config(val) if type(val) is dict else val
|
| 1322 |
-
|
| 1323 |
-
__setattr__ = dict.__setitem__
|
| 1324 |
-
__delattr__ = dict.__delitem__
|
| 1325 |
-
|
| 1326 |
-
class HubertModel(BaseFairseqModel):
|
| 1327 |
-
def __init__(self, cfg):
|
| 1328 |
-
super().__init__()
|
| 1329 |
-
feature_enc_layers = eval(cfg.conv_feature_layers)
|
| 1330 |
-
self.embed = feature_enc_layers[-1][0]
|
| 1331 |
-
self.feature_extractor = ConvFeatureExtractionModel(conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias)
|
| 1332 |
-
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
|
| 1333 |
-
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / 16000
|
| 1334 |
-
self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None)
|
| 1335 |
-
self.mask_prob = cfg.mask_prob
|
| 1336 |
-
self.mask_selection = cfg.mask_selection
|
| 1337 |
-
self.mask_other = cfg.mask_other
|
| 1338 |
-
self.mask_length = cfg.mask_length
|
| 1339 |
-
self.no_mask_overlap = cfg.no_mask_overlap
|
| 1340 |
-
self.mask_min_space = cfg.mask_min_space
|
| 1341 |
-
self.mask_channel_prob = cfg.mask_channel_prob
|
| 1342 |
-
self.mask_channel_selection = cfg.mask_channel_selection
|
| 1343 |
-
self.mask_channel_other = cfg.mask_channel_other
|
| 1344 |
-
self.mask_channel_length = cfg.mask_channel_length
|
| 1345 |
-
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
| 1346 |
-
self.mask_channel_min_space = cfg.mask_channel_min_space
|
| 1347 |
-
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
| 1348 |
-
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
| 1349 |
-
self.feature_grad_mult = cfg.feature_grad_mult
|
| 1350 |
-
self.logit_temp = cfg.logit_temp
|
| 1351 |
-
self.skip_masked = cfg.skip_masked
|
| 1352 |
-
self.skip_nomask = cfg.skip_nomask
|
| 1353 |
-
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
|
| 1354 |
-
self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_())
|
| 1355 |
-
self.encoder = TransformerEncoder(cfg)
|
| 1356 |
-
self.layer_norm = LayerNorm(self.embed)
|
| 1357 |
-
self.target_glu = None
|
| 1358 |
-
if cfg.target_glu: self.target_glu = nn.Sequential(nn.Linear(final_dim, final_dim * 2), nn.GLU())
|
| 1359 |
-
self.untie_final_proj = cfg.untie_final_proj
|
| 1360 |
-
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
|
| 1361 |
-
self.num_classes = [504]
|
| 1362 |
-
self.label_embs_concat = nn.Parameter(torch.FloatTensor(sum(self.num_classes), final_dim))
|
| 1363 |
-
nn.init.uniform_(self.label_embs_concat)
|
| 1364 |
-
|
| 1365 |
-
def upgrade_state_dict_named(self, state_dict, name):
|
| 1366 |
-
super().upgrade_state_dict_named(state_dict, name)
|
| 1367 |
-
return state_dict
|
| 1368 |
-
|
| 1369 |
-
def apply_mask(self, x, padding_mask, target_list):
|
| 1370 |
-
B, T, C = x.shape
|
| 1371 |
-
if self.mask_prob > 0:
|
| 1372 |
-
mask_indices = torch.from_numpy(compute_mask_indices((B, T), padding_mask, self.mask_prob, self.mask_length, self.mask_selection, self.mask_other, min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space)).to(x.device)
|
| 1373 |
-
x[mask_indices] = self.mask_emb
|
| 1374 |
-
else: mask_indices = None
|
| 1375 |
-
|
| 1376 |
-
if self.mask_channel_prob > 0: x[(torch.from_numpy(compute_mask_indices((B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space)).to(x.device).unsqueeze(1).expand(-1, T, -1))] = 0
|
| 1377 |
-
return x, mask_indices
|
| 1378 |
-
|
| 1379 |
-
def compute_nce(self, x, pos, negs):
|
| 1380 |
-
neg_is_pos = (pos == negs).all(-1)
|
| 1381 |
-
logits = torch.cosine_similarity(x.float(), torch.cat([pos.unsqueeze(0), negs], dim=0).float(), dim=-1).type_as(x)
|
| 1382 |
-
logits /= self.logit_temp
|
| 1383 |
-
|
| 1384 |
-
if neg_is_pos.any(): logits[1:][neg_is_pos] = float("-inf")
|
| 1385 |
-
return logits.transpose(0, 1)
|
| 1386 |
-
|
| 1387 |
-
def forward_features(self, source):
|
| 1388 |
-
if self.feature_grad_mult > 0:
|
| 1389 |
-
features = self.feature_extractor(source)
|
| 1390 |
-
if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult)
|
| 1391 |
-
else:
|
| 1392 |
-
with torch.no_grad():
|
| 1393 |
-
features = self.feature_extractor(source)
|
| 1394 |
-
return features
|
| 1395 |
-
|
| 1396 |
-
def forward_targets(self, features, target_list):
|
| 1397 |
-
feat_tsz = features.size(2)
|
| 1398 |
-
targ_tsz = min([t.size(1) for t in target_list])
|
| 1399 |
-
|
| 1400 |
-
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
| 1401 |
-
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
| 1402 |
-
features = features[..., :feat_tsz]
|
| 1403 |
-
|
| 1404 |
-
return features, [t[:, (torch.arange(feat_tsz).float() * self.feat2tar_ratio).long()] for t in target_list]
|
| 1405 |
-
|
| 1406 |
-
def forward_padding_mask(self, features, padding_mask):
|
| 1407 |
-
extra = padding_mask.size(1) % features.size(1)
|
| 1408 |
-
if extra > 0: padding_mask = padding_mask[:, :-extra]
|
| 1409 |
-
|
| 1410 |
-
return padding_mask.view(padding_mask.size(0), features.size(1), -1).all(-1)
|
| 1411 |
-
|
| 1412 |
-
def forward(self, source, target_list = None, padding_mask = None, mask = True, features_only = False, output_layer = None):
|
| 1413 |
-
features = self.forward_features(source)
|
| 1414 |
-
if target_list is not None: features, target_list = self.forward_targets(features, target_list)
|
| 1415 |
-
|
| 1416 |
-
features_pen = features.float().pow(2).mean()
|
| 1417 |
-
|
| 1418 |
-
features = self.layer_norm(features.transpose(1, 2))
|
| 1419 |
-
unmasked_features = features.clone()
|
| 1420 |
-
|
| 1421 |
-
if padding_mask is not None: padding_mask = self.forward_padding_mask(features, padding_mask)
|
| 1422 |
-
if self.post_extract_proj is not None: features = self.post_extract_proj(features)
|
| 1423 |
-
|
| 1424 |
-
features = self.dropout_input(features)
|
| 1425 |
-
unmasked_features = self.dropout_features(unmasked_features)
|
| 1426 |
-
|
| 1427 |
-
if mask: x, mask_indices = self.apply_mask(features, padding_mask, target_list)
|
| 1428 |
-
else: x, mask_indices = features, None
|
| 1429 |
-
|
| 1430 |
-
x, _ = self.encoder(x, padding_mask=padding_mask, layer=None if output_layer is None else output_layer - 1)
|
| 1431 |
-
if features_only: return {"x": x, "padding_mask": padding_mask, "features": features}
|
| 1432 |
-
|
| 1433 |
-
def compute_pred(proj_x, target, label_embs):
|
| 1434 |
-
y = torch.index_select(label_embs, 0, target.long())
|
| 1435 |
-
negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
|
| 1436 |
-
|
| 1437 |
-
if self.target_glu:
|
| 1438 |
-
y = self.target_glu(y)
|
| 1439 |
-
negs = self.target_glu(negs)
|
| 1440 |
-
|
| 1441 |
-
return self.compute_nce(proj_x, y, negs)
|
| 1442 |
-
|
| 1443 |
-
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
|
| 1444 |
-
|
| 1445 |
-
if not self.skip_masked:
|
| 1446 |
-
masked_indices = torch.logical_and(~padding_mask, mask_indices)
|
| 1447 |
-
proj_x_m = self.final_proj(x[masked_indices])
|
| 1448 |
-
logit_m_list = [compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) for i, (proj_x_m, t) in enumerate(zip(proj_x_m.chunk(len(target_list), dim=-1) if self.untie_final_proj else [proj_x_m for _ in range(len(target_list))], target_list))]
|
| 1449 |
-
else: logit_m_list = [None for _ in target_list]
|
| 1450 |
-
|
| 1451 |
-
if not self.skip_nomask:
|
| 1452 |
-
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
|
| 1453 |
-
proj_x_u = self.final_proj(x[nomask_indices])
|
| 1454 |
-
logit_u_list = [compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i]) for i, (proj_x_u, t) in enumerate(zip(proj_x_u.chunk(len(target_list), dim=-1) if self.untie_final_proj else [proj_x_u for _ in range(len(target_list))], target_list))]
|
| 1455 |
-
else: logit_u_list = [None for _ in target_list]
|
| 1456 |
-
|
| 1457 |
-
return {"logit_m_list": logit_m_list, "logit_u_list": logit_u_list, "padding_mask": padding_mask, "features_pen": features_pen}
|
| 1458 |
-
|
| 1459 |
-
def extract_features(self, source, padding_mask = None, mask = False, ret_conv = False, output_layer = None):
|
| 1460 |
-
res = self.forward(source, padding_mask=padding_mask, mask=mask, features_only=True, output_layer=output_layer)
|
| 1461 |
-
return res["features"] if ret_conv else res["x"], res["padding_mask"]
|
| 1462 |
-
|
| 1463 |
-
def get_logits(self, net_output, is_masked=True):
|
| 1464 |
-
return [x.float() for x in (net_output["logit_m_list"] if is_masked else net_output["logit_u_list"]) if x is not None]
|
| 1465 |
-
|
| 1466 |
-
def get_targets(self, net_output, is_masked=True):
|
| 1467 |
-
return [x.new_zeros(x.size(0), dtype=torch.long) for x in self.get_logits(net_output, is_masked)]
|
| 1468 |
-
|
| 1469 |
-
def get_extra_losses(self, net_output):
|
| 1470 |
-
extra_losses, names = [], []
|
| 1471 |
-
|
| 1472 |
-
if "features_pen" in net_output:
|
| 1473 |
-
extra_losses.append(net_output["features_pen"])
|
| 1474 |
-
names.append("features_pen")
|
| 1475 |
-
|
| 1476 |
-
return extra_losses, names
|
| 1477 |
-
|
| 1478 |
-
def remove_pretraining_modules(self):
|
| 1479 |
-
self.target_glu = None
|
| 1480 |
-
self.final_proj = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/architectures/mdx_separator.py
DELETED
|
@@ -1,320 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import onnx
|
| 4 |
-
import torch
|
| 5 |
-
import platform
|
| 6 |
-
import onnx2torch
|
| 7 |
-
|
| 8 |
-
import numpy as np
|
| 9 |
-
import onnxruntime as ort
|
| 10 |
-
|
| 11 |
-
from tqdm import tqdm
|
| 12 |
-
|
| 13 |
-
sys.path.append(os.getcwd())
|
| 14 |
-
|
| 15 |
-
from main.configs.config import Config
|
| 16 |
-
from main.library.uvr5_separator import spec_utils
|
| 17 |
-
from main.library.uvr5_separator.common_separator import CommonSeparator
|
| 18 |
-
|
| 19 |
-
translations = Config().translations
|
| 20 |
-
|
| 21 |
-
class MDXSeparator(CommonSeparator):
|
| 22 |
-
def __init__(self, common_config, arch_config):
|
| 23 |
-
super().__init__(config=common_config)
|
| 24 |
-
self.segment_size = arch_config.get("segment_size")
|
| 25 |
-
self.overlap = arch_config.get("overlap")
|
| 26 |
-
self.batch_size = arch_config.get("batch_size", 1)
|
| 27 |
-
self.hop_length = arch_config.get("hop_length")
|
| 28 |
-
self.enable_denoise = arch_config.get("enable_denoise")
|
| 29 |
-
self.logger.debug(translations["mdx_info"].format(batch_size=self.batch_size, segment_size=self.segment_size))
|
| 30 |
-
self.logger.debug(translations["mdx_info_2"].format(overlap=self.overlap, hop_length=self.hop_length, enable_denoise=self.enable_denoise))
|
| 31 |
-
self.compensate = self.model_data["compensate"]
|
| 32 |
-
self.dim_f = self.model_data["mdx_dim_f_set"]
|
| 33 |
-
self.dim_t = 2 ** self.model_data["mdx_dim_t_set"]
|
| 34 |
-
self.n_fft = self.model_data["mdx_n_fft_scale_set"]
|
| 35 |
-
self.config_yaml = self.model_data.get("config_yaml", None)
|
| 36 |
-
self.logger.debug(f"{translations['mdx_info_3']}: compensate = {self.compensate}, dim_f = {self.dim_f}, dim_t = {self.dim_t}, n_fft = {self.n_fft}")
|
| 37 |
-
self.logger.debug(f"{translations['mdx_info_3']}: config_yaml = {self.config_yaml}")
|
| 38 |
-
self.load_model()
|
| 39 |
-
self.n_bins = 0
|
| 40 |
-
self.trim = 0
|
| 41 |
-
self.chunk_size = 0
|
| 42 |
-
self.gen_size = 0
|
| 43 |
-
self.stft = None
|
| 44 |
-
self.primary_source = None
|
| 45 |
-
self.secondary_source = None
|
| 46 |
-
self.audio_file_path = None
|
| 47 |
-
self.audio_file_base = None
|
| 48 |
-
|
| 49 |
-
def load_model(self):
|
| 50 |
-
self.logger.debug(translations["load_model_onnx"])
|
| 51 |
-
|
| 52 |
-
if self.segment_size == self.dim_t:
|
| 53 |
-
ort_session_options = ort.SessionOptions()
|
| 54 |
-
ort_session_options.log_severity_level = 3 if self.log_level > 10 else 0
|
| 55 |
-
ort_inference_session = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider, sess_options=ort_session_options)
|
| 56 |
-
self.model_run = lambda spek: ort_inference_session.run(None, {"input": spek.cpu().numpy()})[0]
|
| 57 |
-
self.logger.debug(translations["load_model_onnx_success"])
|
| 58 |
-
else:
|
| 59 |
-
self.model_run = onnx2torch.convert(onnx.load(self.model_path)) if platform.system() == 'Windows' else onnx2torch.convert(self.model_path)
|
| 60 |
-
self.model_run.to(self.torch_device).eval()
|
| 61 |
-
self.logger.debug(translations["onnx_to_pytorch"])
|
| 62 |
-
|
| 63 |
-
def separate(self, audio_file_path):
|
| 64 |
-
self.audio_file_path = audio_file_path
|
| 65 |
-
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
|
| 66 |
-
self.logger.debug(translations["mix"].format(audio_file_path=self.audio_file_path))
|
| 67 |
-
mix = self.prepare_mix(self.audio_file_path)
|
| 68 |
-
self.logger.debug(translations["normalization_demix"])
|
| 69 |
-
mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold)
|
| 70 |
-
source = self.demix(mix)
|
| 71 |
-
self.logger.debug(translations["mix_success"])
|
| 72 |
-
output_files = []
|
| 73 |
-
self.logger.debug(translations["process_output_file"])
|
| 74 |
-
|
| 75 |
-
if not isinstance(self.primary_source, np.ndarray):
|
| 76 |
-
self.logger.debug(translations["primary_source"])
|
| 77 |
-
self.primary_source = spec_utils.normalize(wave=source, max_peak=self.normalization_threshold).T
|
| 78 |
-
|
| 79 |
-
if not isinstance(self.secondary_source, np.ndarray):
|
| 80 |
-
self.logger.debug(translations["secondary_source"])
|
| 81 |
-
raw_mix = self.demix(mix, is_match_mix=True)
|
| 82 |
-
|
| 83 |
-
if self.invert_using_spec:
|
| 84 |
-
self.logger.debug(translations["invert_using_spec"])
|
| 85 |
-
self.secondary_source = spec_utils.invert_stem(raw_mix, source)
|
| 86 |
-
else:
|
| 87 |
-
self.logger.debug(translations["invert_using_spec_2"])
|
| 88 |
-
self.secondary_source = mix.T - source.T
|
| 89 |
-
|
| 90 |
-
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
|
| 91 |
-
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
|
| 92 |
-
self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.secondary_stem_name, stem_output_path=self.secondary_stem_output_path))
|
| 93 |
-
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
|
| 94 |
-
output_files.append(self.secondary_stem_output_path)
|
| 95 |
-
|
| 96 |
-
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
|
| 97 |
-
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
|
| 98 |
-
if not isinstance(self.primary_source, np.ndarray): self.primary_source = source.T
|
| 99 |
-
|
| 100 |
-
self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.primary_stem_name, stem_output_path=self.primary_stem_output_path))
|
| 101 |
-
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
|
| 102 |
-
output_files.append(self.primary_stem_output_path)
|
| 103 |
-
|
| 104 |
-
return output_files
|
| 105 |
-
|
| 106 |
-
def initialize_model_settings(self):
|
| 107 |
-
self.logger.debug(translations["starting_model"])
|
| 108 |
-
|
| 109 |
-
self.n_bins = self.n_fft // 2 + 1
|
| 110 |
-
self.trim = self.n_fft // 2
|
| 111 |
-
|
| 112 |
-
self.chunk_size = self.hop_length * (self.segment_size - 1)
|
| 113 |
-
self.gen_size = self.chunk_size - 2 * self.trim
|
| 114 |
-
|
| 115 |
-
self.stft = STFT(self.logger, self.n_fft, self.hop_length, self.dim_f, self.torch_device)
|
| 116 |
-
|
| 117 |
-
self.logger.debug(f"{translations['input_info']}: n_fft = {self.n_fft} hop_length = {self.hop_length} dim_f = {self.dim_f}")
|
| 118 |
-
self.logger.debug(f"{translations['model_settings']}: n_bins = {self.n_bins}, Trim = {self.trim}, chunk_size = {self.chunk_size}, gen_size = {self.gen_size}")
|
| 119 |
-
|
| 120 |
-
def initialize_mix(self, mix, is_ckpt=False):
|
| 121 |
-
self.logger.debug(translations["initialize_mix"].format(is_ckpt=is_ckpt, shape=mix.shape))
|
| 122 |
-
|
| 123 |
-
if mix.shape[0] != 2:
|
| 124 |
-
error_message = translations["!=2"].format(shape=mix.shape[0])
|
| 125 |
-
self.logger.error(error_message)
|
| 126 |
-
raise ValueError(error_message)
|
| 127 |
-
|
| 128 |
-
if is_ckpt:
|
| 129 |
-
self.logger.debug(translations["process_check"])
|
| 130 |
-
pad = self.gen_size + self.trim - (mix.shape[-1] % self.gen_size)
|
| 131 |
-
self.logger.debug(f"{translations['cache']}: {pad}")
|
| 132 |
-
|
| 133 |
-
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1)
|
| 134 |
-
|
| 135 |
-
num_chunks = mixture.shape[-1] // self.gen_size
|
| 136 |
-
self.logger.debug(translations["shape"].format(shape=mixture.shape, num_chunks=num_chunks))
|
| 137 |
-
|
| 138 |
-
mix_waves = [mixture[:, i * self.gen_size : i * self.gen_size + self.chunk_size] for i in range(num_chunks)]
|
| 139 |
-
else:
|
| 140 |
-
self.logger.debug(translations["process_no_check"])
|
| 141 |
-
mix_waves = []
|
| 142 |
-
n_sample = mix.shape[1]
|
| 143 |
-
|
| 144 |
-
pad = self.gen_size - n_sample % self.gen_size
|
| 145 |
-
self.logger.debug(translations["n_sample_or_pad"].format(n_sample=n_sample, pad=pad))
|
| 146 |
-
|
| 147 |
-
mix_p = np.concatenate((np.zeros((2, self.trim)), mix, np.zeros((2, pad)), np.zeros((2, self.trim))), 1)
|
| 148 |
-
self.logger.debug(f"{translations['shape_2']}: {mix_p.shape}")
|
| 149 |
-
|
| 150 |
-
i = 0
|
| 151 |
-
while i < n_sample + pad:
|
| 152 |
-
mix_waves.append(np.array(mix_p[:, i : i + self.chunk_size]))
|
| 153 |
-
|
| 154 |
-
self.logger.debug(translations["process_part"].format(mix_waves=len(mix_waves), i=i, ii=i + self.chunk_size))
|
| 155 |
-
i += self.gen_size
|
| 156 |
-
|
| 157 |
-
mix_waves_tensor = torch.tensor(mix_waves, dtype=torch.float32).to(self.torch_device)
|
| 158 |
-
self.logger.debug(translations["mix_waves_to_tensor"].format(shape=mix_waves_tensor.shape))
|
| 159 |
-
|
| 160 |
-
return mix_waves_tensor, pad
|
| 161 |
-
|
| 162 |
-
def demix(self, mix, is_match_mix=False):
|
| 163 |
-
self.logger.debug(f"{translations['demix_is_match_mix']}: {is_match_mix}...")
|
| 164 |
-
self.initialize_model_settings()
|
| 165 |
-
self.logger.debug(f"{translations['mix_shape']}: {mix.shape}")
|
| 166 |
-
tar_waves_ = []
|
| 167 |
-
|
| 168 |
-
if is_match_mix:
|
| 169 |
-
chunk_size = self.hop_length * (self.segment_size - 1)
|
| 170 |
-
overlap = 0.02
|
| 171 |
-
self.logger.debug(translations["chunk_size_or_overlap"].format(chunk_size=chunk_size, overlap=overlap))
|
| 172 |
-
else:
|
| 173 |
-
chunk_size = self.chunk_size
|
| 174 |
-
overlap = self.overlap
|
| 175 |
-
self.logger.debug(translations["chunk_size_or_overlap_standard"].format(chunk_size=chunk_size, overlap=overlap))
|
| 176 |
-
|
| 177 |
-
gen_size = chunk_size - 2 * self.trim
|
| 178 |
-
self.logger.debug(f"{translations['calc_size']}: {gen_size}")
|
| 179 |
-
|
| 180 |
-
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, gen_size + self.trim - ((mix.shape[-1]) % gen_size)), dtype="float32")), 1)
|
| 181 |
-
self.logger.debug(f"{translations['mix_cache']}: {mixture.shape}")
|
| 182 |
-
|
| 183 |
-
step = int((1 - overlap) * chunk_size)
|
| 184 |
-
self.logger.debug(translations["step_or_overlap"].format(step=step, overlap=overlap))
|
| 185 |
-
|
| 186 |
-
result = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
|
| 187 |
-
divider = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
|
| 188 |
-
|
| 189 |
-
total = 0
|
| 190 |
-
total_chunks = (mixture.shape[-1] + step - 1) // step
|
| 191 |
-
self.logger.debug(f"{translations['all_process_part']}: {total_chunks}")
|
| 192 |
-
|
| 193 |
-
for i in tqdm(range(0, mixture.shape[-1], step), ncols=100, unit="f"):
|
| 194 |
-
total += 1
|
| 195 |
-
start = i
|
| 196 |
-
end = min(i + chunk_size, mixture.shape[-1])
|
| 197 |
-
self.logger.debug(translations["process_part_2"].format(total=total, total_chunks=total_chunks, start=start, end=end))
|
| 198 |
-
|
| 199 |
-
chunk_size_actual = end - start
|
| 200 |
-
window = None
|
| 201 |
-
|
| 202 |
-
if overlap != 0:
|
| 203 |
-
window = np.hanning(chunk_size_actual)
|
| 204 |
-
window = np.tile(window[None, None, :], (1, 2, 1))
|
| 205 |
-
self.logger.debug(translations["window"])
|
| 206 |
-
|
| 207 |
-
mix_part_ = mixture[:, start:end]
|
| 208 |
-
|
| 209 |
-
if end != i + chunk_size:
|
| 210 |
-
pad_size = (i + chunk_size) - end
|
| 211 |
-
mix_part_ = np.concatenate((mix_part_, np.zeros((2, pad_size), dtype="float32")), axis=-1)
|
| 212 |
-
|
| 213 |
-
mix_waves = torch.tensor([mix_part_], dtype=torch.float32).to(self.torch_device).split(self.batch_size)
|
| 214 |
-
|
| 215 |
-
total_batches = len(mix_waves)
|
| 216 |
-
self.logger.debug(f"{translations['mix_or_batch']}: {total_batches}")
|
| 217 |
-
|
| 218 |
-
with torch.no_grad():
|
| 219 |
-
batches_processed = 0
|
| 220 |
-
|
| 221 |
-
for mix_wave in mix_waves:
|
| 222 |
-
batches_processed += 1
|
| 223 |
-
self.logger.debug(f"{translations['mix_wave']} {batches_processed}/{total_batches}")
|
| 224 |
-
|
| 225 |
-
tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix)
|
| 226 |
-
|
| 227 |
-
if window is not None:
|
| 228 |
-
tar_waves[..., :chunk_size_actual] *= window
|
| 229 |
-
divider[..., start:end] += window
|
| 230 |
-
else: divider[..., start:end] += 1
|
| 231 |
-
|
| 232 |
-
result[..., start:end] += tar_waves[..., : end - start]
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
self.logger.debug(translations["normalization_2"])
|
| 236 |
-
tar_waves = result / divider
|
| 237 |
-
tar_waves_.append(tar_waves)
|
| 238 |
-
|
| 239 |
-
tar_waves = np.concatenate(np.vstack(tar_waves_)[:, :, self.trim : -self.trim], axis=-1)[:, : mix.shape[-1]]
|
| 240 |
-
|
| 241 |
-
source = tar_waves[:, 0:None]
|
| 242 |
-
self.logger.debug(f"{translations['tar_waves']}: {tar_waves.shape}")
|
| 243 |
-
|
| 244 |
-
if not is_match_mix:
|
| 245 |
-
source *= self.compensate
|
| 246 |
-
self.logger.debug(translations["mix_match"])
|
| 247 |
-
|
| 248 |
-
self.logger.debug(translations["mix_success"])
|
| 249 |
-
return source
|
| 250 |
-
|
| 251 |
-
def run_model(self, mix, is_match_mix=False):
|
| 252 |
-
spek = self.stft(mix.to(self.torch_device))
|
| 253 |
-
self.logger.debug(translations["stft_2"].format(shape=spek.shape))
|
| 254 |
-
|
| 255 |
-
spek[:, :, :3, :] *= 0
|
| 256 |
-
|
| 257 |
-
if is_match_mix:
|
| 258 |
-
spec_pred = spek.cpu().numpy()
|
| 259 |
-
self.logger.debug(translations["is_match_mix"])
|
| 260 |
-
else:
|
| 261 |
-
if self.enable_denoise:
|
| 262 |
-
spec_pred_neg = self.model_run(-spek)
|
| 263 |
-
spec_pred_pos = self.model_run(spek)
|
| 264 |
-
spec_pred = (spec_pred_neg * -0.5) + (spec_pred_pos * 0.5)
|
| 265 |
-
self.logger.debug(translations["enable_denoise"])
|
| 266 |
-
else:
|
| 267 |
-
spec_pred = self.model_run(spek)
|
| 268 |
-
self.logger.debug(translations["no_denoise"])
|
| 269 |
-
|
| 270 |
-
result = self.stft.inverse(torch.tensor(spec_pred).to(self.torch_device)).cpu().detach().numpy()
|
| 271 |
-
self.logger.debug(f"{translations['stft']}: {result.shape}")
|
| 272 |
-
|
| 273 |
-
return result
|
| 274 |
-
|
| 275 |
-
class STFT:
|
| 276 |
-
def __init__(self, logger, n_fft, hop_length, dim_f, device):
|
| 277 |
-
self.logger = logger
|
| 278 |
-
self.n_fft = n_fft
|
| 279 |
-
self.hop_length = hop_length
|
| 280 |
-
self.dim_f = dim_f
|
| 281 |
-
self.device = device
|
| 282 |
-
self.hann_window = torch.hann_window(window_length=self.n_fft, periodic=True)
|
| 283 |
-
|
| 284 |
-
def __call__(self, input_tensor):
|
| 285 |
-
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
|
| 286 |
-
|
| 287 |
-
if is_non_standard_device: input_tensor = input_tensor.cpu()
|
| 288 |
-
|
| 289 |
-
batch_dimensions = input_tensor.shape[:-2]
|
| 290 |
-
channel_dim, time_dim = input_tensor.shape[-2:]
|
| 291 |
-
|
| 292 |
-
permuted_stft_output = torch.stft(input_tensor.reshape([-1, time_dim]), n_fft=self.n_fft, hop_length=self.hop_length, window=self.hann_window.to(input_tensor.device), center=True, return_complex=False).permute([0, 3, 1, 2])
|
| 293 |
-
final_output = permuted_stft_output.reshape([*batch_dimensions, channel_dim, 2, -1, permuted_stft_output.shape[-1]]).reshape([*batch_dimensions, channel_dim * 2, -1, permuted_stft_output.shape[-1]])
|
| 294 |
-
|
| 295 |
-
if is_non_standard_device: final_output = final_output.to(self.device)
|
| 296 |
-
return final_output[..., : self.dim_f, :]
|
| 297 |
-
|
| 298 |
-
def pad_frequency_dimension(self, input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins):
|
| 299 |
-
return torch.cat([input_tensor, torch.zeros([*batch_dimensions, channel_dim, num_freq_bins - freq_dim, time_dim]).to(input_tensor.device)], -2)
|
| 300 |
-
|
| 301 |
-
def calculate_inverse_dimensions(self, input_tensor):
|
| 302 |
-
channel_dim, freq_dim, time_dim = input_tensor.shape[-3:]
|
| 303 |
-
|
| 304 |
-
return input_tensor.shape[:-3], channel_dim, freq_dim, time_dim, self.n_fft // 2 + 1
|
| 305 |
-
|
| 306 |
-
def prepare_for_istft(self, padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim):
|
| 307 |
-
permuted_tensor = padded_tensor.reshape([*batch_dimensions, channel_dim // 2, 2, num_freq_bins, time_dim]).reshape([-1, 2, num_freq_bins, time_dim]).permute([0, 2, 3, 1])
|
| 308 |
-
|
| 309 |
-
return permuted_tensor[..., 0] + permuted_tensor[..., 1] * 1.0j
|
| 310 |
-
|
| 311 |
-
def inverse(self, input_tensor):
|
| 312 |
-
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
|
| 313 |
-
if is_non_standard_device: input_tensor = input_tensor.cpu()
|
| 314 |
-
|
| 315 |
-
batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins = self.calculate_inverse_dimensions(input_tensor)
|
| 316 |
-
final_output = torch.istft(self.prepare_for_istft(self.pad_frequency_dimension(input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins), batch_dimensions, channel_dim, num_freq_bins, time_dim), n_fft=self.n_fft, hop_length=self.hop_length, window=self.hann_window.to(input_tensor.device), center=True).reshape([*batch_dimensions, 2, -1])
|
| 317 |
-
|
| 318 |
-
if is_non_standard_device: final_output = final_output.to(self.device)
|
| 319 |
-
|
| 320 |
-
return final_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/audioldm2/models.py
DELETED
|
@@ -1,330 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import torch
|
| 4 |
-
import librosa
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
|
| 9 |
-
from scipy.signal import get_window
|
| 10 |
-
from librosa.util import pad_center
|
| 11 |
-
from diffusers import DDIMScheduler, AudioLDM2Pipeline
|
| 12 |
-
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
|
| 13 |
-
from transformers import RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer
|
| 14 |
-
|
| 15 |
-
sys.path.append(os.getcwd())
|
| 16 |
-
|
| 17 |
-
from main.configs.config import Config
|
| 18 |
-
from main.library.utils import check_audioldm2
|
| 19 |
-
|
| 20 |
-
config = Config()
|
| 21 |
-
|
| 22 |
-
class Pipeline(torch.nn.Module):
|
| 23 |
-
def __init__(self, model_id, device, double_precision = False, token = None, *args, **kwargs):
|
| 24 |
-
super().__init__(*args, **kwargs)
|
| 25 |
-
self.model_id = model_id
|
| 26 |
-
self.device = device
|
| 27 |
-
self.double_precision = double_precision
|
| 28 |
-
self.token = token
|
| 29 |
-
|
| 30 |
-
def load_scheduler(self):
|
| 31 |
-
pass
|
| 32 |
-
|
| 33 |
-
def get_melspectrogram(self):
|
| 34 |
-
pass
|
| 35 |
-
|
| 36 |
-
def vae_encode(self, x):
|
| 37 |
-
pass
|
| 38 |
-
|
| 39 |
-
def vae_decode(self, x):
|
| 40 |
-
pass
|
| 41 |
-
|
| 42 |
-
def decode_to_mel(self, x):
|
| 43 |
-
pass
|
| 44 |
-
|
| 45 |
-
def setup_extra_inputs(self, *args, **kwargs):
|
| 46 |
-
pass
|
| 47 |
-
|
| 48 |
-
def encode_text(self, prompts, **kwargs):
|
| 49 |
-
pass
|
| 50 |
-
|
| 51 |
-
def get_variance(self, timestep, prev_timestep):
|
| 52 |
-
pass
|
| 53 |
-
|
| 54 |
-
def get_alpha_prod_t_prev(self, prev_timestep):
|
| 55 |
-
pass
|
| 56 |
-
|
| 57 |
-
def get_noise_shape(self, x0, num_steps):
|
| 58 |
-
return (num_steps, self.model.unet.config.in_channels, x0.shape[-2], x0.shape[-1])
|
| 59 |
-
|
| 60 |
-
def sample_xts_from_x0(self, x0, num_inference_steps = 50):
|
| 61 |
-
alpha_bar = self.model.scheduler.alphas_cumprod
|
| 62 |
-
sqrt_one_minus_alpha_bar = (1 - alpha_bar) ** 0.5
|
| 63 |
-
timesteps = self.model.scheduler.timesteps.to(self.device)
|
| 64 |
-
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
| 65 |
-
xts = torch.zeros(self.get_noise_shape(x0, num_inference_steps + 1)).to(x0.device)
|
| 66 |
-
xts[0] = x0
|
| 67 |
-
|
| 68 |
-
for t in reversed(timesteps):
|
| 69 |
-
idx = num_inference_steps - t_to_idx[int(t)]
|
| 70 |
-
xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
|
| 71 |
-
|
| 72 |
-
return xts
|
| 73 |
-
|
| 74 |
-
def get_zs_from_xts(self, xt, xtm1, noise_pred, t, eta = 0, numerical_fix = True, **kwargs):
|
| 75 |
-
alpha_bar = self.model.scheduler.alphas_cumprod
|
| 76 |
-
|
| 77 |
-
if self.model.scheduler.config.prediction_type == 'epsilon': pred_original_sample = (xt - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5
|
| 78 |
-
elif self.model.scheduler.config.prediction_type == 'v_prediction': pred_original_sample = (alpha_bar[t] ** 0.5) * xt - ((1 - alpha_bar[t]) ** 0.5) * noise_pred
|
| 79 |
-
|
| 80 |
-
prev_timestep = t - self.model.scheduler.config.num_train_timesteps // self.model.scheduler.num_inference_steps
|
| 81 |
-
alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
|
| 82 |
-
variance = self.get_variance(t, prev_timestep)
|
| 83 |
-
|
| 84 |
-
if self.model.scheduler.config.prediction_type == 'epsilon': radom_noise_pred = noise_pred
|
| 85 |
-
elif self.model.scheduler.config.prediction_type == 'v_prediction': radom_noise_pred = (alpha_bar[t] ** 0.5) * noise_pred + ((1 - alpha_bar[t]) ** 0.5) * xt
|
| 86 |
-
|
| 87 |
-
mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + ((1 - alpha_prod_t_prev - eta * variance) ** (0.5) * radom_noise_pred)
|
| 88 |
-
z = (xtm1 - mu_xt) / (eta * variance ** 0.5)
|
| 89 |
-
|
| 90 |
-
if numerical_fix: xtm1 = mu_xt + (eta * variance ** 0.5)*z
|
| 91 |
-
return z, xtm1, None
|
| 92 |
-
|
| 93 |
-
def reverse_step_with_custom_noise(self, model_output, timestep, sample, variance_noise = None, eta = 0, **kwargs):
|
| 94 |
-
prev_timestep = timestep - self.model.scheduler.config.num_train_timesteps // self.model.scheduler.num_inference_steps
|
| 95 |
-
alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
|
| 96 |
-
alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
|
| 97 |
-
beta_prod_t = 1 - alpha_prod_t
|
| 98 |
-
|
| 99 |
-
if self.model.scheduler.config.prediction_type == 'epsilon': pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
| 100 |
-
elif self.model.scheduler.config.prediction_type == 'v_prediction': pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
|
| 101 |
-
|
| 102 |
-
variance = self.get_variance(timestep, prev_timestep)
|
| 103 |
-
|
| 104 |
-
if self.model.scheduler.config.prediction_type == 'epsilon': model_output_direction = model_output
|
| 105 |
-
elif self.model.scheduler.config.prediction_type == 'v_prediction': model_output_direction = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
| 106 |
-
|
| 107 |
-
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + ((1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction)
|
| 108 |
-
|
| 109 |
-
if eta > 0:
|
| 110 |
-
if variance_noise is None: variance_noise = torch.randn(model_output.shape, device=self.device)
|
| 111 |
-
prev_sample = prev_sample + (eta * variance ** (0.5) * variance_noise)
|
| 112 |
-
|
| 113 |
-
return prev_sample
|
| 114 |
-
|
| 115 |
-
def unet_forward(self, sample, timestep, encoder_hidden_states, class_labels = None, timestep_cond = None, attention_mask = None, cross_attention_kwargs = None, added_cond_kwargs = None, down_block_additional_residuals = None, mid_block_additional_residual = None, encoder_attention_mask = None, replace_h_space = None, replace_skip_conns = None, return_dict = True, zero_out_resconns = None):
|
| 116 |
-
pass
|
| 117 |
-
|
| 118 |
-
class STFT(torch.nn.Module):
|
| 119 |
-
def __init__(self, fft_size, hop_size, window_size, window_type="hann"):
|
| 120 |
-
super().__init__()
|
| 121 |
-
self.fft_size = fft_size
|
| 122 |
-
self.hop_size = hop_size
|
| 123 |
-
self.window_size = window_size
|
| 124 |
-
self.window_type = window_type
|
| 125 |
-
|
| 126 |
-
scale = fft_size / hop_size
|
| 127 |
-
fourier_basis = np.fft.fft(np.eye(fft_size))
|
| 128 |
-
|
| 129 |
-
cutoff = fft_size // 2 + 1
|
| 130 |
-
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])])
|
| 131 |
-
|
| 132 |
-
self.forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
| 133 |
-
self.inverse_basis = torch.FloatTensor(np.linalg.pinv(scale * fourier_basis).T[:, None, :])
|
| 134 |
-
|
| 135 |
-
if window_type:
|
| 136 |
-
assert fft_size >= window_size
|
| 137 |
-
|
| 138 |
-
fft_window = torch.from_numpy(pad_center(get_window(window_type, window_size, fftbins=True), size=fft_size)).float()
|
| 139 |
-
self.forward_basis *= fft_window
|
| 140 |
-
self.inverse_basis *= fft_window
|
| 141 |
-
|
| 142 |
-
if not hasattr(self, "forward_basis"): self.register_buffer("forward_basis", self.forward_basis)
|
| 143 |
-
if not hasattr(self, "inverse_basis"): self.register_buffer("inverse_basis", self.inverse_basis)
|
| 144 |
-
|
| 145 |
-
def transform(self, signal):
|
| 146 |
-
batch_size, num_samples = signal.shape
|
| 147 |
-
transformed_signal = F.conv1d(F.pad(signal.view(batch_size, 1, num_samples).unsqueeze(1), (self.fft_size // 2, self.fft_size // 2, 0, 0), mode="reflect").squeeze(1), self.forward_basis, stride=self.hop_size, padding=0).cpu()
|
| 148 |
-
|
| 149 |
-
cutoff = self.fft_size // 2 + 1
|
| 150 |
-
real_part, imag_part = transformed_signal[:, :cutoff, :], transformed_signal[:, cutoff:, :]
|
| 151 |
-
|
| 152 |
-
return torch.sqrt(real_part ** 2 + imag_part ** 2), torch.atan2(imag_part, real_part)
|
| 153 |
-
|
| 154 |
-
class MelSpectrogramProcessor(torch.nn.Module):
|
| 155 |
-
def __init__(self, fft_size, hop_size, window_size, num_mel_bins, sample_rate, fmin, fmax):
|
| 156 |
-
super().__init__()
|
| 157 |
-
self.num_mel_bins = num_mel_bins
|
| 158 |
-
self.sample_rate = sample_rate
|
| 159 |
-
self.stft_processor = STFT(fft_size, hop_size, window_size)
|
| 160 |
-
self.register_buffer("mel_filter", torch.from_numpy(librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mel_bins, fmin=fmin, fmax=fmax)).float())
|
| 161 |
-
|
| 162 |
-
def compute_mel_spectrogram(self, waveform, normalization_fn=torch.log):
|
| 163 |
-
assert torch.min(waveform) >= -1
|
| 164 |
-
assert torch.max(waveform) <= 1
|
| 165 |
-
|
| 166 |
-
magnitudes, _ = self.stft_processor.transform(waveform)
|
| 167 |
-
return normalization_fn(torch.clamp(torch.matmul(self.mel_filter, magnitudes), min=1e-5))
|
| 168 |
-
|
| 169 |
-
class AudioLDM2(Pipeline):
|
| 170 |
-
def __init__(self, *args, **kwargs):
|
| 171 |
-
super().__init__(*args, **kwargs)
|
| 172 |
-
self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=True, torch_dtype=torch.float16 if config.is_half else torch.float32).to(self.device)
|
| 173 |
-
|
| 174 |
-
def load_scheduler(self):
|
| 175 |
-
self.model.scheduler = DDIMScheduler.from_pretrained(self.model_id, subfolder="scheduler")
|
| 176 |
-
|
| 177 |
-
def get_melspectrogram(self):
|
| 178 |
-
return MelSpectrogramProcessor(fft_size=1024, hop_size=160, window_size=1024, num_mel_bins=64, sample_rate=16000, fmin=0, fmax=8000)
|
| 179 |
-
|
| 180 |
-
def vae_encode(self, x):
|
| 181 |
-
if x.shape[2] % 4: x = F.pad(x, (0, 0, 4 - (x.shape[2] % 4), 0))
|
| 182 |
-
output = (self.model.vae.encode(x.half() if config.is_half else x.float()).latent_dist.mode() * self.model.vae.config.scaling_factor)
|
| 183 |
-
return output.half() if config.is_half else output.float()
|
| 184 |
-
|
| 185 |
-
def vae_decode(self, x):
|
| 186 |
-
return self.model.vae.decode(1 / self.model.vae.config.scaling_factor * x).sample
|
| 187 |
-
|
| 188 |
-
def decode_to_mel(self, x):
|
| 189 |
-
tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().to(torch.float16 if config.is_half else torch.float32)).detach()
|
| 190 |
-
|
| 191 |
-
if len(tmp.shape) == 1: tmp = tmp.unsqueeze(0)
|
| 192 |
-
return tmp
|
| 193 |
-
|
| 194 |
-
def encode_text(self, prompts, negative = False, save_compute = False, cond_length = 0, **kwargs):
|
| 195 |
-
tokenizers, text_encoders = [self.model.tokenizer, self.model.tokenizer_2], [self.model.text_encoder, self.model.text_encoder_2]
|
| 196 |
-
prompt_embeds_list, attention_mask_list = [], []
|
| 197 |
-
|
| 198 |
-
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
| 199 |
-
text_inputs = tokenizer(prompts, padding="max_length" if (save_compute and negative) or isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) else True, max_length=tokenizer.model_max_length if (not save_compute) or ((not negative) or isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer))) else cond_length, truncation=True, return_tensors="pt")
|
| 200 |
-
text_input_ids = text_inputs.input_ids
|
| 201 |
-
|
| 202 |
-
attention_mask = text_inputs.attention_mask
|
| 203 |
-
untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids
|
| 204 |
-
|
| 205 |
-
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1: -1])
|
| 206 |
-
|
| 207 |
-
text_input_ids = text_input_ids.to(self.device)
|
| 208 |
-
attention_mask = attention_mask.to(self.device)
|
| 209 |
-
|
| 210 |
-
with torch.no_grad():
|
| 211 |
-
if text_encoder.config.model_type == "clap":
|
| 212 |
-
prompt_embeds = text_encoder.get_text_features(text_input_ids, attention_mask=attention_mask)
|
| 213 |
-
prompt_embeds = prompt_embeds[:, None, :]
|
| 214 |
-
attention_mask = attention_mask.new_ones((len(prompts), 1))
|
| 215 |
-
else: prompt_embeds = text_encoder(text_input_ids, attention_mask=attention_mask)[0]
|
| 216 |
-
|
| 217 |
-
prompt_embeds_list.append(prompt_embeds)
|
| 218 |
-
attention_mask_list.append(attention_mask)
|
| 219 |
-
|
| 220 |
-
projection_output = self.model.projection_model(hidden_states=prompt_embeds_list[0], hidden_states_1=prompt_embeds_list[1], attention_mask=attention_mask_list[0], attention_mask_1=attention_mask_list[1])
|
| 221 |
-
generated_prompt_embeds = self.model.generate_language_model(projection_output.hidden_states, attention_mask=projection_output.attention_mask, max_new_tokens=None)
|
| 222 |
-
prompt_embeds = prompt_embeds.to(dtype=self.model.text_encoder_2.dtype, device=self.device)
|
| 223 |
-
return generated_prompt_embeds.to(dtype=self.model.language_model.dtype, device=self.device), prompt_embeds, (attention_mask.to(device=self.device) if attention_mask is not None else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=self.device))
|
| 224 |
-
|
| 225 |
-
def get_variance(self, timestep, prev_timestep):
|
| 226 |
-
alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
|
| 227 |
-
alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
|
| 228 |
-
return ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
| 229 |
-
|
| 230 |
-
def get_alpha_prod_t_prev(self, prev_timestep):
|
| 231 |
-
return self.model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.model.scheduler.final_alpha_cumprod
|
| 232 |
-
|
| 233 |
-
def unet_forward(self, sample, timestep, encoder_hidden_states, timestep_cond = None, class_labels = None, attention_mask = None, encoder_attention_mask = None, return_dict = True, cross_attention_kwargs = None, mid_block_additional_residual = None, replace_h_space = None, replace_skip_conns = None, zero_out_resconns = None):
|
| 234 |
-
encoder_hidden_states_1 = class_labels
|
| 235 |
-
class_labels = None
|
| 236 |
-
encoder_attention_mask_1 = encoder_attention_mask
|
| 237 |
-
encoder_attention_mask = None
|
| 238 |
-
default_overall_up_factor = 2 ** self.model.unet.num_upsamplers
|
| 239 |
-
forward_upsample_size = False
|
| 240 |
-
upsample_size = None
|
| 241 |
-
|
| 242 |
-
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): forward_upsample_size = True
|
| 243 |
-
|
| 244 |
-
if attention_mask is not None:
|
| 245 |
-
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
| 246 |
-
attention_mask = attention_mask.unsqueeze(1)
|
| 247 |
-
|
| 248 |
-
if encoder_attention_mask is not None:
|
| 249 |
-
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
| 250 |
-
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
| 251 |
-
|
| 252 |
-
if encoder_attention_mask_1 is not None:
|
| 253 |
-
encoder_attention_mask_1 = (1 - encoder_attention_mask_1.to(sample.dtype)) * -10000.0
|
| 254 |
-
encoder_attention_mask_1 = encoder_attention_mask_1.unsqueeze(1)
|
| 255 |
-
|
| 256 |
-
timesteps = timestep
|
| 257 |
-
if not torch.is_tensor(timesteps):
|
| 258 |
-
is_mps = sample.device.type == "mps"
|
| 259 |
-
|
| 260 |
-
dtype = (torch.float16 if is_mps else torch.float32) if isinstance(timestep, float) else (torch.int16 if is_mps else torch.int32)
|
| 261 |
-
|
| 262 |
-
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 263 |
-
elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device)
|
| 264 |
-
|
| 265 |
-
emb = self.model.unet.time_embedding(self.model.unet.time_proj(timesteps.expand(sample.shape[0])).to(dtype=sample.dtype), timestep_cond)
|
| 266 |
-
aug_emb = None
|
| 267 |
-
|
| 268 |
-
if self.model.unet.class_embedding is not None:
|
| 269 |
-
if class_labels is None: raise ValueError
|
| 270 |
-
|
| 271 |
-
if self.model.unet.config.class_embed_type == "timestep": class_labels = self.model.unet.time_proj(class_labels).to(dtype=sample.dtype)
|
| 272 |
-
class_emb = self.model.unet.class_embedding(class_labels).to(dtype=sample.dtype)
|
| 273 |
-
|
| 274 |
-
if self.model.unet.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1)
|
| 275 |
-
else: emb = emb + class_emb
|
| 276 |
-
|
| 277 |
-
emb = emb + aug_emb if aug_emb is not None else emb
|
| 278 |
-
if self.model.unet.time_embed_act is not None: emb = self.model.unet.time_embed_act(emb)
|
| 279 |
-
|
| 280 |
-
sample = self.model.unet.conv_in(sample)
|
| 281 |
-
down_block_res_samples = (sample,)
|
| 282 |
-
|
| 283 |
-
for downsample_block in self.model.unet.down_blocks:
|
| 284 |
-
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states_1=encoder_hidden_states_1, encoder_attention_mask_1=encoder_attention_mask_1)
|
| 285 |
-
else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
| 286 |
-
|
| 287 |
-
down_block_res_samples += res_samples
|
| 288 |
-
|
| 289 |
-
if self.model.unet.mid_block is not None: sample = self.model.unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states_1=encoder_hidden_states_1, encoder_attention_mask_1=encoder_attention_mask_1)
|
| 290 |
-
|
| 291 |
-
if replace_h_space is None: h_space = sample.clone()
|
| 292 |
-
else:
|
| 293 |
-
h_space = replace_h_space
|
| 294 |
-
sample = replace_h_space.clone()
|
| 295 |
-
|
| 296 |
-
if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual
|
| 297 |
-
extracted_res_conns = {}
|
| 298 |
-
|
| 299 |
-
for i, upsample_block in enumerate(self.model.unet.up_blocks):
|
| 300 |
-
is_final_block = i == len(self.model.unet.up_blocks) - 1
|
| 301 |
-
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
| 302 |
-
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 303 |
-
|
| 304 |
-
if replace_skip_conns is not None and replace_skip_conns.get(i): res_samples = replace_skip_conns.get(i)
|
| 305 |
-
|
| 306 |
-
if zero_out_resconns is not None:
|
| 307 |
-
if (type(zero_out_resconns) is int and i >= (zero_out_resconns - 1)) or type(zero_out_resconns) is list and i in zero_out_resconns: res_samples = [torch.zeros_like(x) for x in res_samples]
|
| 308 |
-
|
| 309 |
-
extracted_res_conns[i] = res_samples
|
| 310 |
-
if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:]
|
| 311 |
-
|
| 312 |
-
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states_1=encoder_hidden_states_1, encoder_attention_mask_1=encoder_attention_mask_1)
|
| 313 |
-
else: sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size)
|
| 314 |
-
|
| 315 |
-
if self.model.unet.conv_norm_out: sample = self.model.unet.conv_act(self.model.unet.conv_norm_out(sample))
|
| 316 |
-
sample = self.model.unet.conv_out(sample)
|
| 317 |
-
|
| 318 |
-
if not return_dict: return (sample,)
|
| 319 |
-
return UNet2DConditionOutput(sample=sample), h_space, extracted_res_conns
|
| 320 |
-
|
| 321 |
-
def load_model(model, device):
|
| 322 |
-
check_audioldm2(model)
|
| 323 |
-
|
| 324 |
-
ldm_stable = AudioLDM2(model_id=os.path.join("assets", "models", "audioldm2", model), device=device, double_precision=False)
|
| 325 |
-
ldm_stable.load_scheduler()
|
| 326 |
-
|
| 327 |
-
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 328 |
-
elif torch.backends.mps.is_available(): torch.mps.empty_cache()
|
| 329 |
-
|
| 330 |
-
return ldm_stable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/audioldm2/utils.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import librosa
|
| 3 |
-
import torchaudio
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
|
| 7 |
-
def compute_mel_spectrogram(audio, stft_processor):
|
| 8 |
-
return stft_processor.compute_mel_spectrogram(torch.autograd.Variable(torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1), requires_grad=False)).squeeze(0).numpy().astype(np.float32)
|
| 9 |
-
|
| 10 |
-
def pad_spectrogram(spectrogram, target_length=1024):
|
| 11 |
-
pad_amount = target_length - spectrogram.shape[0]
|
| 12 |
-
spectrogram = torch.nn.functional.pad(spectrogram, (0, 0, 0, pad_amount)) if pad_amount > 0 else spectrogram[:target_length, :]
|
| 13 |
-
|
| 14 |
-
if spectrogram.size(-1) % 2 != 0: spectrogram = spectrogram[..., :-1]
|
| 15 |
-
return spectrogram
|
| 16 |
-
|
| 17 |
-
def pad_waveform(waveform, segment_length):
|
| 18 |
-
waveform_length = waveform.shape[-1]
|
| 19 |
-
assert waveform_length > 100
|
| 20 |
-
|
| 21 |
-
if segment_length is None or waveform_length == segment_length: return waveform
|
| 22 |
-
elif waveform_length > segment_length: return waveform[:, :segment_length]
|
| 23 |
-
|
| 24 |
-
padded_waveform = np.zeros((1, segment_length))
|
| 25 |
-
padded_waveform[:, :waveform_length] = waveform
|
| 26 |
-
return padded_waveform
|
| 27 |
-
|
| 28 |
-
def normalize(waveform):
|
| 29 |
-
waveform -= np.mean(waveform)
|
| 30 |
-
return (waveform / (np.max(np.abs(waveform)) + 1e-8)) * 0.5
|
| 31 |
-
|
| 32 |
-
def process_audio(y, sr, segment_length):
|
| 33 |
-
normalized_waveform = normalize(torchaudio.functional.resample(torch.from_numpy(y), orig_freq=sr, new_freq=16000).numpy())[None, ...]
|
| 34 |
-
return 0.5 * (pad_waveform(normalized_waveform, segment_length) / np.max(np.abs(normalized_waveform)))
|
| 35 |
-
|
| 36 |
-
def load_audio(audio_path, stft_processor, device=None):
|
| 37 |
-
y, sr = librosa.load(audio_path, sr=None)
|
| 38 |
-
duration = len(y) / sr
|
| 39 |
-
|
| 40 |
-
return pad_spectrogram(torch.FloatTensor(compute_mel_spectrogram(torch.FloatTensor(process_audio(y, sr, int(duration * 102.4) * 160)[0, ...]), stft_processor).T), int(duration * 102.4)).unsqueeze(0).unsqueeze(0).to(device), duration
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/predictors/CREPE.py
DELETED
|
@@ -1,210 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
import librosa
|
| 4 |
-
import functools
|
| 5 |
-
import scipy.stats
|
| 6 |
-
|
| 7 |
-
import numpy as np
|
| 8 |
-
|
| 9 |
-
CENTS_PER_BIN, MAX_FMAX, PITCH_BINS, SAMPLE_RATE, WINDOW_SIZE = 20, 2006, 360, 16000, 1024
|
| 10 |
-
|
| 11 |
-
class Crepe(torch.nn.Module):
|
| 12 |
-
def __init__(self, model='full'):
|
| 13 |
-
super().__init__()
|
| 14 |
-
if model == 'full':
|
| 15 |
-
in_channels = [1, 1024, 128, 128, 128, 256]
|
| 16 |
-
out_channels = [1024, 128, 128, 128, 256, 512]
|
| 17 |
-
self.in_features = 2048
|
| 18 |
-
elif model == 'large':
|
| 19 |
-
in_channels = [1, 768, 96, 96, 96, 192]
|
| 20 |
-
out_channels = [768, 96, 96, 96, 192, 384]
|
| 21 |
-
self.in_features = 1536
|
| 22 |
-
elif model == 'medium':
|
| 23 |
-
in_channels = [1, 512, 64, 64, 64, 128]
|
| 24 |
-
out_channels = [512, 64, 64, 64, 128, 256]
|
| 25 |
-
self.in_features = 1024
|
| 26 |
-
elif model == 'small':
|
| 27 |
-
in_channels = [1, 256, 32, 32, 32, 64]
|
| 28 |
-
out_channels = [256, 32, 32, 32, 64, 128]
|
| 29 |
-
self.in_features = 512
|
| 30 |
-
elif model == 'tiny':
|
| 31 |
-
in_channels = [1, 128, 16, 16, 16, 32]
|
| 32 |
-
out_channels = [128, 16, 16, 16, 32, 64]
|
| 33 |
-
self.in_features = 256
|
| 34 |
-
|
| 35 |
-
kernel_sizes = [(512, 1)] + 5 * [(64, 1)]
|
| 36 |
-
strides = [(4, 1)] + 5 * [(1, 1)]
|
| 37 |
-
|
| 38 |
-
batch_norm_fn = functools.partial(torch.nn.BatchNorm2d, eps=0.0010000000474974513, momentum=0.0)
|
| 39 |
-
|
| 40 |
-
self.conv1 = torch.nn.Conv2d(in_channels=in_channels[0], out_channels=out_channels[0], kernel_size=kernel_sizes[0], stride=strides[0])
|
| 41 |
-
self.conv1_BN = batch_norm_fn(num_features=out_channels[0])
|
| 42 |
-
self.conv2 = torch.nn.Conv2d(in_channels=in_channels[1], out_channels=out_channels[1], kernel_size=kernel_sizes[1], stride=strides[1])
|
| 43 |
-
self.conv2_BN = batch_norm_fn(num_features=out_channels[1])
|
| 44 |
-
|
| 45 |
-
self.conv3 = torch.nn.Conv2d(in_channels=in_channels[2], out_channels=out_channels[2], kernel_size=kernel_sizes[2], stride=strides[2])
|
| 46 |
-
self.conv3_BN = batch_norm_fn(num_features=out_channels[2])
|
| 47 |
-
self.conv4 = torch.nn.Conv2d(in_channels=in_channels[3], out_channels=out_channels[3], kernel_size=kernel_sizes[3], stride=strides[3])
|
| 48 |
-
self.conv4_BN = batch_norm_fn(num_features=out_channels[3])
|
| 49 |
-
|
| 50 |
-
self.conv5 = torch.nn.Conv2d(in_channels=in_channels[4], out_channels=out_channels[4], kernel_size=kernel_sizes[4], stride=strides[4])
|
| 51 |
-
self.conv5_BN = batch_norm_fn(num_features=out_channels[4])
|
| 52 |
-
self.conv6 = torch.nn.Conv2d(in_channels=in_channels[5], out_channels=out_channels[5], kernel_size=kernel_sizes[5], stride=strides[5])
|
| 53 |
-
self.conv6_BN = batch_norm_fn(num_features=out_channels[5])
|
| 54 |
-
|
| 55 |
-
self.classifier = torch.nn.Linear(in_features=self.in_features, out_features=PITCH_BINS)
|
| 56 |
-
|
| 57 |
-
def forward(self, x, embed=False):
|
| 58 |
-
x = self.embed(x)
|
| 59 |
-
if embed: return x
|
| 60 |
-
|
| 61 |
-
return torch.sigmoid(self.classifier(self.layer(x, self.conv6, self.conv6_BN).permute(0, 2, 1, 3).reshape(-1, self.in_features)))
|
| 62 |
-
|
| 63 |
-
def embed(self, x):
|
| 64 |
-
x = x[:, None, :, None]
|
| 65 |
-
|
| 66 |
-
return self.layer(self.layer(self.layer(self.layer(self.layer(x, self.conv1, self.conv1_BN, (0, 0, 254, 254)), self.conv2, self.conv2_BN), self.conv3, self.conv3_BN), self.conv4, self.conv4_BN), self.conv5, self.conv5_BN)
|
| 67 |
-
|
| 68 |
-
def layer(self, x, conv, batch_norm, padding=(0, 0, 31, 32)):
|
| 69 |
-
return torch.nn.functional.max_pool2d(batch_norm(torch.nn.functional.relu(conv(torch.nn.functional.pad(x, padding)))), (2, 1), (2, 1))
|
| 70 |
-
|
| 71 |
-
def viterbi(logits):
|
| 72 |
-
if not hasattr(viterbi, 'transition'):
|
| 73 |
-
xx, yy = np.meshgrid(range(360), range(360))
|
| 74 |
-
transition = np.maximum(12 - abs(xx - yy), 0)
|
| 75 |
-
viterbi.transition = transition / transition.sum(axis=1, keepdims=True)
|
| 76 |
-
|
| 77 |
-
with torch.no_grad():
|
| 78 |
-
probs = torch.nn.functional.softmax(logits, dim=1)
|
| 79 |
-
|
| 80 |
-
bins = torch.tensor(np.array([librosa.sequence.viterbi(sequence, viterbi.transition).astype(np.int64) for sequence in probs.cpu().numpy()]), device=probs.device)
|
| 81 |
-
return bins, bins_to_frequency(bins)
|
| 82 |
-
|
| 83 |
-
def predict(audio, sample_rate, hop_length=None, fmin=50, fmax=MAX_FMAX, model='full', return_periodicity=False, batch_size=None, device='cpu', pad=True, providers=None, onnx=False):
|
| 84 |
-
results = []
|
| 85 |
-
|
| 86 |
-
if onnx:
|
| 87 |
-
import onnxruntime as ort
|
| 88 |
-
|
| 89 |
-
sess_options = ort.SessionOptions()
|
| 90 |
-
sess_options.log_severity_level = 3
|
| 91 |
-
|
| 92 |
-
session = ort.InferenceSession(os.path.join("assets", "models", "predictors", f"crepe_{model}.onnx"), sess_options=sess_options, providers=providers)
|
| 93 |
-
|
| 94 |
-
for frames in preprocess(audio, sample_rate, hop_length, batch_size, device, pad):
|
| 95 |
-
result = postprocess(torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: frames.cpu().numpy()})[0].transpose(1, 0)[None]), fmin, fmax, return_periodicity)
|
| 96 |
-
results.append((result[0], result[1]) if isinstance(result, tuple) else result)
|
| 97 |
-
|
| 98 |
-
del session
|
| 99 |
-
|
| 100 |
-
if return_periodicity:
|
| 101 |
-
pitch, periodicity = zip(*results)
|
| 102 |
-
return torch.cat(pitch, 1), torch.cat(periodicity, 1)
|
| 103 |
-
|
| 104 |
-
return torch.cat(results, 1)
|
| 105 |
-
else:
|
| 106 |
-
with torch.no_grad():
|
| 107 |
-
for frames in preprocess(audio, sample_rate, hop_length, batch_size, device, pad):
|
| 108 |
-
result = postprocess(infer(frames, model, device, embed=False).reshape(audio.size(0), -1, PITCH_BINS).transpose(1, 2), fmin, fmax, return_periodicity)
|
| 109 |
-
results.append((result[0].to(audio.device), result[1].to(audio.device)) if isinstance(result, tuple) else result.to(audio.device))
|
| 110 |
-
|
| 111 |
-
if return_periodicity:
|
| 112 |
-
pitch, periodicity = zip(*results)
|
| 113 |
-
return torch.cat(pitch, 1), torch.cat(periodicity, 1)
|
| 114 |
-
|
| 115 |
-
return torch.cat(results, 1)
|
| 116 |
-
|
| 117 |
-
def bins_to_frequency(bins):
|
| 118 |
-
cents = CENTS_PER_BIN * bins + 1997.3794084376191
|
| 119 |
-
return 10 * 2 ** ((cents + cents.new_tensor(scipy.stats.triang.rvs(c=0.5, loc=-CENTS_PER_BIN, scale=2 * CENTS_PER_BIN, size=cents.size()))) / 1200)
|
| 120 |
-
|
| 121 |
-
def frequency_to_bins(frequency, quantize_fn=torch.floor):
|
| 122 |
-
return quantize_fn(((1200 * torch.log2(frequency / 10)) - 1997.3794084376191) / CENTS_PER_BIN).int()
|
| 123 |
-
|
| 124 |
-
def infer(frames, model='full', device='cpu', embed=False):
|
| 125 |
-
if not hasattr(infer, 'model') or not hasattr(infer, 'capacity') or (hasattr(infer, 'capacity') and infer.capacity != model): load_model(device, model)
|
| 126 |
-
infer.model = infer.model.to(device)
|
| 127 |
-
|
| 128 |
-
return infer.model(frames, embed=embed)
|
| 129 |
-
|
| 130 |
-
def load_model(device, capacity='full'):
|
| 131 |
-
infer.capacity = capacity
|
| 132 |
-
infer.model = Crepe(capacity)
|
| 133 |
-
infer.model.load_state_dict(torch.load(os.path.join("assets", "models", "predictors", f"crepe_{capacity}.pth"), map_location=device))
|
| 134 |
-
infer.model = infer.model.to(torch.device(device))
|
| 135 |
-
infer.model.eval()
|
| 136 |
-
|
| 137 |
-
def postprocess(probabilities, fmin=0, fmax=MAX_FMAX, return_periodicity=False):
|
| 138 |
-
probabilities = probabilities.detach()
|
| 139 |
-
|
| 140 |
-
probabilities[:, :frequency_to_bins(torch.tensor(fmin))] = -float('inf')
|
| 141 |
-
probabilities[:, frequency_to_bins(torch.tensor(fmax), torch.ceil):] = -float('inf')
|
| 142 |
-
|
| 143 |
-
bins, pitch = viterbi(probabilities)
|
| 144 |
-
|
| 145 |
-
if not return_periodicity: return pitch
|
| 146 |
-
return pitch, periodicity(probabilities, bins)
|
| 147 |
-
|
| 148 |
-
def preprocess(audio, sample_rate, hop_length=None, batch_size=None, device='cpu', pad=True):
|
| 149 |
-
hop_length = sample_rate // 100 if hop_length is None else hop_length
|
| 150 |
-
|
| 151 |
-
if sample_rate != SAMPLE_RATE:
|
| 152 |
-
audio = torch.tensor(librosa.resample(audio.detach().cpu().numpy().squeeze(0), orig_sr=sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_vhq"), device=audio.device).unsqueeze(0)
|
| 153 |
-
hop_length = int(hop_length * SAMPLE_RATE / sample_rate)
|
| 154 |
-
|
| 155 |
-
if pad:
|
| 156 |
-
total_frames = 1 + int(audio.size(1) // hop_length)
|
| 157 |
-
audio = torch.nn.functional.pad(audio, (WINDOW_SIZE // 2, WINDOW_SIZE // 2))
|
| 158 |
-
else: total_frames = 1 + int((audio.size(1) - WINDOW_SIZE) // hop_length)
|
| 159 |
-
|
| 160 |
-
batch_size = total_frames if batch_size is None else batch_size
|
| 161 |
-
|
| 162 |
-
for i in range(0, total_frames, batch_size):
|
| 163 |
-
frames = torch.nn.functional.unfold(audio[:, None, None, max(0, i * hop_length):min(audio.size(1), (i + batch_size - 1) * hop_length + WINDOW_SIZE)], kernel_size=(1, WINDOW_SIZE), stride=(1, hop_length))
|
| 164 |
-
frames = frames.transpose(1, 2).reshape(-1, WINDOW_SIZE).to(device)
|
| 165 |
-
frames -= frames.mean(dim=1, keepdim=True)
|
| 166 |
-
frames /= torch.max(torch.tensor(1e-10, device=frames.device), frames.std(dim=1, keepdim=True))
|
| 167 |
-
|
| 168 |
-
yield frames
|
| 169 |
-
|
| 170 |
-
def periodicity(probabilities, bins):
|
| 171 |
-
probs_stacked = probabilities.transpose(1, 2).reshape(-1, PITCH_BINS)
|
| 172 |
-
periodicity = probs_stacked.gather(1, bins.reshape(-1, 1).to(torch.int64))
|
| 173 |
-
|
| 174 |
-
return periodicity.reshape(probabilities.size(0), probabilities.size(2))
|
| 175 |
-
|
| 176 |
-
def mean(signals, win_length=9):
|
| 177 |
-
assert signals.dim() == 2
|
| 178 |
-
|
| 179 |
-
signals = signals.unsqueeze(1)
|
| 180 |
-
mask = ~torch.isnan(signals)
|
| 181 |
-
padding = win_length // 2
|
| 182 |
-
|
| 183 |
-
ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device)
|
| 184 |
-
avg_pooled = torch.nn.functional.conv1d(torch.where(mask, signals, torch.zeros_like(signals)), ones_kernel, stride=1, padding=padding) / torch.nn.functional.conv1d(mask.float(), ones_kernel, stride=1, padding=padding).clamp(min=1)
|
| 185 |
-
avg_pooled[avg_pooled == 0] = float("nan")
|
| 186 |
-
|
| 187 |
-
return avg_pooled.squeeze(1)
|
| 188 |
-
|
| 189 |
-
def median(signals, win_length):
|
| 190 |
-
assert signals.dim() == 2
|
| 191 |
-
|
| 192 |
-
signals = signals.unsqueeze(1)
|
| 193 |
-
mask = ~torch.isnan(signals)
|
| 194 |
-
padding = win_length // 2
|
| 195 |
-
|
| 196 |
-
x = torch.nn.functional.pad(torch.where(mask, signals, torch.zeros_like(signals)), (padding, padding), mode="reflect")
|
| 197 |
-
mask = torch.nn.functional.pad(mask.float(), (padding, padding), mode="constant", value=0)
|
| 198 |
-
|
| 199 |
-
x = x.unfold(2, win_length, 1)
|
| 200 |
-
mask = mask.unfold(2, win_length, 1)
|
| 201 |
-
|
| 202 |
-
x = x.contiguous().view(x.size()[:3] + (-1,))
|
| 203 |
-
mask = mask.contiguous().view(mask.size()[:3] + (-1,))
|
| 204 |
-
|
| 205 |
-
x_sorted, _ = torch.sort(torch.where(mask.bool(), x.float(), float("inf")).to(x), dim=-1)
|
| 206 |
-
|
| 207 |
-
median_pooled = x_sorted.gather(-1, ((mask.sum(dim=-1) - 1) // 2).clamp(min=0).unsqueeze(-1).long()).squeeze(-1)
|
| 208 |
-
median_pooled[torch.isinf(median_pooled)] = float("nan")
|
| 209 |
-
|
| 210 |
-
return median_pooled.squeeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/predictors/FCPE.py
DELETED
|
@@ -1,1097 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import io
|
| 3 |
-
import math
|
| 4 |
-
import torch
|
| 5 |
-
import librosa
|
| 6 |
-
|
| 7 |
-
import numpy as np
|
| 8 |
-
import soundfile as sf
|
| 9 |
-
import onnxruntime as ort
|
| 10 |
-
import torch.nn.functional as F
|
| 11 |
-
|
| 12 |
-
from torch import nn, einsum
|
| 13 |
-
from functools import partial
|
| 14 |
-
from Crypto.Cipher import AES
|
| 15 |
-
from Crypto.Util.Padding import unpad
|
| 16 |
-
from torchaudio.transforms import Resample
|
| 17 |
-
from einops import rearrange, repeat, pack, unpack
|
| 18 |
-
from torch.nn.utils.parametrizations import weight_norm
|
| 19 |
-
|
| 20 |
-
from librosa.filters import mel as librosa_mel_fn
|
| 21 |
-
|
| 22 |
-
os.environ["LRU_CACHE_CAPACITY"] = "3"
|
| 23 |
-
|
| 24 |
-
def exists(val):
|
| 25 |
-
return val is not None
|
| 26 |
-
|
| 27 |
-
def default(value, d):
|
| 28 |
-
return value if exists(value) else d
|
| 29 |
-
|
| 30 |
-
def max_neg_value(tensor):
|
| 31 |
-
return -torch.finfo(tensor.dtype).max
|
| 32 |
-
|
| 33 |
-
def empty(tensor):
|
| 34 |
-
return tensor.numel() == 0
|
| 35 |
-
|
| 36 |
-
def l2norm(tensor):
|
| 37 |
-
return F.normalize(tensor, dim = -1).type(tensor.dtype)
|
| 38 |
-
|
| 39 |
-
def decrypt_model(input_path):
|
| 40 |
-
with open(input_path, "rb") as f:
|
| 41 |
-
data = f.read()
|
| 42 |
-
|
| 43 |
-
with open(os.path.join("main", "configs", "decrypt.bin"), "rb") as f:
|
| 44 |
-
key = f.read()
|
| 45 |
-
|
| 46 |
-
return io.BytesIO(unpad(AES.new(key, AES.MODE_CBC, data[:16]).decrypt(data[16:]), AES.block_size)).read()
|
| 47 |
-
|
| 48 |
-
def l2_regularization(model, l2_alpha):
|
| 49 |
-
l2_loss = []
|
| 50 |
-
|
| 51 |
-
for module in model.modules():
|
| 52 |
-
if type(module) is nn.Conv2d: l2_loss.append((module.weight**2).sum() / 2.0)
|
| 53 |
-
|
| 54 |
-
return l2_alpha * sum(l2_loss)
|
| 55 |
-
|
| 56 |
-
def pad_to_multiple(tensor, multiple, dim=-1, value=0):
|
| 57 |
-
seqlen = tensor.shape[dim]
|
| 58 |
-
m = seqlen / multiple
|
| 59 |
-
|
| 60 |
-
if m.is_integer(): return False, tensor
|
| 61 |
-
return True, F.pad(tensor, (*((0,) * (-1 - dim) * 2), 0, (math.ceil(m) * multiple - seqlen)), value = value)
|
| 62 |
-
|
| 63 |
-
def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2):
|
| 64 |
-
t = x.shape[1]
|
| 65 |
-
dims = (len(x.shape) - dim) * (0, 0)
|
| 66 |
-
|
| 67 |
-
padded_x = F.pad(x, (*dims, backward, forward), value = pad_value)
|
| 68 |
-
return torch.cat([padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)], dim = dim)
|
| 69 |
-
|
| 70 |
-
def rotate_half(x):
|
| 71 |
-
x1, x2 = rearrange(x, 'b ... (r d) -> b ... r d', r = 2).unbind(dim = -2)
|
| 72 |
-
return torch.cat((-x2, x1), dim = -1)
|
| 73 |
-
|
| 74 |
-
def apply_rotary_pos_emb(q, k, freqs, scale = 1):
|
| 75 |
-
q_len = q.shape[-2]
|
| 76 |
-
q_freqs = freqs[..., -q_len:, :]
|
| 77 |
-
inv_scale = scale ** -1
|
| 78 |
-
|
| 79 |
-
if scale.ndim == 2: scale = scale[-q_len:, :]
|
| 80 |
-
|
| 81 |
-
q = (q * q_freqs.cos() * scale) + (rotate_half(q) * q_freqs.sin() * scale)
|
| 82 |
-
k = (k * freqs.cos() * inv_scale) + (rotate_half(k) * freqs.sin() * inv_scale)
|
| 83 |
-
|
| 84 |
-
return q, k
|
| 85 |
-
|
| 86 |
-
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 87 |
-
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 88 |
-
|
| 89 |
-
def orthogonal_matrix_chunk(cols, qr_uniform_q=False, device=None):
|
| 90 |
-
unstructured_block = torch.randn((cols, cols), device=device)
|
| 91 |
-
|
| 92 |
-
q, r = torch.linalg.qr(unstructured_block.cpu(), mode="reduced")
|
| 93 |
-
q, r = map(lambda t: t.to(device), (q, r))
|
| 94 |
-
|
| 95 |
-
if qr_uniform_q:
|
| 96 |
-
d = torch.diag(r, 0)
|
| 97 |
-
q *= d.sign()
|
| 98 |
-
|
| 99 |
-
return q.t()
|
| 100 |
-
|
| 101 |
-
def linear_attention(q, k, v):
|
| 102 |
-
return einsum("...ed,...nd->...ne", k, q) if v is None else einsum("...de,...nd,...n->...ne", einsum("...nd,...ne->...de", k, v), q, 1.0 / (einsum("...nd,...d->...n", q, k.sum(dim=-2).type_as(q)) + 1e-8))
|
| 103 |
-
|
| 104 |
-
def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling=0, qr_uniform_q=False, device=None):
|
| 105 |
-
nb_full_blocks = int(nb_rows / nb_columns)
|
| 106 |
-
block_list = []
|
| 107 |
-
|
| 108 |
-
for _ in range(nb_full_blocks):
|
| 109 |
-
block_list.append(orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device))
|
| 110 |
-
|
| 111 |
-
remaining_rows = nb_rows - nb_full_blocks * nb_columns
|
| 112 |
-
if remaining_rows > 0: block_list.append(orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device)[:remaining_rows])
|
| 113 |
-
|
| 114 |
-
if scaling == 0: multiplier = torch.randn((nb_rows, nb_columns), device=device).norm(dim=1)
|
| 115 |
-
elif scaling == 1: multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device=device)
|
| 116 |
-
else: raise ValueError(f"{scaling} != 0, 1")
|
| 117 |
-
|
| 118 |
-
return torch.diag(multiplier) @ torch.cat(block_list)
|
| 119 |
-
|
| 120 |
-
def calc_same_padding(kernel_size):
|
| 121 |
-
pad = kernel_size // 2
|
| 122 |
-
return (pad, pad - (kernel_size + 1) % 2)
|
| 123 |
-
|
| 124 |
-
def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device=None):
|
| 125 |
-
b, h, *_ = data.shape
|
| 126 |
-
|
| 127 |
-
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.0
|
| 128 |
-
ratio = projection_matrix.shape[0] ** -0.5
|
| 129 |
-
|
| 130 |
-
data_dash = torch.einsum("...id,...jd->...ij", (data_normalizer * data), repeat(projection_matrix, "j d -> b h j d", b=b, h=h).type_as(data))
|
| 131 |
-
diag_data = ((torch.sum(data**2, dim=-1) / 2.0) * (data_normalizer**2)).unsqueeze(dim=-1)
|
| 132 |
-
|
| 133 |
-
return (ratio * (torch.exp(data_dash - diag_data - torch.max(data_dash, dim=-1, keepdim=True).values) + eps) if is_query else ratio * (torch.exp(data_dash - diag_data + eps))).type_as(data)
|
| 134 |
-
|
| 135 |
-
def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
|
| 136 |
-
try:
|
| 137 |
-
data, sample_rate = sf.read(full_path, always_2d=True)
|
| 138 |
-
except Exception as e:
|
| 139 |
-
print(f"{full_path}: {e}")
|
| 140 |
-
|
| 141 |
-
if return_empty_on_exception: return [], sample_rate or target_sr or 48000
|
| 142 |
-
else: raise
|
| 143 |
-
|
| 144 |
-
data = data[:, 0] if len(data.shape) > 1 else data
|
| 145 |
-
assert len(data) > 2
|
| 146 |
-
|
| 147 |
-
max_mag = (-np.iinfo(data.dtype).min if np.issubdtype(data.dtype, np.integer) else max(np.amax(data), -np.amin(data)))
|
| 148 |
-
data = torch.FloatTensor(data.astype(np.float32)) / ((2**31) + 1 if max_mag > (2**15) else ((2**15) + 1 if max_mag > 1.01 else 1.0))
|
| 149 |
-
|
| 150 |
-
if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception: return [], sample_rate or target_sr or 48000
|
| 151 |
-
|
| 152 |
-
if target_sr is not None and sample_rate != target_sr:
|
| 153 |
-
data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sample_rate, target_sr=target_sr))
|
| 154 |
-
sample_rate = target_sr
|
| 155 |
-
|
| 156 |
-
return data, sample_rate
|
| 157 |
-
|
| 158 |
-
def torch_interp(x, xp, fp):
|
| 159 |
-
sort_idx = torch.argsort(xp)
|
| 160 |
-
|
| 161 |
-
xp = xp[sort_idx]
|
| 162 |
-
fp = fp[sort_idx]
|
| 163 |
-
|
| 164 |
-
right_idxs = torch.searchsorted(xp, x).clamp(max=len(xp) - 1)
|
| 165 |
-
left_idxs = (right_idxs - 1).clamp(min=0)
|
| 166 |
-
|
| 167 |
-
x_left = xp[left_idxs]
|
| 168 |
-
y_left = fp[left_idxs]
|
| 169 |
-
|
| 170 |
-
interp_vals = y_left + ((x - x_left) * (fp[right_idxs] - y_left) / (xp[right_idxs] - x_left))
|
| 171 |
-
interp_vals[x < xp[0]] = fp[0]
|
| 172 |
-
interp_vals[x > xp[-1]] = fp[-1]
|
| 173 |
-
|
| 174 |
-
return interp_vals
|
| 175 |
-
|
| 176 |
-
def batch_interp_with_replacement_detach(uv, f0):
|
| 177 |
-
result = f0.clone()
|
| 178 |
-
|
| 179 |
-
for i in range(uv.shape[0]):
|
| 180 |
-
interp_vals = torch_interp(torch.where(uv[i])[-1], torch.where(~uv[i])[-1], f0[i][~uv[i]]).detach()
|
| 181 |
-
result[i][uv[i]] = interp_vals
|
| 182 |
-
|
| 183 |
-
return result
|
| 184 |
-
|
| 185 |
-
def spawn_model(args):
|
| 186 |
-
return CFNaiveMelPE(input_channels=catch_none_args_must(args.mel.num_mels, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.mel.num_mels is None"), out_dims=catch_none_args_must(args.model.out_dims, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.out_dims is None"), hidden_dims=catch_none_args_must(args.model.hidden_dims, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.hidden_dims is None"), n_layers=catch_none_args_must(args.model.n_layers, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.n_layers is None"), n_heads=catch_none_args_must(args.model.n_heads, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.n_heads is None"), f0_max=catch_none_args_must(args.model.f0_max, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.f0_max is None"), f0_min=catch_none_args_must(args.model.f0_min, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.f0_min is None"), use_fa_norm=catch_none_args_must(args.model.use_fa_norm, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.use_fa_norm is None"), conv_only=catch_none_args_opti(args.model.conv_only, default=False, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.conv_only is None"), conv_dropout=catch_none_args_opti(args.model.conv_dropout, default=0.0, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.conv_dropout is None"), atten_dropout=catch_none_args_opti(args.model.atten_dropout, default=0.0, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.atten_dropout is None"), use_harmonic_emb=catch_none_args_opti(args.model.use_harmonic_emb, default=False, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.use_harmonic_emb is None"))
|
| 187 |
-
|
| 188 |
-
def catch_none_args_must(x, func_name, warning_str):
|
| 189 |
-
level = "ERROR"
|
| 190 |
-
|
| 191 |
-
if x is None:
|
| 192 |
-
print(f' [{level}] {warning_str}')
|
| 193 |
-
print(f' [{level}] > {func_name}')
|
| 194 |
-
raise ValueError(f' [{level}] {warning_str}')
|
| 195 |
-
else: return x
|
| 196 |
-
|
| 197 |
-
def catch_none_args_opti(x, default, func_name, warning_str=None, level='WARN'):
|
| 198 |
-
return default if x is None else x
|
| 199 |
-
|
| 200 |
-
def spawn_wav2mel(args, device = None):
|
| 201 |
-
_type = args.mel.type
|
| 202 |
-
|
| 203 |
-
if (str(_type).lower() == 'none') or (str(_type).lower() == 'default'): _type = 'default'
|
| 204 |
-
elif str(_type).lower() == 'stft': _type = 'stft'
|
| 205 |
-
|
| 206 |
-
wav2mel = Wav2MelModule(sr=catch_none_args_opti(args.mel.sr, default=16000, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.sr is None'), n_mels=catch_none_args_opti(args.mel.num_mels, default=128, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.num_mels is None'), n_fft=catch_none_args_opti(args.mel.n_fft, default=1024, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.n_fft is None'), win_size=catch_none_args_opti(args.mel.win_size, default=1024, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.win_size is None'), hop_length=catch_none_args_opti(args.mel.hop_size, default=160, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.hop_size is None'), fmin=catch_none_args_opti(args.mel.fmin, default=0, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.fmin is None'), fmax=catch_none_args_opti(args.mel.fmax, default=8000, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.fmax is None'), clip_val=1e-05, mel_type=_type)
|
| 207 |
-
device = catch_none_args_opti(device, default='cpu', func_name='torchfcpe.tools.spawn_wav2mel', warning_str='.device is None')
|
| 208 |
-
|
| 209 |
-
return wav2mel.to(torch.device(device))
|
| 210 |
-
|
| 211 |
-
def ensemble_f0(f0s, key_shift_list, tta_uv_penalty):
|
| 212 |
-
device = f0s.device
|
| 213 |
-
f0s = f0s / (torch.pow(2, torch.tensor(key_shift_list, device=device).to(device).unsqueeze(0).unsqueeze(0) / 12))
|
| 214 |
-
|
| 215 |
-
notes = torch.log2(f0s / 440) * 12 + 69
|
| 216 |
-
notes[notes < 0] = 0
|
| 217 |
-
|
| 218 |
-
uv_penalty = tta_uv_penalty**2
|
| 219 |
-
dp = torch.zeros_like(notes, device=device)
|
| 220 |
-
|
| 221 |
-
backtrack = torch.zeros_like(notes, device=device).long()
|
| 222 |
-
dp[:, 0, :] = (notes[:, 0, :] <= 0) * uv_penalty
|
| 223 |
-
|
| 224 |
-
for t in range(1, notes.size(1)):
|
| 225 |
-
penalty = torch.zeros([notes.size(0), notes.size(2), notes.size(2)], device=device)
|
| 226 |
-
t_uv = notes[:, t, :] <= 0
|
| 227 |
-
penalty += uv_penalty * t_uv.unsqueeze(1)
|
| 228 |
-
|
| 229 |
-
t1_uv = notes[:, t - 1, :] <= 0
|
| 230 |
-
l2 = torch.pow((notes[:, t - 1, :].unsqueeze(-1) - notes[:, t, :].unsqueeze(1)) * (~t1_uv).unsqueeze(-1) * (~t_uv).unsqueeze(1), 2) - 0.5
|
| 231 |
-
l2 = l2 * (l2 > 0)
|
| 232 |
-
|
| 233 |
-
penalty += l2
|
| 234 |
-
penalty += t1_uv.unsqueeze(-1) * (~t_uv).unsqueeze(1) * uv_penalty * 2
|
| 235 |
-
|
| 236 |
-
min_value, min_indices = torch.min(dp[:, t - 1, :].unsqueeze(-1) + penalty, dim=1)
|
| 237 |
-
dp[:, t, :] = min_value
|
| 238 |
-
backtrack[:, t, :] = min_indices
|
| 239 |
-
|
| 240 |
-
t = f0s.size(1) - 1
|
| 241 |
-
f0_result = torch.zeros_like(f0s[:, :, 0], device=device)
|
| 242 |
-
min_indices = torch.argmin(dp[:, t, :], dim=-1)
|
| 243 |
-
|
| 244 |
-
for i in range(0, t + 1):
|
| 245 |
-
f0_result[:, t - i] = f0s[:, t - i, min_indices]
|
| 246 |
-
min_indices = backtrack[:, t - i, min_indices]
|
| 247 |
-
|
| 248 |
-
return f0_result.unsqueeze(-1)
|
| 249 |
-
|
| 250 |
-
class LocalAttention(nn.Module):
|
| 251 |
-
def __init__(self, window_size, causal = False, look_backward = 1, look_forward = None, dropout = 0., shared_qk = False, rel_pos_emb_config = None, dim = None, autopad = False, exact_windowsize = False, scale = None, use_rotary_pos_emb = True, use_xpos = False, xpos_scale_base = None):
|
| 252 |
-
super().__init__()
|
| 253 |
-
look_forward = default(look_forward, 0 if causal else 1)
|
| 254 |
-
assert not (causal and look_forward > 0)
|
| 255 |
-
self.scale = scale
|
| 256 |
-
self.window_size = window_size
|
| 257 |
-
self.autopad = autopad
|
| 258 |
-
self.exact_windowsize = exact_windowsize
|
| 259 |
-
self.causal = causal
|
| 260 |
-
self.look_backward = look_backward
|
| 261 |
-
self.look_forward = look_forward
|
| 262 |
-
self.dropout = nn.Dropout(dropout)
|
| 263 |
-
self.shared_qk = shared_qk
|
| 264 |
-
self.rel_pos = None
|
| 265 |
-
self.use_xpos = use_xpos
|
| 266 |
-
|
| 267 |
-
if use_rotary_pos_emb and (exists(rel_pos_emb_config) or exists(dim)):
|
| 268 |
-
if exists(rel_pos_emb_config): dim = rel_pos_emb_config[0]
|
| 269 |
-
self.rel_pos = SinusoidalEmbeddings(dim, use_xpos = use_xpos, scale_base = default(xpos_scale_base, window_size // 2))
|
| 270 |
-
|
| 271 |
-
def forward(self, q, k, v, mask = None, input_mask = None, attn_bias = None, window_size = None):
|
| 272 |
-
mask = default(mask, input_mask)
|
| 273 |
-
assert not (exists(window_size) and not self.use_xpos)
|
| 274 |
-
|
| 275 |
-
_, autopad, pad_value, window_size, causal, look_backward, look_forward, shared_qk = q.shape, self.autopad, -1, default(window_size, self.window_size), self.causal, self.look_backward, self.look_forward, self.shared_qk
|
| 276 |
-
(q, packed_shape), (k, _), (v, _) = map(lambda t: pack([t], '* n d'), (q, k, v))
|
| 277 |
-
|
| 278 |
-
if autopad:
|
| 279 |
-
orig_seq_len = q.shape[1]
|
| 280 |
-
(_, q), (_, k), (_, v) = map(lambda t: pad_to_multiple(t, self.window_size, dim = -2), (q, k, v))
|
| 281 |
-
|
| 282 |
-
b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype
|
| 283 |
-
scale = default(self.scale, dim_head ** -0.5)
|
| 284 |
-
|
| 285 |
-
assert (n % window_size) == 0
|
| 286 |
-
windows = n // window_size
|
| 287 |
-
|
| 288 |
-
if shared_qk: k = l2norm(k)
|
| 289 |
-
|
| 290 |
-
seq = torch.arange(n, device = device)
|
| 291 |
-
b_t = rearrange(seq, '(w n) -> 1 w n', w = windows, n = window_size)
|
| 292 |
-
bq, bk, bv = map(lambda t: rearrange(t, 'b (w n) d -> b w n d', w = windows), (q, k, v))
|
| 293 |
-
|
| 294 |
-
bq = bq * scale
|
| 295 |
-
look_around_kwargs = dict(backward = look_backward, forward = look_forward, pad_value = pad_value)
|
| 296 |
-
|
| 297 |
-
bk = look_around(bk, **look_around_kwargs)
|
| 298 |
-
bv = look_around(bv, **look_around_kwargs)
|
| 299 |
-
|
| 300 |
-
if exists(self.rel_pos):
|
| 301 |
-
pos_emb, xpos_scale = self.rel_pos(bk)
|
| 302 |
-
bq, bk = apply_rotary_pos_emb(bq, bk, pos_emb, scale = xpos_scale)
|
| 303 |
-
|
| 304 |
-
bq_t = b_t
|
| 305 |
-
bq_k = look_around(b_t, **look_around_kwargs)
|
| 306 |
-
|
| 307 |
-
bq_t = rearrange(bq_t, '... i -> ... i 1')
|
| 308 |
-
bq_k = rearrange(bq_k, '... j -> ... 1 j')
|
| 309 |
-
|
| 310 |
-
pad_mask = bq_k == pad_value
|
| 311 |
-
sim = einsum('b h i e, b h j e -> b h i j', bq, bk)
|
| 312 |
-
|
| 313 |
-
if exists(attn_bias):
|
| 314 |
-
heads = attn_bias.shape[0]
|
| 315 |
-
assert (b % heads) == 0
|
| 316 |
-
|
| 317 |
-
attn_bias = repeat(attn_bias, 'h i j -> (b h) 1 i j', b = b // heads)
|
| 318 |
-
sim = sim + attn_bias
|
| 319 |
-
|
| 320 |
-
mask_value = max_neg_value(sim)
|
| 321 |
-
|
| 322 |
-
if shared_qk:
|
| 323 |
-
self_mask = bq_t == bq_k
|
| 324 |
-
sim = sim.masked_fill(self_mask, -5e4)
|
| 325 |
-
del self_mask
|
| 326 |
-
|
| 327 |
-
if causal:
|
| 328 |
-
causal_mask = bq_t < bq_k
|
| 329 |
-
if self.exact_windowsize: causal_mask = causal_mask | (bq_t > (bq_k + (self.window_size * self.look_backward)))
|
| 330 |
-
sim = sim.masked_fill(causal_mask, mask_value)
|
| 331 |
-
del causal_mask
|
| 332 |
-
|
| 333 |
-
sim = sim.masked_fill(((bq_k - (self.window_size * self.look_forward)) > bq_t) | (bq_t > (bq_k + (self.window_size * self.look_backward))) | pad_mask, mask_value) if not causal and self.exact_windowsize else sim.masked_fill(pad_mask, mask_value)
|
| 334 |
-
|
| 335 |
-
if exists(mask):
|
| 336 |
-
batch = mask.shape[0]
|
| 337 |
-
assert (b % batch) == 0
|
| 338 |
-
|
| 339 |
-
h = b // mask.shape[0]
|
| 340 |
-
if autopad: _, mask = pad_to_multiple(mask, window_size, dim = -1, value = False)
|
| 341 |
-
|
| 342 |
-
mask = repeat(rearrange(look_around(rearrange(mask, '... (w n) -> (...) w n', w = windows, n = window_size), **{**look_around_kwargs, 'pad_value': False}), '... j -> ... 1 j'), 'b ... -> (b h) ...', h = h)
|
| 343 |
-
sim = sim.masked_fill(~mask, mask_value)
|
| 344 |
-
|
| 345 |
-
del mask
|
| 346 |
-
|
| 347 |
-
out = rearrange(einsum('b h i j, b h j e -> b h i e', self.dropout(sim.softmax(dim = -1)), bv), 'b w n d -> b (w n) d')
|
| 348 |
-
if autopad: out = out[:, :orig_seq_len, :]
|
| 349 |
-
|
| 350 |
-
out, *_ = unpack(out, packed_shape, '* n d')
|
| 351 |
-
return out
|
| 352 |
-
|
| 353 |
-
class SinusoidalEmbeddings(nn.Module):
|
| 354 |
-
def __init__(self, dim, scale_base = None, use_xpos = False, theta = 10000):
|
| 355 |
-
super().__init__()
|
| 356 |
-
inv_freq = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
| 357 |
-
self.register_buffer('inv_freq', inv_freq)
|
| 358 |
-
self.use_xpos = use_xpos
|
| 359 |
-
self.scale_base = scale_base
|
| 360 |
-
assert not (use_xpos and not exists(scale_base))
|
| 361 |
-
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
| 362 |
-
self.register_buffer('scale', scale, persistent = False)
|
| 363 |
-
|
| 364 |
-
def forward(self, x):
|
| 365 |
-
seq_len, device = x.shape[-2], x.device
|
| 366 |
-
t = torch.arange(seq_len, device = x.device).type_as(self.inv_freq)
|
| 367 |
-
|
| 368 |
-
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
|
| 369 |
-
freqs = torch.cat((freqs, freqs), dim = -1)
|
| 370 |
-
|
| 371 |
-
if not self.use_xpos: return freqs, torch.ones(1, device = device)
|
| 372 |
-
|
| 373 |
-
power = (t - (seq_len // 2)) / self.scale_base
|
| 374 |
-
scale = self.scale ** rearrange(power, 'n -> n 1')
|
| 375 |
-
|
| 376 |
-
return freqs, torch.cat((scale, scale), dim = -1)
|
| 377 |
-
|
| 378 |
-
class STFT:
|
| 379 |
-
def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
|
| 380 |
-
self.target_sr = sr
|
| 381 |
-
self.n_mels = n_mels
|
| 382 |
-
self.n_fft = n_fft
|
| 383 |
-
self.win_size = win_size
|
| 384 |
-
self.hop_length = hop_length
|
| 385 |
-
self.fmin = fmin
|
| 386 |
-
self.fmax = fmax
|
| 387 |
-
self.clip_val = clip_val
|
| 388 |
-
self.mel_basis = {}
|
| 389 |
-
self.hann_window = {}
|
| 390 |
-
|
| 391 |
-
def get_mel(self, y, keyshift=0, speed=1, center=False, train=False):
|
| 392 |
-
n_fft = self.n_fft
|
| 393 |
-
win_size = self.win_size
|
| 394 |
-
hop_length = self.hop_length
|
| 395 |
-
fmax = self.fmax
|
| 396 |
-
factor = 2 ** (keyshift / 12)
|
| 397 |
-
win_size_new = int(np.round(win_size * factor))
|
| 398 |
-
hop_length_new = int(np.round(hop_length * speed))
|
| 399 |
-
mel_basis = self.mel_basis if not train else {}
|
| 400 |
-
hann_window = self.hann_window if not train else {}
|
| 401 |
-
mel_basis_key = str(fmax) + "_" + str(y.device)
|
| 402 |
-
|
| 403 |
-
if mel_basis_key not in mel_basis: mel_basis[mel_basis_key] = torch.from_numpy(librosa_mel_fn(sr=self.target_sr, n_fft=n_fft, n_mels=self.n_mels, fmin=self.fmin, fmax=fmax)).float().to(y.device)
|
| 404 |
-
keyshift_key = str(keyshift) + "_" + str(y.device)
|
| 405 |
-
if keyshift_key not in hann_window: hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
|
| 406 |
-
|
| 407 |
-
pad_left = (win_size_new - hop_length_new) // 2
|
| 408 |
-
pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left)
|
| 409 |
-
|
| 410 |
-
spec = torch.stft(F.pad(y.unsqueeze(1), (pad_left, pad_right), mode="reflect" if pad_right < y.size(-1) else "constant").squeeze(1), int(np.round(n_fft * factor)), hop_length=hop_length_new, win_length=win_size_new, window=hann_window[keyshift_key], center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
|
| 411 |
-
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9))
|
| 412 |
-
|
| 413 |
-
if keyshift != 0:
|
| 414 |
-
size = n_fft // 2 + 1
|
| 415 |
-
resize = spec.size(1)
|
| 416 |
-
spec = (F.pad(spec, (0, 0, 0, size - resize)) if resize < size else spec[:, :size, :]) * win_size / win_size_new
|
| 417 |
-
|
| 418 |
-
return dynamic_range_compression_torch(torch.matmul(mel_basis[mel_basis_key], spec), clip_val=self.clip_val)
|
| 419 |
-
|
| 420 |
-
def __call__(self, audiopath):
|
| 421 |
-
audio, _ = load_wav_to_torch(audiopath, target_sr=self.target_sr)
|
| 422 |
-
return self.get_mel(audio.unsqueeze(0)).squeeze(0)
|
| 423 |
-
|
| 424 |
-
class PCmer(nn.Module):
|
| 425 |
-
def __init__(self, num_layers, num_heads, dim_model, dim_keys, dim_values, residual_dropout, attention_dropout):
|
| 426 |
-
super().__init__()
|
| 427 |
-
self.num_layers = num_layers
|
| 428 |
-
self.num_heads = num_heads
|
| 429 |
-
self.dim_model = dim_model
|
| 430 |
-
self.dim_values = dim_values
|
| 431 |
-
self.dim_keys = dim_keys
|
| 432 |
-
self.residual_dropout = residual_dropout
|
| 433 |
-
self.attention_dropout = attention_dropout
|
| 434 |
-
self._layers = nn.ModuleList([_EncoderLayer(self) for _ in range(num_layers)])
|
| 435 |
-
|
| 436 |
-
def forward(self, phone, mask=None):
|
| 437 |
-
for layer in self._layers:
|
| 438 |
-
phone = layer(phone, mask)
|
| 439 |
-
|
| 440 |
-
return phone
|
| 441 |
-
|
| 442 |
-
class _EncoderLayer(nn.Module):
|
| 443 |
-
def __init__(self, parent):
|
| 444 |
-
super().__init__()
|
| 445 |
-
self.conformer = ConformerConvModule_LEGACY(parent.dim_model)
|
| 446 |
-
self.norm = nn.LayerNorm(parent.dim_model)
|
| 447 |
-
self.dropout = nn.Dropout(parent.residual_dropout)
|
| 448 |
-
self.attn = SelfAttention(dim=parent.dim_model, heads=parent.num_heads, causal=False)
|
| 449 |
-
|
| 450 |
-
def forward(self, phone, mask=None):
|
| 451 |
-
phone = phone + (self.attn(self.norm(phone), mask=mask))
|
| 452 |
-
return phone + (self.conformer(phone))
|
| 453 |
-
|
| 454 |
-
class ConformerNaiveEncoder(nn.Module):
|
| 455 |
-
def __init__(self, num_layers, num_heads, dim_model, use_norm = False, conv_only = False, conv_dropout = 0, atten_dropout = 0):
|
| 456 |
-
super().__init__()
|
| 457 |
-
self.num_layers = num_layers
|
| 458 |
-
self.num_heads = num_heads
|
| 459 |
-
self.dim_model = dim_model
|
| 460 |
-
self.use_norm = use_norm
|
| 461 |
-
self.residual_dropout = 0.1
|
| 462 |
-
self.attention_dropout = 0.1
|
| 463 |
-
self.encoder_layers = nn.ModuleList([CFNEncoderLayer(dim_model, num_heads, use_norm, conv_only, conv_dropout, atten_dropout) for _ in range(num_layers)])
|
| 464 |
-
|
| 465 |
-
def forward(self, x, mask=None):
|
| 466 |
-
for (_, layer) in enumerate(self.encoder_layers):
|
| 467 |
-
x = layer(x, mask)
|
| 468 |
-
|
| 469 |
-
return x
|
| 470 |
-
|
| 471 |
-
class CFNaiveMelPE(nn.Module):
|
| 472 |
-
def __init__(self, input_channels, out_dims, hidden_dims = 512, n_layers = 6, n_heads = 8, f0_max = 1975.5, f0_min = 32.70, use_fa_norm = False, conv_only = False, conv_dropout = 0, atten_dropout = 0, use_harmonic_emb = False):
|
| 473 |
-
super().__init__()
|
| 474 |
-
self.input_channels = input_channels
|
| 475 |
-
self.out_dims = out_dims
|
| 476 |
-
self.hidden_dims = hidden_dims
|
| 477 |
-
self.n_layers = n_layers
|
| 478 |
-
self.n_heads = n_heads
|
| 479 |
-
self.f0_max = f0_max
|
| 480 |
-
self.f0_min = f0_min
|
| 481 |
-
self.use_fa_norm = use_fa_norm
|
| 482 |
-
self.residual_dropout = 0.1
|
| 483 |
-
self.attention_dropout = 0.1
|
| 484 |
-
self.harmonic_emb = nn.Embedding(9, hidden_dims) if use_harmonic_emb else None
|
| 485 |
-
self.input_stack = nn.Sequential(nn.Conv1d(input_channels, hidden_dims, 3, 1, 1), nn.GroupNorm(4, hidden_dims), nn.LeakyReLU(), nn.Conv1d(hidden_dims, hidden_dims, 3, 1, 1))
|
| 486 |
-
self.net = ConformerNaiveEncoder(num_layers=n_layers, num_heads=n_heads, dim_model=hidden_dims, use_norm=use_fa_norm, conv_only=conv_only, conv_dropout=conv_dropout, atten_dropout=atten_dropout)
|
| 487 |
-
self.norm = nn.LayerNorm(hidden_dims)
|
| 488 |
-
self.output_proj = weight_norm(nn.Linear(hidden_dims, out_dims))
|
| 489 |
-
self.cent_table_b = torch.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], out_dims).detach()
|
| 490 |
-
self.register_buffer("cent_table", self.cent_table_b)
|
| 491 |
-
self.gaussian_blurred_cent_mask_b = (1200 * torch.log2(torch.Tensor([self.f0_max / 10.])))[0].detach()
|
| 492 |
-
self.register_buffer("gaussian_blurred_cent_mask", self.gaussian_blurred_cent_mask_b)
|
| 493 |
-
|
| 494 |
-
def forward(self, x, _h_emb=None):
|
| 495 |
-
x = self.input_stack(x.transpose(-1, -2)).transpose(-1, -2)
|
| 496 |
-
if self.harmonic_emb is not None: x = x + self.harmonic_emb(torch.LongTensor([0]).to(x.device)) if _h_emb is None else x + self.harmonic_emb(torch.LongTensor([int(_h_emb)]).to(x.device))
|
| 497 |
-
|
| 498 |
-
return torch.sigmoid(self.output_proj(self.norm(self.net(x))))
|
| 499 |
-
|
| 500 |
-
@torch.no_grad()
|
| 501 |
-
def latent2cents_decoder(self, y, threshold = 0.05, mask = True):
|
| 502 |
-
B, N, _ = y.size()
|
| 503 |
-
ci = self.cent_table[None, None, :].expand(B, N, -1)
|
| 504 |
-
rtn = torch.sum(ci * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True)
|
| 505 |
-
|
| 506 |
-
if mask:
|
| 507 |
-
confident = torch.max(y, dim=-1, keepdim=True)[0]
|
| 508 |
-
confident_mask = torch.ones_like(confident)
|
| 509 |
-
confident_mask[confident <= threshold] = float("-INF")
|
| 510 |
-
rtn = rtn * confident_mask
|
| 511 |
-
|
| 512 |
-
return rtn
|
| 513 |
-
|
| 514 |
-
@torch.no_grad()
|
| 515 |
-
def latent2cents_local_decoder(self, y, threshold = 0.05, mask = True):
|
| 516 |
-
B, N, _ = y.size()
|
| 517 |
-
ci = self.cent_table[None, None, :].expand(B, N, -1)
|
| 518 |
-
confident, max_index = torch.max(y, dim=-1, keepdim=True)
|
| 519 |
-
|
| 520 |
-
local_argmax_index = torch.arange(0, 9).to(max_index.device) + (max_index - 4)
|
| 521 |
-
local_argmax_index[local_argmax_index < 0] = 0
|
| 522 |
-
local_argmax_index[local_argmax_index >= self.out_dims] = self.out_dims - 1
|
| 523 |
-
|
| 524 |
-
y_l = torch.gather(y, -1, local_argmax_index)
|
| 525 |
-
rtn = torch.sum(torch.gather(ci, -1, local_argmax_index) * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True)
|
| 526 |
-
|
| 527 |
-
if mask:
|
| 528 |
-
confident_mask = torch.ones_like(confident)
|
| 529 |
-
confident_mask[confident <= threshold] = float("-INF")
|
| 530 |
-
|
| 531 |
-
rtn = rtn * confident_mask
|
| 532 |
-
|
| 533 |
-
return rtn
|
| 534 |
-
|
| 535 |
-
@torch.no_grad()
|
| 536 |
-
def infer(self, mel, decoder = "local_argmax", threshold = 0.05):
|
| 537 |
-
latent = self.forward(mel)
|
| 538 |
-
|
| 539 |
-
if decoder == "argmax": cents = self.latent2cents_local_decoder
|
| 540 |
-
elif decoder == "local_argmax": cents = self.latent2cents_local_decoder
|
| 541 |
-
|
| 542 |
-
return self.cent_to_f0(cents(latent, threshold=threshold))
|
| 543 |
-
|
| 544 |
-
@torch.no_grad()
|
| 545 |
-
def cent_to_f0(self, cent: torch.Tensor) -> torch.Tensor:
|
| 546 |
-
return 10 * 2 ** (cent / 1200)
|
| 547 |
-
|
| 548 |
-
@torch.no_grad()
|
| 549 |
-
def f0_to_cent(self, f0):
|
| 550 |
-
return 1200 * torch.log2(f0 / 10)
|
| 551 |
-
|
| 552 |
-
class CFNEncoderLayer(nn.Module):
|
| 553 |
-
def __init__(self, dim_model, num_heads = 8, use_norm = False, conv_only = False, conv_dropout = 0, atten_dropout = 0):
|
| 554 |
-
super().__init__()
|
| 555 |
-
|
| 556 |
-
self.conformer = nn.Sequential(ConformerConvModule(dim_model), nn.Dropout(conv_dropout)) if conv_dropout > 0 else ConformerConvModule(dim_model)
|
| 557 |
-
self.norm = nn.LayerNorm(dim_model)
|
| 558 |
-
|
| 559 |
-
self.dropout = nn.Dropout(0.1)
|
| 560 |
-
self.attn = SelfAttention(dim=dim_model, heads=num_heads, causal=False, use_norm=use_norm, dropout=atten_dropout) if not conv_only else None
|
| 561 |
-
|
| 562 |
-
def forward(self, x, mask=None):
|
| 563 |
-
if self.attn is not None: x = x + (self.attn(self.norm(x), mask=mask))
|
| 564 |
-
return x + (self.conformer(x))
|
| 565 |
-
|
| 566 |
-
class Swish(nn.Module):
|
| 567 |
-
def forward(self, x):
|
| 568 |
-
return x * x.sigmoid()
|
| 569 |
-
|
| 570 |
-
class Transpose(nn.Module):
|
| 571 |
-
def __init__(self, dims):
|
| 572 |
-
super().__init__()
|
| 573 |
-
assert len(dims) == 2, "dims == 2"
|
| 574 |
-
|
| 575 |
-
self.dims = dims
|
| 576 |
-
|
| 577 |
-
def forward(self, x):
|
| 578 |
-
return x.transpose(*self.dims)
|
| 579 |
-
|
| 580 |
-
class GLU(nn.Module):
|
| 581 |
-
def __init__(self, dim):
|
| 582 |
-
super().__init__()
|
| 583 |
-
self.dim = dim
|
| 584 |
-
|
| 585 |
-
def forward(self, x):
|
| 586 |
-
out, gate = x.chunk(2, dim=self.dim)
|
| 587 |
-
return out * gate.sigmoid()
|
| 588 |
-
|
| 589 |
-
class DepthWiseConv1d_LEGACY(nn.Module):
|
| 590 |
-
def __init__(self, chan_in, chan_out, kernel_size, padding):
|
| 591 |
-
super().__init__()
|
| 592 |
-
self.padding = padding
|
| 593 |
-
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in)
|
| 594 |
-
|
| 595 |
-
def forward(self, x):
|
| 596 |
-
return self.conv(F.pad(x, self.padding))
|
| 597 |
-
|
| 598 |
-
class DepthWiseConv1d(nn.Module):
|
| 599 |
-
def __init__(self, chan_in, chan_out, kernel_size, padding, groups):
|
| 600 |
-
super().__init__()
|
| 601 |
-
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=kernel_size, padding=padding, groups=groups)
|
| 602 |
-
|
| 603 |
-
def forward(self, x):
|
| 604 |
-
return self.conv(x)
|
| 605 |
-
|
| 606 |
-
class ConformerConvModule_LEGACY(nn.Module):
|
| 607 |
-
def __init__(self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0):
|
| 608 |
-
super().__init__()
|
| 609 |
-
inner_dim = dim * expansion_factor
|
| 610 |
-
self.net = nn.Sequential(nn.LayerNorm(dim), Transpose((1, 2)), nn.Conv1d(dim, inner_dim * 2, 1), GLU(dim=1), DepthWiseConv1d_LEGACY(inner_dim, inner_dim, kernel_size=kernel_size, padding=(calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0))), Swish(), nn.Conv1d(inner_dim, dim, 1), Transpose((1, 2)), nn.Dropout(dropout))
|
| 611 |
-
|
| 612 |
-
def forward(self, x):
|
| 613 |
-
return self.net(x)
|
| 614 |
-
|
| 615 |
-
class ConformerConvModule(nn.Module):
|
| 616 |
-
def __init__(self, dim, expansion_factor=2, kernel_size=31, dropout=0):
|
| 617 |
-
super().__init__()
|
| 618 |
-
inner_dim = dim * expansion_factor
|
| 619 |
-
|
| 620 |
-
self.net = nn.Sequential(nn.LayerNorm(dim), Transpose((1, 2)), nn.Conv1d(dim, inner_dim * 2, 1), nn.GLU(dim=1), DepthWiseConv1d(inner_dim, inner_dim, kernel_size=kernel_size, padding=calc_same_padding(kernel_size)[0], groups=inner_dim), nn.SiLU(), nn.Conv1d(inner_dim, dim, 1), Transpose((1, 2)), nn.Dropout(dropout))
|
| 621 |
-
|
| 622 |
-
def forward(self, x):
|
| 623 |
-
return self.net(x)
|
| 624 |
-
|
| 625 |
-
class FastAttention(nn.Module):
|
| 626 |
-
def __init__(self, dim_heads, nb_features=None, ortho_scaling=0, causal=False, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, no_projection=False):
|
| 627 |
-
super().__init__()
|
| 628 |
-
nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
|
| 629 |
-
self.dim_heads = dim_heads
|
| 630 |
-
self.nb_features = nb_features
|
| 631 |
-
self.ortho_scaling = ortho_scaling
|
| 632 |
-
self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows=self.nb_features, nb_columns=dim_heads, scaling=ortho_scaling, qr_uniform_q=qr_uniform_q)
|
| 633 |
-
projection_matrix = self.create_projection()
|
| 634 |
-
self.register_buffer("projection_matrix", projection_matrix)
|
| 635 |
-
self.generalized_attention = generalized_attention
|
| 636 |
-
self.kernel_fn = kernel_fn
|
| 637 |
-
self.no_projection = no_projection
|
| 638 |
-
self.causal = causal
|
| 639 |
-
|
| 640 |
-
@torch.no_grad()
|
| 641 |
-
def redraw_projection_matrix(self):
|
| 642 |
-
projections = self.create_projection()
|
| 643 |
-
self.projection_matrix.copy_(projections)
|
| 644 |
-
|
| 645 |
-
del projections
|
| 646 |
-
|
| 647 |
-
def forward(self, q, k, v):
|
| 648 |
-
if self.no_projection: q, k = q.softmax(dim=-1), (torch.exp(k) if self.causal else k.softmax(dim=-2))
|
| 649 |
-
else:
|
| 650 |
-
create_kernel = partial(softmax_kernel, projection_matrix=self.projection_matrix, device=q.device)
|
| 651 |
-
q, k = create_kernel(q, is_query=True), create_kernel(k, is_query=False)
|
| 652 |
-
|
| 653 |
-
attn_fn = linear_attention if not self.causal else self.causal_linear_fn
|
| 654 |
-
return attn_fn(q, k, None) if v is None else attn_fn(q, k, v)
|
| 655 |
-
|
| 656 |
-
class SelfAttention(nn.Module):
|
| 657 |
-
def __init__(self, dim, causal=False, heads=8, dim_head=64, local_heads=0, local_window_size=256, nb_features=None, feature_redraw_interval=1000, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, dropout=0.0, no_projection=False):
|
| 658 |
-
super().__init__()
|
| 659 |
-
assert dim % heads == 0
|
| 660 |
-
dim_head = default(dim_head, dim // heads)
|
| 661 |
-
inner_dim = dim_head * heads
|
| 662 |
-
self.fast_attention = FastAttention(dim_head, nb_features, causal=causal, generalized_attention=generalized_attention, kernel_fn=kernel_fn, qr_uniform_q=qr_uniform_q, no_projection=no_projection)
|
| 663 |
-
self.heads = heads
|
| 664 |
-
self.global_heads = heads - local_heads
|
| 665 |
-
self.local_attn = (LocalAttention(window_size=local_window_size, causal=causal, autopad=True, dropout=dropout, look_forward=int(not causal), rel_pos_emb_config=(dim_head, local_heads)) if local_heads > 0 else None)
|
| 666 |
-
self.to_q = nn.Linear(dim, inner_dim)
|
| 667 |
-
self.to_k = nn.Linear(dim, inner_dim)
|
| 668 |
-
self.to_v = nn.Linear(dim, inner_dim)
|
| 669 |
-
self.to_out = nn.Linear(inner_dim, dim)
|
| 670 |
-
self.dropout = nn.Dropout(dropout)
|
| 671 |
-
|
| 672 |
-
@torch.no_grad()
|
| 673 |
-
def redraw_projection_matrix(self):
|
| 674 |
-
self.fast_attention.redraw_projection_matrix()
|
| 675 |
-
|
| 676 |
-
def forward(self, x, context=None, mask=None, context_mask=None, name=None, inference=False, **kwargs):
|
| 677 |
-
_, _, _, h, gh = *x.shape, self.heads, self.global_heads
|
| 678 |
-
cross_attend = exists(context)
|
| 679 |
-
|
| 680 |
-
context = default(context, x)
|
| 681 |
-
context_mask = default(context_mask, mask) if not cross_attend else context_mask
|
| 682 |
-
|
| 683 |
-
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (self.to_q(x), self.to_k(context), self.to_v(context)))
|
| 684 |
-
(q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))
|
| 685 |
-
|
| 686 |
-
attn_outs = []
|
| 687 |
-
|
| 688 |
-
if not empty(q):
|
| 689 |
-
if exists(context_mask): v.masked_fill_(~context_mask[:, None, :, None], 0.0)
|
| 690 |
-
|
| 691 |
-
if cross_attend: pass
|
| 692 |
-
else: out = self.fast_attention(q, k, v)
|
| 693 |
-
|
| 694 |
-
attn_outs.append(out)
|
| 695 |
-
|
| 696 |
-
if not empty(lq):
|
| 697 |
-
assert (not cross_attend), "not cross_attend"
|
| 698 |
-
|
| 699 |
-
out = self.local_attn(lq, lk, lv, input_mask=mask)
|
| 700 |
-
attn_outs.append(out)
|
| 701 |
-
|
| 702 |
-
return self.dropout(self.to_out(rearrange(torch.cat(attn_outs, dim=1), "b h n d -> b n (h d)")))
|
| 703 |
-
|
| 704 |
-
class HannWindow(torch.nn.Module):
|
| 705 |
-
def __init__(self, win_size):
|
| 706 |
-
super().__init__()
|
| 707 |
-
self.register_buffer('window', torch.hann_window(win_size), persistent=False)
|
| 708 |
-
|
| 709 |
-
def forward(self):
|
| 710 |
-
return self.window
|
| 711 |
-
|
| 712 |
-
class FCPE_LEGACY(nn.Module):
|
| 713 |
-
def __init__(self, input_channel=128, out_dims=360, n_layers=12, n_chans=512, use_siren=False, use_full=False, loss_mse_scale=10, loss_l2_regularization=False, loss_l2_regularization_scale=1, loss_grad1_mse=False, loss_grad1_mse_scale=1, f0_max=1975.5, f0_min=32.70, confidence=False, threshold=0.05, use_input_conv=True):
|
| 714 |
-
super().__init__()
|
| 715 |
-
if use_siren: raise ValueError("Siren not support")
|
| 716 |
-
if use_full: raise ValueError("Model full not support")
|
| 717 |
-
|
| 718 |
-
self.loss_mse_scale = loss_mse_scale if (loss_mse_scale is not None) else 10
|
| 719 |
-
self.loss_l2_regularization = (loss_l2_regularization if (loss_l2_regularization is not None) else False)
|
| 720 |
-
self.loss_l2_regularization_scale = (loss_l2_regularization_scale if (loss_l2_regularization_scale is not None) else 1)
|
| 721 |
-
self.loss_grad1_mse = loss_grad1_mse if (loss_grad1_mse is not None) else False
|
| 722 |
-
self.loss_grad1_mse_scale = (loss_grad1_mse_scale if (loss_grad1_mse_scale is not None) else 1)
|
| 723 |
-
self.f0_max = f0_max if (f0_max is not None) else 1975.5
|
| 724 |
-
self.f0_min = f0_min if (f0_min is not None) else 32.70
|
| 725 |
-
self.confidence = confidence if (confidence is not None) else False
|
| 726 |
-
self.threshold = threshold if (threshold is not None) else 0.05
|
| 727 |
-
self.use_input_conv = use_input_conv if (use_input_conv is not None) else True
|
| 728 |
-
self.cent_table_b = torch.Tensor(np.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], out_dims))
|
| 729 |
-
self.register_buffer("cent_table", self.cent_table_b)
|
| 730 |
-
self.stack = nn.Sequential(nn.Conv1d(input_channel, n_chans, 3, 1, 1), nn.GroupNorm(4, n_chans), nn.LeakyReLU(), nn.Conv1d(n_chans, n_chans, 3, 1, 1))
|
| 731 |
-
self.decoder = PCmer(num_layers=n_layers, num_heads=8, dim_model=n_chans, dim_keys=n_chans, dim_values=n_chans, residual_dropout=0.1, attention_dropout=0.1)
|
| 732 |
-
self.norm = nn.LayerNorm(n_chans)
|
| 733 |
-
self.n_out = out_dims
|
| 734 |
-
self.dense_out = weight_norm(nn.Linear(n_chans, self.n_out))
|
| 735 |
-
|
| 736 |
-
def forward(self, mel, infer=True, gt_f0=None, return_hz_f0=False, cdecoder="local_argmax"):
|
| 737 |
-
if cdecoder == "argmax": self.cdecoder = self.cents_decoder
|
| 738 |
-
elif cdecoder == "local_argmax": self.cdecoder = self.cents_local_decoder
|
| 739 |
-
|
| 740 |
-
x = torch.sigmoid(self.dense_out(self.norm(self.decoder((self.stack(mel.transpose(1, 2)).transpose(1, 2) if self.use_input_conv else mel)))))
|
| 741 |
-
|
| 742 |
-
if not infer:
|
| 743 |
-
loss_all = self.loss_mse_scale * F.binary_cross_entropy(x, self.gaussian_blurred_cent(self.f0_to_cent(gt_f0)))
|
| 744 |
-
if self.loss_l2_regularization: loss_all = loss_all + l2_regularization(model=self, l2_alpha=self.loss_l2_regularization_scale)
|
| 745 |
-
x = loss_all
|
| 746 |
-
|
| 747 |
-
if infer:
|
| 748 |
-
x = self.cent_to_f0(self.cdecoder(x))
|
| 749 |
-
x = (1 + x / 700).log() if not return_hz_f0 else x
|
| 750 |
-
|
| 751 |
-
return x
|
| 752 |
-
|
| 753 |
-
def cents_decoder(self, y, mask=True):
|
| 754 |
-
B, N, _ = y.size()
|
| 755 |
-
rtn = torch.sum(self.cent_table[None, None, :].expand(B, N, -1) * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True)
|
| 756 |
-
|
| 757 |
-
if mask:
|
| 758 |
-
confident = torch.max(y, dim=-1, keepdim=True)[0]
|
| 759 |
-
confident_mask = torch.ones_like(confident)
|
| 760 |
-
|
| 761 |
-
confident_mask[confident <= self.threshold] = float("-INF")
|
| 762 |
-
rtn = rtn * confident_mask
|
| 763 |
-
|
| 764 |
-
return (rtn, confident) if self.confidence else rtn
|
| 765 |
-
|
| 766 |
-
def cents_local_decoder(self, y, mask=True):
|
| 767 |
-
B, N, _ = y.size()
|
| 768 |
-
|
| 769 |
-
confident, max_index = torch.max(y, dim=-1, keepdim=True)
|
| 770 |
-
local_argmax_index = torch.clamp(torch.arange(0, 9).to(max_index.device) + (max_index - 4), 0, self.n_out - 1)
|
| 771 |
-
|
| 772 |
-
y_l = torch.gather(y, -1, local_argmax_index)
|
| 773 |
-
rtn = torch.sum(torch.gather(self.cent_table[None, None, :].expand(B, N, -1), -1, local_argmax_index) * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True)
|
| 774 |
-
|
| 775 |
-
if mask:
|
| 776 |
-
confident_mask = torch.ones_like(confident)
|
| 777 |
-
confident_mask[confident <= self.threshold] = float("-INF")
|
| 778 |
-
rtn = rtn * confident_mask
|
| 779 |
-
|
| 780 |
-
return (rtn, confident) if self.confidence else rtn
|
| 781 |
-
|
| 782 |
-
def cent_to_f0(self, cent):
|
| 783 |
-
return 10.0 * 2 ** (cent / 1200.0)
|
| 784 |
-
|
| 785 |
-
def f0_to_cent(self, f0):
|
| 786 |
-
return 1200.0 * torch.log2(f0 / 10.0)
|
| 787 |
-
|
| 788 |
-
def gaussian_blurred_cent(self, cents):
|
| 789 |
-
B, N, _ = cents.size()
|
| 790 |
-
return torch.exp(-torch.square(self.cent_table[None, None, :].expand(B, N, -1) - cents) / 1250) * (cents > 0.1) & (cents < (1200.0 * np.log2(self.f0_max / 10.0))).float()
|
| 791 |
-
|
| 792 |
-
class InferCFNaiveMelPE(torch.nn.Module):
|
| 793 |
-
def __init__(self, args, state_dict):
|
| 794 |
-
super().__init__()
|
| 795 |
-
self.wav2mel = spawn_wav2mel(args, device="cpu")
|
| 796 |
-
self.model = spawn_model(args)
|
| 797 |
-
self.model.load_state_dict(state_dict)
|
| 798 |
-
self.model.eval()
|
| 799 |
-
self.args_dict = dict(args)
|
| 800 |
-
self.register_buffer("tensor_device_marker", torch.tensor(1.0).float(), persistent=False)
|
| 801 |
-
|
| 802 |
-
def forward(self, wav, sr, decoder_mode = "local_argmax", threshold = 0.006, key_shifts = [0]):
|
| 803 |
-
with torch.no_grad():
|
| 804 |
-
mels = rearrange(torch.stack([self.wav2mel(wav.to(self.tensor_device_marker.device), sr, keyshift=keyshift) for keyshift in key_shifts], -1), "B T C K -> (B K) T C")
|
| 805 |
-
f0s = rearrange(self.model.infer(mels, decoder=decoder_mode, threshold=threshold), "(B K) T 1 -> B T (K 1)", K=len(key_shifts))
|
| 806 |
-
|
| 807 |
-
return f0s
|
| 808 |
-
|
| 809 |
-
def infer(self, wav, sr, decoder_mode = "local_argmax", threshold = 0.006, f0_min = None, f0_max = None, interp_uv = False, output_interp_target_length = None, return_uv = False, test_time_augmentation = False, tta_uv_penalty = 12.0, tta_key_shifts = [0, -12, 12], tta_use_origin_uv=False):
|
| 810 |
-
if test_time_augmentation:
|
| 811 |
-
assert len(tta_key_shifts) > 0
|
| 812 |
-
flag = 0
|
| 813 |
-
|
| 814 |
-
if tta_use_origin_uv:
|
| 815 |
-
if 0 not in tta_key_shifts:
|
| 816 |
-
flag = 1
|
| 817 |
-
tta_key_shifts.append(0)
|
| 818 |
-
|
| 819 |
-
tta_key_shifts.sort(key=lambda x: (x if x >= 0 else -x / 2))
|
| 820 |
-
f0s = self.__call__(wav, sr, decoder_mode, threshold, tta_key_shifts)
|
| 821 |
-
f0 = ensemble_f0(f0s[:, :, flag:], tta_key_shifts[flag:], tta_uv_penalty)
|
| 822 |
-
|
| 823 |
-
f0_for_uv = f0s[:, :, [0]] if tta_use_origin_uv else f0
|
| 824 |
-
else:
|
| 825 |
-
f0 = self.__call__(wav, sr, decoder_mode, threshold)
|
| 826 |
-
f0_for_uv = f0
|
| 827 |
-
|
| 828 |
-
if f0_min is None: f0_min = self.args_dict["model"]["f0_min"]
|
| 829 |
-
|
| 830 |
-
uv = (f0_for_uv < f0_min).type(f0_for_uv.dtype)
|
| 831 |
-
f0 = f0 * (1 - uv)
|
| 832 |
-
|
| 833 |
-
if interp_uv: f0 = batch_interp_with_replacement_detach(uv.squeeze(-1).bool(), f0.squeeze(-1)).unsqueeze(-1)
|
| 834 |
-
if f0_max is not None: f0[f0 > f0_max] = f0_max
|
| 835 |
-
if output_interp_target_length is not None: f0 = F.interpolate(f0.transpose(1, 2), size=int(output_interp_target_length), mode="nearest").transpose(1, 2)
|
| 836 |
-
|
| 837 |
-
if return_uv: return f0, F.interpolate(uv.transpose(1, 2), size=int(output_interp_target_length), mode="nearest").transpose(1, 2)
|
| 838 |
-
else: return f0
|
| 839 |
-
|
| 840 |
-
class FCPEInfer_LEGACY:
|
| 841 |
-
def __init__(self, model_path, device=None, dtype=torch.float32, providers=None, onnx=False):
|
| 842 |
-
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 843 |
-
self.device = device
|
| 844 |
-
self.dtype = dtype
|
| 845 |
-
self.onnx = onnx
|
| 846 |
-
|
| 847 |
-
if self.onnx:
|
| 848 |
-
sess_options = ort.SessionOptions()
|
| 849 |
-
sess_options.log_severity_level = 3
|
| 850 |
-
|
| 851 |
-
self.model = ort.InferenceSession(decrypt_model(model_path), sess_options=sess_options, providers=providers)
|
| 852 |
-
else:
|
| 853 |
-
ckpt = torch.load(model_path, map_location=torch.device(self.device))
|
| 854 |
-
self.args = DotDict(ckpt["config"])
|
| 855 |
-
|
| 856 |
-
model = FCPE_LEGACY(input_channel=self.args.model.input_channel, out_dims=self.args.model.out_dims, n_layers=self.args.model.n_layers, n_chans=self.args.model.n_chans, use_siren=self.args.model.use_siren, use_full=self.args.model.use_full, loss_mse_scale=self.args.loss.loss_mse_scale, loss_l2_regularization=self.args.loss.loss_l2_regularization, loss_l2_regularization_scale=self.args.loss.loss_l2_regularization_scale, loss_grad1_mse=self.args.loss.loss_grad1_mse, loss_grad1_mse_scale=self.args.loss.loss_grad1_mse_scale, f0_max=self.args.model.f0_max, f0_min=self.args.model.f0_min, confidence=self.args.model.confidence)
|
| 857 |
-
model.to(self.device).to(self.dtype)
|
| 858 |
-
model.load_state_dict(ckpt["model"])
|
| 859 |
-
|
| 860 |
-
model.eval()
|
| 861 |
-
self.model = model
|
| 862 |
-
|
| 863 |
-
@torch.no_grad()
|
| 864 |
-
def __call__(self, audio, sr, threshold=0.05):
|
| 865 |
-
if not self.onnx: self.model.threshold = threshold
|
| 866 |
-
self.wav2mel = Wav2Mel(device=self.device, dtype=self.dtype)
|
| 867 |
-
|
| 868 |
-
return (torch.as_tensor(self.model.run([self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype).detach().cpu().numpy(), self.model.get_inputs()[1].name: np.array(threshold, dtype=np.float32)})[0], dtype=self.dtype, device=self.device) if self.onnx else self.model(mel=self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype), infer=True, return_hz_f0=True))
|
| 869 |
-
|
| 870 |
-
class FCPEInfer:
|
| 871 |
-
def __init__(self, model_path, device=None, dtype=torch.float32, providers=None, onnx=False):
|
| 872 |
-
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 873 |
-
self.device = device
|
| 874 |
-
self.dtype = dtype
|
| 875 |
-
self.onnx = onnx
|
| 876 |
-
|
| 877 |
-
if self.onnx:
|
| 878 |
-
sess_options = ort.SessionOptions()
|
| 879 |
-
sess_options.log_severity_level = 3
|
| 880 |
-
|
| 881 |
-
self.model = ort.InferenceSession(decrypt_model(model_path), sess_options=sess_options, providers=providers)
|
| 882 |
-
else:
|
| 883 |
-
ckpt = torch.load(model_path, map_location=torch.device(device))
|
| 884 |
-
ckpt["config_dict"]["model"]["conv_dropout"] = ckpt["config_dict"]["model"]["atten_dropout"] = 0.0
|
| 885 |
-
self.args = DotDict(ckpt["config_dict"])
|
| 886 |
-
|
| 887 |
-
model = InferCFNaiveMelPE(self.args, ckpt["model"])
|
| 888 |
-
model = model.to(device)
|
| 889 |
-
|
| 890 |
-
model.eval()
|
| 891 |
-
self.model = model
|
| 892 |
-
|
| 893 |
-
@torch.no_grad()
|
| 894 |
-
def __call__(self, audio, sr, threshold=0.05, f0_min=50, f0_max=1100, p_len=None):
|
| 895 |
-
if self.onnx: self.wav2mel = Wav2Mel(device=self.device, dtype=self.dtype)
|
| 896 |
-
return (torch.as_tensor(self.model.run([self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype).detach().cpu().numpy(), self.model.get_inputs()[1].name: np.array(threshold, dtype=np.float32)})[0], dtype=self.dtype, device=self.device) if self.onnx else self.model.infer(audio[None, :], sr, threshold=threshold, f0_min=f0_min, f0_max=f0_max, output_interp_target_length=p_len))
|
| 897 |
-
|
| 898 |
-
class MelModule(torch.nn.Module):
|
| 899 |
-
def __init__(self, sr, n_mels, n_fft, win_size, hop_length, fmin = None, fmax = None, clip_val = 1e-5, out_stft = False):
|
| 900 |
-
super().__init__()
|
| 901 |
-
if fmin is None: fmin = 0
|
| 902 |
-
if fmax is None: fmax = sr / 2
|
| 903 |
-
|
| 904 |
-
self.target_sr = sr
|
| 905 |
-
self.n_mels = n_mels
|
| 906 |
-
self.n_fft = n_fft
|
| 907 |
-
self.win_size = win_size
|
| 908 |
-
self.hop_length = hop_length
|
| 909 |
-
self.fmin = fmin
|
| 910 |
-
self.fmax = fmax
|
| 911 |
-
self.clip_val = clip_val
|
| 912 |
-
|
| 913 |
-
self.register_buffer('mel_basis', torch.tensor(librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)).float(), persistent=False)
|
| 914 |
-
self.hann_window = torch.nn.ModuleDict()
|
| 915 |
-
self.out_stft = out_stft
|
| 916 |
-
|
| 917 |
-
@torch.no_grad()
|
| 918 |
-
def __call__(self, y, key_shift = 0, speed = 1, center = False, no_cache_window = False):
|
| 919 |
-
n_fft = self.n_fft
|
| 920 |
-
win_size = self.win_size
|
| 921 |
-
hop_length = self.hop_length
|
| 922 |
-
clip_val = self.clip_val
|
| 923 |
-
|
| 924 |
-
factor = 2 ** (key_shift / 12)
|
| 925 |
-
n_fft_new = int(np.round(n_fft * factor))
|
| 926 |
-
win_size_new = int(np.round(win_size * factor))
|
| 927 |
-
hop_length_new = int(np.round(hop_length * speed))
|
| 928 |
-
|
| 929 |
-
y = y.squeeze(-1)
|
| 930 |
-
|
| 931 |
-
if torch.min(y) < -1: print('[error with torchfcpe.mel_extractor.MelModule] min ', torch.min(y))
|
| 932 |
-
if torch.max(y) > 1: print('[error with torchfcpe.mel_extractor.MelModule] max ', torch.max(y))
|
| 933 |
-
|
| 934 |
-
key_shift_key = str(key_shift)
|
| 935 |
-
if not no_cache_window:
|
| 936 |
-
if key_shift_key in self.hann_window: hann_window = self.hann_window[key_shift_key]
|
| 937 |
-
else:
|
| 938 |
-
hann_window = HannWindow(win_size_new).to(self.mel_basis.device)
|
| 939 |
-
self.hann_window[key_shift_key] = hann_window
|
| 940 |
-
|
| 941 |
-
hann_window_tensor = hann_window()
|
| 942 |
-
else: hann_window_tensor = torch.hann_window(win_size_new).to(self.mel_basis.device)
|
| 943 |
-
|
| 944 |
-
pad_left = (win_size_new - hop_length_new) // 2
|
| 945 |
-
pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left)
|
| 946 |
-
|
| 947 |
-
mode = 'reflect' if pad_right < y.size(-1) else 'constant'
|
| 948 |
-
|
| 949 |
-
spec = torch.stft(F.pad(y.unsqueeze(1), (pad_left, pad_right), mode=mode).squeeze(1), n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=hann_window_tensor, center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
|
| 950 |
-
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-9)
|
| 951 |
-
|
| 952 |
-
if key_shift != 0:
|
| 953 |
-
size = n_fft // 2 + 1
|
| 954 |
-
resize = spec.size(1)
|
| 955 |
-
|
| 956 |
-
if resize < size: spec = F.pad(spec, (0, 0, 0, size - resize))
|
| 957 |
-
spec = spec[:, :size, :] * win_size / win_size_new
|
| 958 |
-
|
| 959 |
-
spec = spec[:, :512, :] if self.out_stft else torch.matmul(self.mel_basis, spec)
|
| 960 |
-
|
| 961 |
-
return dynamic_range_compression_torch(spec, clip_val=clip_val).transpose(-1, -2)
|
| 962 |
-
|
| 963 |
-
class Wav2MelModule(torch.nn.Module):
|
| 964 |
-
def __init__(self, sr, n_mels, n_fft, win_size, hop_length, fmin = None, fmax = None, clip_val = 1e-5, mel_type="default"):
|
| 965 |
-
super().__init__()
|
| 966 |
-
if fmin is None: fmin = 0
|
| 967 |
-
if fmax is None: fmax = sr / 2
|
| 968 |
-
|
| 969 |
-
self.sampling_rate = sr
|
| 970 |
-
self.n_mels = n_mels
|
| 971 |
-
self.n_fft = n_fft
|
| 972 |
-
self.win_size = win_size
|
| 973 |
-
self.hop_size = hop_length
|
| 974 |
-
self.fmin = fmin
|
| 975 |
-
self.fmax = fmax
|
| 976 |
-
self.clip_val = clip_val
|
| 977 |
-
|
| 978 |
-
self.register_buffer('tensor_device_marker', torch.tensor(1.0).float(), persistent=False)
|
| 979 |
-
self.resample_kernel = torch.nn.ModuleDict()
|
| 980 |
-
|
| 981 |
-
if mel_type == "default": self.mel_extractor = MelModule(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val, out_stft=False)
|
| 982 |
-
elif mel_type == "stft": self.mel_extractor = MelModule(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val, out_stft=True)
|
| 983 |
-
|
| 984 |
-
self.mel_type = mel_type
|
| 985 |
-
|
| 986 |
-
@torch.no_grad()
|
| 987 |
-
def __call__(self, audio, sample_rate, keyshift = 0, no_cache_window = False):
|
| 988 |
-
|
| 989 |
-
if sample_rate == self.sampling_rate: audio_res = audio
|
| 990 |
-
else:
|
| 991 |
-
key_str = str(sample_rate)
|
| 992 |
-
|
| 993 |
-
if key_str not in self.resample_kernel:
|
| 994 |
-
if len(self.resample_kernel) > 8: self.resample_kernel.clear()
|
| 995 |
-
self.resample_kernel[key_str] = Resample(sample_rate, self.sampling_rate, lowpass_filter_width=128).to(self.tensor_device_marker.device)
|
| 996 |
-
|
| 997 |
-
audio_res = self.resample_kernel[key_str](audio.squeeze(-1)).unsqueeze(-1)
|
| 998 |
-
|
| 999 |
-
mel = self.mel_extractor(audio_res, keyshift, no_cache_window=no_cache_window)
|
| 1000 |
-
n_frames = int(audio.shape[1] // self.hop_size) + 1
|
| 1001 |
-
|
| 1002 |
-
if n_frames > int(mel.shape[1]): mel = torch.cat((mel, mel[:, -1:, :]), 1)
|
| 1003 |
-
if n_frames < int(mel.shape[1]): mel = mel[:, :n_frames, :]
|
| 1004 |
-
|
| 1005 |
-
return mel
|
| 1006 |
-
|
| 1007 |
-
class Wav2Mel:
|
| 1008 |
-
def __init__(self, device=None, dtype=torch.float32):
|
| 1009 |
-
self.sample_rate = 16000
|
| 1010 |
-
self.hop_size = 160
|
| 1011 |
-
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 1012 |
-
self.device = device
|
| 1013 |
-
self.dtype = dtype
|
| 1014 |
-
self.stft = STFT(16000, 128, 1024, 1024, 160, 0, 8000)
|
| 1015 |
-
self.resample_kernel = {}
|
| 1016 |
-
|
| 1017 |
-
def extract_nvstft(self, audio, keyshift=0, train=False):
|
| 1018 |
-
return self.stft.get_mel(audio, keyshift=keyshift, train=train).transpose(1, 2)
|
| 1019 |
-
|
| 1020 |
-
def extract_mel(self, audio, sample_rate, keyshift=0, train=False):
|
| 1021 |
-
audio = audio.to(self.dtype).to(self.device)
|
| 1022 |
-
|
| 1023 |
-
if sample_rate == self.sample_rate: audio_res = audio
|
| 1024 |
-
else:
|
| 1025 |
-
key_str = str(sample_rate)
|
| 1026 |
-
if key_str not in self.resample_kernel: self.resample_kernel[key_str] = Resample(sample_rate, self.sample_rate, lowpass_filter_width=128)
|
| 1027 |
-
|
| 1028 |
-
self.resample_kernel[key_str] = (self.resample_kernel[key_str].to(self.dtype).to(self.device))
|
| 1029 |
-
audio_res = self.resample_kernel[key_str](audio)
|
| 1030 |
-
|
| 1031 |
-
mel = self.extract_nvstft(audio_res, keyshift=keyshift, train=train)
|
| 1032 |
-
n_frames = int(audio.shape[1] // self.hop_size) + 1
|
| 1033 |
-
|
| 1034 |
-
mel = (torch.cat((mel, mel[:, -1:, :]), 1) if n_frames > int(mel.shape[1]) else mel)
|
| 1035 |
-
return mel[:, :n_frames, :] if n_frames < int(mel.shape[1]) else mel
|
| 1036 |
-
|
| 1037 |
-
def __call__(self, audio, sample_rate, keyshift=0, train=False):
|
| 1038 |
-
return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train)
|
| 1039 |
-
|
| 1040 |
-
class DotDict(dict):
|
| 1041 |
-
def __getattr__(*args):
|
| 1042 |
-
val = dict.get(*args)
|
| 1043 |
-
return DotDict(val) if type(val) is dict else val
|
| 1044 |
-
|
| 1045 |
-
__setattr__ = dict.__setitem__
|
| 1046 |
-
__delattr__ = dict.__delitem__
|
| 1047 |
-
|
| 1048 |
-
class FCPE:
|
| 1049 |
-
def __init__(self, model_path, hop_length=512, f0_min=50, f0_max=1100, dtype=torch.float32, device=None, sample_rate=44100, threshold=0.05, providers=None, onnx=False, legacy=False):
|
| 1050 |
-
self.fcpe = FCPEInfer_LEGACY(model_path, device=device, dtype=dtype, providers=providers, onnx=onnx) if legacy else FCPEInfer(model_path, device=device, dtype=dtype, providers=providers, onnx=onnx)
|
| 1051 |
-
self.hop_length = hop_length
|
| 1052 |
-
self.f0_min = f0_min
|
| 1053 |
-
self.f0_max = f0_max
|
| 1054 |
-
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 1055 |
-
self.threshold = threshold
|
| 1056 |
-
self.sample_rate = sample_rate
|
| 1057 |
-
self.dtype = dtype
|
| 1058 |
-
self.legacy = legacy
|
| 1059 |
-
self.name = "fcpe"
|
| 1060 |
-
|
| 1061 |
-
def repeat_expand(self, content, target_len, mode = "nearest"):
|
| 1062 |
-
ndim = content.ndim
|
| 1063 |
-
content = (content[None, None] if ndim == 1 else content[None] if ndim == 2 else content)
|
| 1064 |
-
|
| 1065 |
-
assert content.ndim == 3
|
| 1066 |
-
is_np = isinstance(content, np.ndarray)
|
| 1067 |
-
|
| 1068 |
-
results = F.interpolate(torch.from_numpy(content) if is_np else content, size=target_len, mode=mode)
|
| 1069 |
-
results = results.numpy() if is_np else results
|
| 1070 |
-
return results[0, 0] if ndim == 1 else results[0] if ndim == 2 else results
|
| 1071 |
-
|
| 1072 |
-
def post_process(self, x, sample_rate, f0, pad_to):
|
| 1073 |
-
f0 = (torch.from_numpy(f0).float().to(x.device) if isinstance(f0, np.ndarray) else f0)
|
| 1074 |
-
f0 = self.repeat_expand(f0, pad_to) if pad_to is not None else f0
|
| 1075 |
-
|
| 1076 |
-
vuv_vector = torch.zeros_like(f0)
|
| 1077 |
-
vuv_vector[f0 > 0.0] = 1.0
|
| 1078 |
-
vuv_vector[f0 <= 0.0] = 0.0
|
| 1079 |
-
|
| 1080 |
-
nzindex = torch.nonzero(f0).squeeze()
|
| 1081 |
-
f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
|
| 1082 |
-
vuv_vector = F.interpolate(vuv_vector[None, None, :], size=pad_to)[0][0]
|
| 1083 |
-
|
| 1084 |
-
if f0.shape[0] <= 0: return np.zeros(pad_to), vuv_vector.cpu().numpy()
|
| 1085 |
-
if f0.shape[0] == 1: return np.ones(pad_to) * f0[0], vuv_vector.cpu().numpy()
|
| 1086 |
-
|
| 1087 |
-
return np.interp(np.arange(pad_to) * self.hop_length / sample_rate, self.hop_length / sample_rate * nzindex.cpu().numpy(), f0, left=f0[0], right=f0[-1]), vuv_vector.cpu().numpy()
|
| 1088 |
-
|
| 1089 |
-
def compute_f0(self, wav, p_len=None):
|
| 1090 |
-
x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
|
| 1091 |
-
p_len = x.shape[0] // self.hop_length if p_len is None else p_len
|
| 1092 |
-
|
| 1093 |
-
f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold) if self.legacy else (self.fcpe(x, sr=self.sample_rate, threshold=self.threshold, f0_min=self.f0_min, f0_max=self.f0_max, p_len=p_len))
|
| 1094 |
-
f0 = f0[:] if f0.dim() == 1 else f0[0, :, 0]
|
| 1095 |
-
|
| 1096 |
-
if torch.all(f0 == 0): return f0.cpu().numpy() if p_len is None else np.zeros(p_len), (f0.cpu().numpy() if p_len is None else np.zeros(p_len))
|
| 1097 |
-
return self.post_process(x, self.sample_rate, f0, p_len)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/predictors/RMVPE.py
DELETED
|
@@ -1,260 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch.nn as nn
|
| 5 |
-
import torch.nn.functional as F
|
| 6 |
-
|
| 7 |
-
from librosa.filters import mel
|
| 8 |
-
|
| 9 |
-
N_MELS, N_CLASS = 128, 360
|
| 10 |
-
|
| 11 |
-
class ConvBlockRes(nn.Module):
|
| 12 |
-
def __init__(self, in_channels, out_channels, momentum=0.01):
|
| 13 |
-
super(ConvBlockRes, self).__init__()
|
| 14 |
-
self.conv = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU())
|
| 15 |
-
|
| 16 |
-
if in_channels != out_channels:
|
| 17 |
-
self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
|
| 18 |
-
self.is_shortcut = True
|
| 19 |
-
else: self.is_shortcut = False
|
| 20 |
-
|
| 21 |
-
def forward(self, x):
|
| 22 |
-
return self.conv(x) + self.shortcut(x) if self.is_shortcut else self.conv(x) + x
|
| 23 |
-
|
| 24 |
-
class ResEncoderBlock(nn.Module):
|
| 25 |
-
def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
|
| 26 |
-
super(ResEncoderBlock, self).__init__()
|
| 27 |
-
self.n_blocks = n_blocks
|
| 28 |
-
self.conv = nn.ModuleList()
|
| 29 |
-
self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
|
| 30 |
-
|
| 31 |
-
for _ in range(n_blocks - 1):
|
| 32 |
-
self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
|
| 33 |
-
|
| 34 |
-
self.kernel_size = kernel_size
|
| 35 |
-
if self.kernel_size is not None: self.pool = nn.AvgPool2d(kernel_size=kernel_size)
|
| 36 |
-
|
| 37 |
-
def forward(self, x):
|
| 38 |
-
for i in range(self.n_blocks):
|
| 39 |
-
x = self.conv[i](x)
|
| 40 |
-
|
| 41 |
-
if self.kernel_size is not None: return x, self.pool(x)
|
| 42 |
-
else: return x
|
| 43 |
-
|
| 44 |
-
class Encoder(nn.Module):
|
| 45 |
-
def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
|
| 46 |
-
super(Encoder, self).__init__()
|
| 47 |
-
self.n_encoders = n_encoders
|
| 48 |
-
self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
|
| 49 |
-
self.layers = nn.ModuleList()
|
| 50 |
-
self.latent_channels = []
|
| 51 |
-
|
| 52 |
-
for _ in range(self.n_encoders):
|
| 53 |
-
self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum))
|
| 54 |
-
self.latent_channels.append([out_channels, in_size])
|
| 55 |
-
in_channels = out_channels
|
| 56 |
-
out_channels *= 2
|
| 57 |
-
in_size //= 2
|
| 58 |
-
|
| 59 |
-
self.out_size = in_size
|
| 60 |
-
self.out_channel = out_channels
|
| 61 |
-
|
| 62 |
-
def forward(self, x):
|
| 63 |
-
concat_tensors = []
|
| 64 |
-
x = self.bn(x)
|
| 65 |
-
|
| 66 |
-
for i in range(self.n_encoders):
|
| 67 |
-
t, x = self.layers[i](x)
|
| 68 |
-
concat_tensors.append(t)
|
| 69 |
-
|
| 70 |
-
return x, concat_tensors
|
| 71 |
-
|
| 72 |
-
class Intermediate(nn.Module):
|
| 73 |
-
def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
|
| 74 |
-
super(Intermediate, self).__init__()
|
| 75 |
-
self.n_inters = n_inters
|
| 76 |
-
self.layers = nn.ModuleList()
|
| 77 |
-
self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
|
| 78 |
-
|
| 79 |
-
for _ in range(self.n_inters - 1):
|
| 80 |
-
self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
|
| 81 |
-
|
| 82 |
-
def forward(self, x):
|
| 83 |
-
for i in range(self.n_inters):
|
| 84 |
-
x = self.layers[i](x)
|
| 85 |
-
|
| 86 |
-
return x
|
| 87 |
-
|
| 88 |
-
class ResDecoderBlock(nn.Module):
|
| 89 |
-
def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
|
| 90 |
-
super(ResDecoderBlock, self).__init__()
|
| 91 |
-
out_padding = (0, 1) if stride == (1, 2) else (1, 1)
|
| 92 |
-
self.n_blocks = n_blocks
|
| 93 |
-
self.conv1 = nn.Sequential(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=stride, padding=(1, 1), output_padding=out_padding, bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU())
|
| 94 |
-
self.conv2 = nn.ModuleList()
|
| 95 |
-
self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
|
| 96 |
-
|
| 97 |
-
for _ in range(n_blocks - 1):
|
| 98 |
-
self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
|
| 99 |
-
|
| 100 |
-
def forward(self, x, concat_tensor):
|
| 101 |
-
x = torch.cat((self.conv1(x), concat_tensor), dim=1)
|
| 102 |
-
|
| 103 |
-
for i in range(self.n_blocks):
|
| 104 |
-
x = self.conv2[i](x)
|
| 105 |
-
|
| 106 |
-
return x
|
| 107 |
-
|
| 108 |
-
class Decoder(nn.Module):
|
| 109 |
-
def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
|
| 110 |
-
super(Decoder, self).__init__()
|
| 111 |
-
self.layers = nn.ModuleList()
|
| 112 |
-
self.n_decoders = n_decoders
|
| 113 |
-
|
| 114 |
-
for _ in range(self.n_decoders):
|
| 115 |
-
out_channels = in_channels // 2
|
| 116 |
-
self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum))
|
| 117 |
-
in_channels = out_channels
|
| 118 |
-
|
| 119 |
-
def forward(self, x, concat_tensors):
|
| 120 |
-
for i in range(self.n_decoders):
|
| 121 |
-
x = self.layers[i](x, concat_tensors[-1 - i])
|
| 122 |
-
|
| 123 |
-
return x
|
| 124 |
-
|
| 125 |
-
class DeepUnet(nn.Module):
|
| 126 |
-
def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
|
| 127 |
-
super(DeepUnet, self).__init__()
|
| 128 |
-
self.encoder = Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels)
|
| 129 |
-
self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
|
| 130 |
-
self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
|
| 131 |
-
|
| 132 |
-
def forward(self, x):
|
| 133 |
-
x, concat_tensors = self.encoder(x)
|
| 134 |
-
return self.decoder(self.intermediate(x), concat_tensors)
|
| 135 |
-
|
| 136 |
-
class E2E(nn.Module):
|
| 137 |
-
def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
|
| 138 |
-
super(E2E, self).__init__()
|
| 139 |
-
self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
|
| 140 |
-
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
|
| 141 |
-
self.fc = nn.Sequential(BiGRU(3 * 128, 256, n_gru), nn.Linear(512, N_CLASS), nn.Dropout(0.25), nn.Sigmoid()) if n_gru else nn.Sequential(nn.Linear(3 * N_MELS, N_CLASS), nn.Dropout(0.25), nn.Sigmoid())
|
| 142 |
-
|
| 143 |
-
def forward(self, mel):
|
| 144 |
-
return self.fc(self.cnn(self.unet(mel.transpose(-1, -2).unsqueeze(1))).transpose(1, 2).flatten(-2))
|
| 145 |
-
|
| 146 |
-
class MelSpectrogram(torch.nn.Module):
|
| 147 |
-
def __init__(self, is_half, n_mel_channels, sample_rate, win_length, hop_length, n_fft=None, mel_fmin=0, mel_fmax=None, clamp=1e-5):
|
| 148 |
-
super().__init__()
|
| 149 |
-
n_fft = win_length if n_fft is None else n_fft
|
| 150 |
-
self.hann_window = {}
|
| 151 |
-
mel_basis = mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax, htk=True)
|
| 152 |
-
mel_basis = torch.from_numpy(mel_basis).float()
|
| 153 |
-
self.register_buffer("mel_basis", mel_basis)
|
| 154 |
-
self.n_fft = win_length if n_fft is None else n_fft
|
| 155 |
-
self.hop_length = hop_length
|
| 156 |
-
self.win_length = win_length
|
| 157 |
-
self.sample_rate = sample_rate
|
| 158 |
-
self.n_mel_channels = n_mel_channels
|
| 159 |
-
self.clamp = clamp
|
| 160 |
-
self.is_half = is_half
|
| 161 |
-
|
| 162 |
-
def forward(self, audio, keyshift=0, speed=1, center=True):
|
| 163 |
-
factor = 2 ** (keyshift / 12)
|
| 164 |
-
win_length_new = int(np.round(self.win_length * factor))
|
| 165 |
-
keyshift_key = str(keyshift) + "_" + str(audio.device)
|
| 166 |
-
if keyshift_key not in self.hann_window: self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
|
| 167 |
-
|
| 168 |
-
fft = torch.stft(audio, n_fft=int(np.round(self.n_fft * factor)), hop_length=int(np.round(self.hop_length * speed)), win_length=win_length_new, window=self.hann_window[keyshift_key], center=center, return_complex=True)
|
| 169 |
-
magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
|
| 170 |
-
|
| 171 |
-
if keyshift != 0:
|
| 172 |
-
size = self.n_fft // 2 + 1
|
| 173 |
-
resize = magnitude.size(1)
|
| 174 |
-
|
| 175 |
-
if resize < size: magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
|
| 176 |
-
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
|
| 177 |
-
|
| 178 |
-
mel_output = torch.matmul(self.mel_basis, magnitude)
|
| 179 |
-
if self.is_half: mel_output = mel_output.half()
|
| 180 |
-
|
| 181 |
-
return torch.log(torch.clamp(mel_output, min=self.clamp))
|
| 182 |
-
|
| 183 |
-
class RMVPE:
|
| 184 |
-
def __init__(self, model_path, is_half, device=None, providers=None, onnx=False):
|
| 185 |
-
self.resample_kernel = {}
|
| 186 |
-
self.onnx = onnx
|
| 187 |
-
|
| 188 |
-
if self.onnx:
|
| 189 |
-
import onnxruntime as ort
|
| 190 |
-
|
| 191 |
-
sess_options = ort.SessionOptions()
|
| 192 |
-
sess_options.log_severity_level = 3
|
| 193 |
-
|
| 194 |
-
self.model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
|
| 195 |
-
else:
|
| 196 |
-
model = E2E(4, 1, (2, 2))
|
| 197 |
-
ckpt = torch.load(model_path, map_location="cpu")
|
| 198 |
-
model.load_state_dict(ckpt)
|
| 199 |
-
model.eval()
|
| 200 |
-
if is_half: model = model.half()
|
| 201 |
-
self.model = model.to(device)
|
| 202 |
-
|
| 203 |
-
self.resample_kernel = {}
|
| 204 |
-
self.is_half = is_half
|
| 205 |
-
self.device = device
|
| 206 |
-
self.mel_extractor = MelSpectrogram(is_half, N_MELS, 16000, 1024, 160, None, 30, 8000).to(device)
|
| 207 |
-
cents_mapping = 20 * np.arange(N_CLASS) + 1997.3794084376191
|
| 208 |
-
self.cents_mapping = np.pad(cents_mapping, (4, 4))
|
| 209 |
-
|
| 210 |
-
def mel2hidden(self, mel):
|
| 211 |
-
with torch.no_grad():
|
| 212 |
-
n_frames = mel.shape[-1]
|
| 213 |
-
mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect")
|
| 214 |
-
hidden = self.model.run([self.model.get_outputs()[0].name], input_feed={self.model.get_inputs()[0].name: mel.cpu().numpy().astype(np.float32)})[0] if self.onnx else self.model(mel.half() if self.is_half else mel.float())
|
| 215 |
-
return hidden[:, :n_frames]
|
| 216 |
-
|
| 217 |
-
def decode(self, hidden, thred=0.03):
|
| 218 |
-
f0 = 10 * (2 ** (self.to_local_average_cents(hidden, thred=thred) / 1200))
|
| 219 |
-
f0[f0 == 10] = 0
|
| 220 |
-
|
| 221 |
-
return f0
|
| 222 |
-
|
| 223 |
-
def infer_from_audio(self, audio, thred=0.03):
|
| 224 |
-
hidden = self.mel2hidden(self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True))
|
| 225 |
-
|
| 226 |
-
return self.decode((hidden.squeeze(0).cpu().numpy().astype(np.float32) if self.is_half else hidden.squeeze(0).cpu().numpy()) if not self.onnx else hidden[0], thred=thred)
|
| 227 |
-
|
| 228 |
-
def infer_from_audio_with_pitch(self, audio, thred=0.03, f0_min=50, f0_max=1100):
|
| 229 |
-
hidden = self.mel2hidden(self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True))
|
| 230 |
-
|
| 231 |
-
f0 = self.decode((hidden.squeeze(0).cpu().numpy().astype(np.float32) if self.is_half else hidden.squeeze(0).cpu().numpy()) if not self.onnx else hidden[0], thred=thred)
|
| 232 |
-
f0[(f0 < f0_min) | (f0 > f0_max)] = 0
|
| 233 |
-
|
| 234 |
-
return f0
|
| 235 |
-
|
| 236 |
-
def to_local_average_cents(self, salience, thred=0.05):
|
| 237 |
-
center = np.argmax(salience, axis=1)
|
| 238 |
-
salience = np.pad(salience, ((0, 0), (4, 4)))
|
| 239 |
-
center += 4
|
| 240 |
-
todo_salience, todo_cents_mapping = [], []
|
| 241 |
-
starts = center - 4
|
| 242 |
-
ends = center + 5
|
| 243 |
-
|
| 244 |
-
for idx in range(salience.shape[0]):
|
| 245 |
-
todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
|
| 246 |
-
todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
|
| 247 |
-
|
| 248 |
-
todo_salience = np.array(todo_salience)
|
| 249 |
-
devided = np.sum(todo_salience * np.array(todo_cents_mapping), 1) / np.sum(todo_salience, 1)
|
| 250 |
-
devided[np.max(salience, axis=1) <= thred] = 0
|
| 251 |
-
|
| 252 |
-
return devided
|
| 253 |
-
|
| 254 |
-
class BiGRU(nn.Module):
|
| 255 |
-
def __init__(self, input_features, hidden_features, num_layers):
|
| 256 |
-
super(BiGRU, self).__init__()
|
| 257 |
-
self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
|
| 258 |
-
|
| 259 |
-
def forward(self, x):
|
| 260 |
-
return self.gru(x)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/predictors/SWIPE.py
DELETED
|
@@ -1,140 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
|
| 5 |
-
from matplotlib import mlab
|
| 6 |
-
from scipy import interpolate
|
| 7 |
-
from decimal import Decimal, ROUND_HALF_UP
|
| 8 |
-
|
| 9 |
-
def swipe(x, fs, f0_floor=50, f0_ceil=1100, frame_period=10, sTHR=0.3):
|
| 10 |
-
plim = np.array([f0_floor, f0_ceil])
|
| 11 |
-
t = np.arange(0, int(1000 * len(x) / fs / (frame_period) + 1)) * (frame_period / 1000)
|
| 12 |
-
|
| 13 |
-
log2pc = np.arange(np.log2(plim[0]) * 96, np.log2(plim[-1]) * 96)
|
| 14 |
-
log2pc *= (1 / 96)
|
| 15 |
-
|
| 16 |
-
pc = 2 ** log2pc
|
| 17 |
-
S = np.zeros((len(pc), len(t)))
|
| 18 |
-
|
| 19 |
-
logWs = [round_matlab(elm) for elm in np.log2(4 * 2 * fs / plim)]
|
| 20 |
-
ws = 2 ** np.arange(logWs[0], logWs[1] - 1, -1)
|
| 21 |
-
p0 = 4 * 2 * fs / ws
|
| 22 |
-
|
| 23 |
-
d = 1 + log2pc - np.log2(4 * 2 * fs / ws[0])
|
| 24 |
-
fERBs = erbs2hz(np.arange(hz2erbs(pc[0] / 4), hz2erbs(fs / 2), 0.1))
|
| 25 |
-
|
| 26 |
-
for i in range(len(ws)):
|
| 27 |
-
dn = round_matlab(4 * fs / p0[i])
|
| 28 |
-
X, f, ti = mlab.specgram(x=np.r_[np.zeros(int(ws[i] / 2)), np.r_[x, np.zeros(int(dn + ws[i] / 2))]], NFFT=ws[i], Fs=fs, window=np.hanning(ws[i] + 2)[1:-1], noverlap=max(0, np.round(ws[i] - dn)), mode='complex')
|
| 29 |
-
ti = np.r_[0, ti[:-1]]
|
| 30 |
-
M = np.maximum(0, interpolate.interp1d(f, np.abs(X.T), kind='cubic')(fERBs)).T
|
| 31 |
-
|
| 32 |
-
if i == len(ws) - 1:
|
| 33 |
-
j = np.where(d - (i + 1) > -1)[0]
|
| 34 |
-
k = np.where(d[j] - (i + 1) < 0)[0]
|
| 35 |
-
elif i == 0:
|
| 36 |
-
j = np.where(d - (i + 1) < 1)[0]
|
| 37 |
-
k = np.where(d[j] - (i + 1) > 0)[0]
|
| 38 |
-
else:
|
| 39 |
-
j = np.where(np.abs(d - (i + 1)) < 1)[0]
|
| 40 |
-
k = np.arange(len(j))
|
| 41 |
-
|
| 42 |
-
Si = pitchStrengthAllCandidates(fERBs, np.sqrt(M), pc[j])
|
| 43 |
-
Si = interpolate.interp1d(ti, Si, bounds_error=False, fill_value='nan')(t) if Si.shape[1] > 1 else np.full((len(Si), len(t)), np.nan)
|
| 44 |
-
|
| 45 |
-
mu = np.ones(j.shape)
|
| 46 |
-
mu[k] = 1 - np.abs(d[j[k]] - i - 1)
|
| 47 |
-
S[j, :] = S[j, :] + np.tile(mu.reshape(-1, 1), (1, Si.shape[1])) * Si
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
p = np.full((S.shape[1], 1), np.nan)
|
| 51 |
-
s = np.full((S.shape[1], 1), np.nan)
|
| 52 |
-
|
| 53 |
-
for j in range(S.shape[1]):
|
| 54 |
-
s[j] = np.max(S[:, j])
|
| 55 |
-
i = np.argmax(S[:, j])
|
| 56 |
-
|
| 57 |
-
if s[j] < sTHR: continue
|
| 58 |
-
|
| 59 |
-
if i == 0: p[j] = pc[0]
|
| 60 |
-
elif i == len(pc) - 1: p[j] = pc[0]
|
| 61 |
-
else:
|
| 62 |
-
I = np.arange(i-1, i+2)
|
| 63 |
-
tc = 1 / pc[I]
|
| 64 |
-
|
| 65 |
-
ntc = (tc / tc[1] - 1) * 2 * np.pi
|
| 66 |
-
idx = np.isfinite(S[I, j])
|
| 67 |
-
|
| 68 |
-
c = np.zeros(len(ntc))
|
| 69 |
-
c += np.nan
|
| 70 |
-
|
| 71 |
-
I_ = I[idx]
|
| 72 |
-
|
| 73 |
-
if len(I_) < 2: c[idx] = (S[I, j])[0] / ntc[0]
|
| 74 |
-
else: c[idx] = np.polyfit(ntc[idx], (S[I_, j]), 2)
|
| 75 |
-
|
| 76 |
-
pval = np.polyval(c, ((1 / (2 ** np.arange(np.log2(pc[I[0]]), np.log2(pc[I[2]]) + 1 / 12 / 64, 1 / 12 / 64))) / tc[1] - 1) * 2 * np.pi)
|
| 77 |
-
s[j] = np.max(pval)
|
| 78 |
-
p[j] = 2 ** (np.log2(pc[I[0]]) + (np.argmax(pval)) / 12 / 64)
|
| 79 |
-
|
| 80 |
-
p = p.flatten()
|
| 81 |
-
p[np.isnan(p)] = 0
|
| 82 |
-
|
| 83 |
-
return np.array(p, dtype=np.float32), np.array(t, dtype=np.float32)
|
| 84 |
-
|
| 85 |
-
def round_matlab(n):
|
| 86 |
-
return int(Decimal(n).quantize(0, ROUND_HALF_UP))
|
| 87 |
-
|
| 88 |
-
def pitchStrengthAllCandidates(f, L, pc):
|
| 89 |
-
den = np.sqrt(np.sum(L * L, axis=0))
|
| 90 |
-
den = np.where(den == 0, 2.220446049250313e-16, den)
|
| 91 |
-
|
| 92 |
-
L = L / den
|
| 93 |
-
S = np.zeros((len(pc), L.shape[1]))
|
| 94 |
-
|
| 95 |
-
for j in range(len(pc)):
|
| 96 |
-
S[j,:] = pitchStrengthOneCandidate(f, L, pc[j])
|
| 97 |
-
|
| 98 |
-
return S
|
| 99 |
-
|
| 100 |
-
def pitchStrengthOneCandidate(f, L, pc):
|
| 101 |
-
k = np.zeros(len(f))
|
| 102 |
-
q = f / pc
|
| 103 |
-
|
| 104 |
-
for i in ([1] + sieve(int(np.fix(f[-1] / pc - 0.75)))):
|
| 105 |
-
a = np.abs(q - i)
|
| 106 |
-
p = a < 0.25
|
| 107 |
-
k[p] = np.cos(2 * np.pi * q[p])
|
| 108 |
-
|
| 109 |
-
v = np.logical_and((0.25 < a), (a < 0.75))
|
| 110 |
-
k[v] = k[v] + np.cos(2 * np.pi * q[v]) / 2
|
| 111 |
-
|
| 112 |
-
k *= np.sqrt(1 / f)
|
| 113 |
-
k /= np.linalg.norm(k[k>0])
|
| 114 |
-
|
| 115 |
-
return k @ L
|
| 116 |
-
|
| 117 |
-
def hz2erbs(hz):
|
| 118 |
-
return 21.4 * np.log10(1 + hz / 229)
|
| 119 |
-
|
| 120 |
-
def erbs2hz(erbs):
|
| 121 |
-
return (10 ** (erbs / 21.4) - 1) * 229
|
| 122 |
-
|
| 123 |
-
def sieve(n):
|
| 124 |
-
primes = list(range(2, n + 1))
|
| 125 |
-
num = 2
|
| 126 |
-
|
| 127 |
-
while num < math.sqrt(n):
|
| 128 |
-
i = num
|
| 129 |
-
|
| 130 |
-
while i <= n:
|
| 131 |
-
i += num
|
| 132 |
-
|
| 133 |
-
if i in primes: primes.remove(i)
|
| 134 |
-
|
| 135 |
-
for j in primes:
|
| 136 |
-
if j > num:
|
| 137 |
-
num = j
|
| 138 |
-
break
|
| 139 |
-
|
| 140 |
-
return primes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/predictors/WORLD_WRAPPER.py
DELETED
|
@@ -1,90 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
import ctypes
|
| 4 |
-
import platform
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class DioOption(ctypes.Structure):
|
| 11 |
-
_fields_ = [("F0Floor", ctypes.c_double), ("F0Ceil", ctypes.c_double), ("ChannelsInOctave", ctypes.c_double), ("FramePeriod", ctypes.c_double), ("Speed", ctypes.c_int), ("AllowedRange", ctypes.c_double)]
|
| 12 |
-
|
| 13 |
-
class HarvestOption(ctypes.Structure):
|
| 14 |
-
_fields_ = [("F0Floor", ctypes.c_double), ("F0Ceil", ctypes.c_double), ("FramePeriod", ctypes.c_double)]
|
| 15 |
-
|
| 16 |
-
class PYWORLD:
|
| 17 |
-
def __init__(self):
|
| 18 |
-
self.world_path = os.path.join("assets", "models", "predictors", "world")
|
| 19 |
-
os.makedirs(self.world_path, exist_ok=True)
|
| 20 |
-
|
| 21 |
-
model_type, suffix = (("world_64" if platform.architecture()[0] == "64bit" else "world_86"), ".dll") if platform.system() == "Windows" else ("world_linux", ".so")
|
| 22 |
-
self.world_file_path = os.path.join(self.world_path, f"{model_type}{suffix}")
|
| 23 |
-
|
| 24 |
-
if not os.path.exists(self.world_file_path):
|
| 25 |
-
model = torch.load(os.path.join("assets", "models", "predictors", "world.pth"), map_location="cpu")
|
| 26 |
-
|
| 27 |
-
with open(self.world_file_path, "wb") as w:
|
| 28 |
-
w.write(model[model_type])
|
| 29 |
-
|
| 30 |
-
self.world_dll = ctypes.CDLL(self.world_file_path)
|
| 31 |
-
|
| 32 |
-
def harvest(self, x, fs, f0_floor=50, f0_ceil=1100, frame_period=10):
|
| 33 |
-
self.world_dll.Harvest.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(HarvestOption), ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double)]
|
| 34 |
-
self.world_dll.Harvest.restype = None
|
| 35 |
-
|
| 36 |
-
self.world_dll.InitializeHarvestOption.argtypes = [ctypes.POINTER(HarvestOption)]
|
| 37 |
-
self.world_dll.InitializeHarvestOption.restype = None
|
| 38 |
-
|
| 39 |
-
self.world_dll.GetSamplesForHarvest.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_double]
|
| 40 |
-
self.world_dll.GetSamplesForHarvest.restype = ctypes.c_int
|
| 41 |
-
|
| 42 |
-
option = HarvestOption()
|
| 43 |
-
self.world_dll.InitializeHarvestOption(ctypes.byref(option))
|
| 44 |
-
|
| 45 |
-
option.F0Floor = f0_floor
|
| 46 |
-
option.F0Ceil = f0_ceil
|
| 47 |
-
option.FramePeriod = frame_period
|
| 48 |
-
|
| 49 |
-
f0_length = self.world_dll.GetSamplesForHarvest(fs, len(x), option.FramePeriod)
|
| 50 |
-
f0 = (ctypes.c_double * f0_length)()
|
| 51 |
-
tpos = (ctypes.c_double * f0_length)()
|
| 52 |
-
|
| 53 |
-
self.world_dll.Harvest((ctypes.c_double * len(x))(*x), len(x), fs, ctypes.byref(option), tpos, f0)
|
| 54 |
-
return np.array(f0, dtype=np.float32), np.array(tpos, dtype=np.float32)
|
| 55 |
-
|
| 56 |
-
def dio(self, x, fs, f0_floor=50, f0_ceil=1100, channels_in_octave=2, frame_period=10, speed=1, allowed_range=0.1):
|
| 57 |
-
self.world_dll.Dio.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(DioOption), ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double)]
|
| 58 |
-
self.world_dll.Dio.restype = None
|
| 59 |
-
|
| 60 |
-
self.world_dll.InitializeDioOption.argtypes = [ctypes.POINTER(DioOption)]
|
| 61 |
-
self.world_dll.InitializeDioOption.restype = None
|
| 62 |
-
|
| 63 |
-
self.world_dll.GetSamplesForDIO.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_double]
|
| 64 |
-
self.world_dll.GetSamplesForDIO.restype = ctypes.c_int
|
| 65 |
-
|
| 66 |
-
option = DioOption()
|
| 67 |
-
self.world_dll.InitializeDioOption(ctypes.byref(option))
|
| 68 |
-
|
| 69 |
-
option.F0Floor = f0_floor
|
| 70 |
-
option.F0Ceil = f0_ceil
|
| 71 |
-
option.ChannelsInOctave = channels_in_octave
|
| 72 |
-
option.FramePeriod = frame_period
|
| 73 |
-
option.Speed = speed
|
| 74 |
-
option.AllowedRange = allowed_range
|
| 75 |
-
|
| 76 |
-
f0_length = self.world_dll.GetSamplesForDIO(fs, len(x), option.FramePeriod)
|
| 77 |
-
f0 = (ctypes.c_double * f0_length)()
|
| 78 |
-
tpos = (ctypes.c_double * f0_length)()
|
| 79 |
-
|
| 80 |
-
self.world_dll.Dio((ctypes.c_double * len(x))(*x), len(x), fs, ctypes.byref(option), tpos, f0)
|
| 81 |
-
return np.array(f0, dtype=np.float32), np.array(tpos, dtype=np.float32)
|
| 82 |
-
|
| 83 |
-
def stonemask(self, x, fs, tpos, f0):
|
| 84 |
-
self.world_dll.StoneMask.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.POINTER(ctypes.c_double)]
|
| 85 |
-
self.world_dll.StoneMask.restype = None
|
| 86 |
-
|
| 87 |
-
out_f0 = (ctypes.c_double * len(f0))()
|
| 88 |
-
self.world_dll.StoneMask((ctypes.c_double * len(x))(*x), len(x), fs, (ctypes.c_double * len(tpos))(*tpos), (ctypes.c_double * len(f0))(*f0), len(f0), out_f0)
|
| 89 |
-
|
| 90 |
-
return np.array(out_f0, dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/speaker_diarization/ECAPA_TDNN.py
DELETED
|
@@ -1,280 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
import torch
|
| 3 |
-
|
| 4 |
-
import torch.nn as nn
|
| 5 |
-
import torch.nn.functional as F
|
| 6 |
-
|
| 7 |
-
def length_to_mask(length, max_len=None, dtype=None, device=None):
|
| 8 |
-
assert len(length.shape) == 1
|
| 9 |
-
|
| 10 |
-
if max_len is None: max_len = length.max().long().item()
|
| 11 |
-
|
| 12 |
-
mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(len(length), max_len) < length.unsqueeze(1)
|
| 13 |
-
|
| 14 |
-
if dtype is None: dtype = length.dtype
|
| 15 |
-
if device is None: device = length.device
|
| 16 |
-
|
| 17 |
-
return torch.as_tensor(mask, dtype=dtype, device=device)
|
| 18 |
-
|
| 19 |
-
def get_padding_elem(L_in, stride, kernel_size, dilation):
|
| 20 |
-
if stride > 1: padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)]
|
| 21 |
-
else:
|
| 22 |
-
L_out = (math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1)
|
| 23 |
-
padding = [math.floor((L_in - L_out) / 2), math.floor((L_in - L_out) / 2)]
|
| 24 |
-
|
| 25 |
-
return padding
|
| 26 |
-
|
| 27 |
-
class _BatchNorm1d(nn.Module):
|
| 28 |
-
def __init__(self, input_shape=None, input_size=None, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, combine_batch_time=False, skip_transpose=False):
|
| 29 |
-
super().__init__()
|
| 30 |
-
self.combine_batch_time = combine_batch_time
|
| 31 |
-
self.skip_transpose = skip_transpose
|
| 32 |
-
|
| 33 |
-
if input_size is None and skip_transpose: input_size = input_shape[1]
|
| 34 |
-
elif input_size is None: input_size = input_shape[-1]
|
| 35 |
-
|
| 36 |
-
self.norm = nn.BatchNorm1d(input_size, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
|
| 37 |
-
|
| 38 |
-
def forward(self, x):
|
| 39 |
-
shape_or = x.shape
|
| 40 |
-
|
| 41 |
-
if self.combine_batch_time:x = x.reshape(shape_or[0] * shape_or[1], shape_or[2]) if x.ndim == 3 else x.reshape(shape_or[0] * shape_or[1], shape_or[3], shape_or[2])
|
| 42 |
-
elif not self.skip_transpose: x = x.transpose(-1, 1)
|
| 43 |
-
|
| 44 |
-
x_n = self.norm(x)
|
| 45 |
-
|
| 46 |
-
if self.combine_batch_time: x_n = x_n.reshape(shape_or)
|
| 47 |
-
elif not self.skip_transpose: x_n = x_n.transpose(1, -1)
|
| 48 |
-
|
| 49 |
-
return x_n
|
| 50 |
-
|
| 51 |
-
class _Conv1d(nn.Module):
|
| 52 |
-
def __init__(self, out_channels, kernel_size, input_shape=None, in_channels=None, stride=1, dilation=1, padding="same", groups=1, bias=True, padding_mode="reflect", skip_transpose=False, weight_norm=False, conv_init=None, default_padding=0):
|
| 53 |
-
super().__init__()
|
| 54 |
-
self.kernel_size = kernel_size
|
| 55 |
-
self.stride = stride
|
| 56 |
-
self.dilation = dilation
|
| 57 |
-
self.padding = padding
|
| 58 |
-
self.padding_mode = padding_mode
|
| 59 |
-
self.unsqueeze = False
|
| 60 |
-
self.skip_transpose = skip_transpose
|
| 61 |
-
|
| 62 |
-
if input_shape is None and in_channels is None: raise ValueError
|
| 63 |
-
if in_channels is None: in_channels = self._check_input_shape(input_shape)
|
| 64 |
-
|
| 65 |
-
self.in_channels = in_channels
|
| 66 |
-
self.conv = nn.Conv1d(in_channels, out_channels, self.kernel_size, stride=self.stride, dilation=self.dilation, padding=default_padding, groups=groups, bias=bias)
|
| 67 |
-
|
| 68 |
-
if conv_init == "kaiming": nn.init.kaiming_normal_(self.conv.weight)
|
| 69 |
-
elif conv_init == "zero": nn.init.zeros_(self.conv.weight)
|
| 70 |
-
elif conv_init == "normal": nn.init.normal_(self.conv.weight, std=1e-6)
|
| 71 |
-
|
| 72 |
-
if weight_norm: self.conv = nn.utils.weight_norm(self.conv)
|
| 73 |
-
|
| 74 |
-
def forward(self, x):
|
| 75 |
-
if not self.skip_transpose: x = x.transpose(1, -1)
|
| 76 |
-
if self.unsqueeze: x = x.unsqueeze(1)
|
| 77 |
-
|
| 78 |
-
if self.padding == "same": x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride)
|
| 79 |
-
elif self.padding == "causal": x = F.pad(x, ((self.kernel_size - 1) * self.dilation, 0))
|
| 80 |
-
elif self.padding == "valid": pass
|
| 81 |
-
else: raise ValueError
|
| 82 |
-
|
| 83 |
-
wx = self.conv(x)
|
| 84 |
-
|
| 85 |
-
if self.unsqueeze: wx = wx.squeeze(1)
|
| 86 |
-
if not self.skip_transpose: wx = wx.transpose(1, -1)
|
| 87 |
-
|
| 88 |
-
return wx
|
| 89 |
-
|
| 90 |
-
def _manage_padding(self, x, kernel_size, dilation, stride):
|
| 91 |
-
return F.pad(x, get_padding_elem(self.in_channels, stride, kernel_size, dilation), mode=self.padding_mode)
|
| 92 |
-
|
| 93 |
-
def _check_input_shape(self, shape):
|
| 94 |
-
if len(shape) == 2:
|
| 95 |
-
self.unsqueeze = True
|
| 96 |
-
in_channels = 1
|
| 97 |
-
elif self.skip_transpose: in_channels = shape[1]
|
| 98 |
-
elif len(shape) == 3: in_channels = shape[2]
|
| 99 |
-
else: raise ValueError
|
| 100 |
-
|
| 101 |
-
if not self.padding == "valid" and self.kernel_size % 2 == 0: raise ValueError
|
| 102 |
-
return in_channels
|
| 103 |
-
|
| 104 |
-
def remove_weight_norm(self):
|
| 105 |
-
self.conv = nn.utils.remove_weight_norm(self.conv)
|
| 106 |
-
|
| 107 |
-
class Linear(torch.nn.Module):
|
| 108 |
-
def __init__(self, n_neurons, input_shape=None, input_size=None, bias=True, max_norm=None, combine_dims=False):
|
| 109 |
-
super().__init__()
|
| 110 |
-
self.max_norm = max_norm
|
| 111 |
-
self.combine_dims = combine_dims
|
| 112 |
-
|
| 113 |
-
if input_shape is None and input_size is None: raise ValueError
|
| 114 |
-
if input_size is None:
|
| 115 |
-
input_size = input_shape[-1]
|
| 116 |
-
if len(input_shape) == 4 and self.combine_dims: input_size = input_shape[2] * input_shape[3]
|
| 117 |
-
|
| 118 |
-
self.w = nn.Linear(input_size, n_neurons, bias=bias)
|
| 119 |
-
|
| 120 |
-
def forward(self, x):
|
| 121 |
-
if x.ndim == 4 and self.combine_dims: x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
|
| 122 |
-
if self.max_norm is not None: self.w.weight.data = torch.renorm(self.w.weight.data, p=2, dim=0, maxnorm=self.max_norm)
|
| 123 |
-
|
| 124 |
-
return self.w(x)
|
| 125 |
-
|
| 126 |
-
class Conv1d(_Conv1d):
|
| 127 |
-
def __init__(self, *args, **kwargs):
|
| 128 |
-
super().__init__(skip_transpose=True, *args, **kwargs)
|
| 129 |
-
|
| 130 |
-
class BatchNorm1d(_BatchNorm1d):
|
| 131 |
-
def __init__(self, *args, **kwargs):
|
| 132 |
-
super().__init__(skip_transpose=True, *args, **kwargs)
|
| 133 |
-
|
| 134 |
-
class TDNNBlock(nn.Module):
|
| 135 |
-
def __init__(self, in_channels, out_channels, kernel_size, dilation, activation=nn.ReLU, groups=1, dropout=0.0):
|
| 136 |
-
super().__init__()
|
| 137 |
-
self.conv = Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, dilation=dilation, groups=groups)
|
| 138 |
-
self.activation = activation()
|
| 139 |
-
self.norm = BatchNorm1d(input_size=out_channels)
|
| 140 |
-
self.dropout = nn.Dropout1d(p=dropout)
|
| 141 |
-
|
| 142 |
-
def forward(self, x):
|
| 143 |
-
return self.dropout(self.norm(self.activation(self.conv(x))))
|
| 144 |
-
|
| 145 |
-
class Res2NetBlock(torch.nn.Module):
|
| 146 |
-
def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1, dropout=0.0):
|
| 147 |
-
super().__init__()
|
| 148 |
-
assert in_channels % scale == 0
|
| 149 |
-
assert out_channels % scale == 0
|
| 150 |
-
in_channel = in_channels // scale
|
| 151 |
-
hidden_channel = out_channels // scale
|
| 152 |
-
self.blocks = nn.ModuleList([TDNNBlock(in_channel, hidden_channel, kernel_size=kernel_size, dilation=dilation, dropout=dropout) for _ in range(scale - 1)])
|
| 153 |
-
self.scale = scale
|
| 154 |
-
|
| 155 |
-
def forward(self, x):
|
| 156 |
-
y = []
|
| 157 |
-
|
| 158 |
-
for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
|
| 159 |
-
if i == 0: y_i = x_i
|
| 160 |
-
elif i == 1: y_i = self.blocks[i - 1](x_i)
|
| 161 |
-
else: y_i = self.blocks[i - 1](x_i + y_i)
|
| 162 |
-
|
| 163 |
-
y.append(y_i)
|
| 164 |
-
|
| 165 |
-
return torch.cat(y, dim=1)
|
| 166 |
-
|
| 167 |
-
class SEBlock(nn.Module):
|
| 168 |
-
def __init__(self, in_channels, se_channels, out_channels):
|
| 169 |
-
super().__init__()
|
| 170 |
-
|
| 171 |
-
self.conv1 = Conv1d(in_channels=in_channels, out_channels=se_channels, kernel_size=1)
|
| 172 |
-
self.relu = torch.nn.ReLU(inplace=True)
|
| 173 |
-
self.conv2 = Conv1d(in_channels=se_channels, out_channels=out_channels, kernel_size=1)
|
| 174 |
-
self.sigmoid = torch.nn.Sigmoid()
|
| 175 |
-
|
| 176 |
-
def forward(self, x, lengths=None):
|
| 177 |
-
L = x.shape[-1]
|
| 178 |
-
|
| 179 |
-
if lengths is not None:
|
| 180 |
-
mask = length_to_mask(lengths * L, max_len=L, device=x.device).unsqueeze(1)
|
| 181 |
-
s = (x * mask).sum(dim=2, keepdim=True) / mask.sum(dim=2, keepdim=True)
|
| 182 |
-
else: s = x.mean(dim=2, keepdim=True)
|
| 183 |
-
|
| 184 |
-
return self.sigmoid(self.conv2(self.relu(self.conv1(s)))) * x
|
| 185 |
-
|
| 186 |
-
class AttentiveStatisticsPooling(nn.Module):
|
| 187 |
-
def __init__(self, channels, attention_channels=128, global_context=True):
|
| 188 |
-
super().__init__()
|
| 189 |
-
self.eps = 1e-12
|
| 190 |
-
self.global_context = global_context
|
| 191 |
-
self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1) if global_context else TDNNBlock(channels, attention_channels, 1, 1)
|
| 192 |
-
self.tanh = nn.Tanh()
|
| 193 |
-
self.conv = Conv1d(in_channels=attention_channels, out_channels=channels, kernel_size=1)
|
| 194 |
-
|
| 195 |
-
def forward(self, x, lengths=None):
|
| 196 |
-
L = x.shape[-1]
|
| 197 |
-
|
| 198 |
-
def _compute_statistics(x, m, dim=2, eps=self.eps):
|
| 199 |
-
mean = (m * x).sum(dim)
|
| 200 |
-
return mean, torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
|
| 201 |
-
|
| 202 |
-
if lengths is None: lengths = torch.ones(x.shape[0], device=x.device)
|
| 203 |
-
mask = length_to_mask(lengths * L, max_len=L, device=x.device).unsqueeze(1)
|
| 204 |
-
|
| 205 |
-
if self.global_context:
|
| 206 |
-
mean, std = _compute_statistics(x, mask / mask.sum(dim=2, keepdim=True).float())
|
| 207 |
-
attn = torch.cat([x, mean.unsqueeze(2).repeat(1, 1, L), std.unsqueeze(2).repeat(1, 1, L)], dim=1)
|
| 208 |
-
else: attn = x
|
| 209 |
-
|
| 210 |
-
mean, std = _compute_statistics(x, F.softmax(self.conv(self.tanh(self.tdnn(attn))).masked_fill(mask == 0, float("-inf")), dim=2))
|
| 211 |
-
return torch.cat((mean, std), dim=1).unsqueeze(2)
|
| 212 |
-
|
| 213 |
-
class SERes2NetBlock(nn.Module):
|
| 214 |
-
def __init__(self, in_channels, out_channels, res2net_scale=8, se_channels=128, kernel_size=1, dilation=1, activation=torch.nn.ReLU, groups=1, dropout=0.0):
|
| 215 |
-
super().__init__()
|
| 216 |
-
self.out_channels = out_channels
|
| 217 |
-
self.tdnn1 = TDNNBlock(in_channels, out_channels, kernel_size=1, dilation=1, activation=activation, groups=groups, dropout=dropout)
|
| 218 |
-
self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation)
|
| 219 |
-
self.tdnn2 = TDNNBlock(out_channels, out_channels, kernel_size=1, dilation=1, activation=activation, groups=groups, dropout=dropout)
|
| 220 |
-
self.se_block = SEBlock(out_channels, se_channels, out_channels)
|
| 221 |
-
|
| 222 |
-
self.shortcut = None
|
| 223 |
-
if in_channels != out_channels: self.shortcut = Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
|
| 224 |
-
|
| 225 |
-
def forward(self, x, lengths=None):
|
| 226 |
-
residual = x
|
| 227 |
-
if self.shortcut: residual = self.shortcut(x)
|
| 228 |
-
|
| 229 |
-
return self.se_block(self.tdnn2(self.res2net_block(self.tdnn1(x))), lengths) + residual
|
| 230 |
-
|
| 231 |
-
class ECAPA_TDNN(torch.nn.Module):
|
| 232 |
-
def __init__(self, input_size, device="cpu", lin_neurons=192, activation=torch.nn.ReLU, channels=[512, 512, 512, 512, 1536], kernel_sizes=[5, 3, 3, 3, 1], dilations=[1, 2, 3, 4, 1], attention_channels=128, res2net_scale=8, se_channels=128, global_context=True, groups=[1, 1, 1, 1, 1], dropout=0.0):
|
| 233 |
-
super().__init__()
|
| 234 |
-
assert len(channels) == len(kernel_sizes)
|
| 235 |
-
assert len(channels) == len(dilations)
|
| 236 |
-
|
| 237 |
-
self.channels = channels
|
| 238 |
-
self.blocks = nn.ModuleList()
|
| 239 |
-
|
| 240 |
-
self.blocks.append(TDNNBlock(input_size, channels[0], kernel_sizes[0], dilations[0], activation, groups[0], dropout))
|
| 241 |
-
|
| 242 |
-
for i in range(1, len(channels) - 1):
|
| 243 |
-
self.blocks.append(SERes2NetBlock(channels[i - 1], channels[i], res2net_scale=res2net_scale, se_channels=se_channels, kernel_size=kernel_sizes[i], dilation=dilations[i], activation=activation, groups=groups[i], dropout=dropout))
|
| 244 |
-
|
| 245 |
-
self.mfa = TDNNBlock(channels[-2] * (len(channels) - 2), channels[-1], kernel_sizes[-1], dilations[-1], activation, groups=groups[-1], dropout=dropout)
|
| 246 |
-
self.asp = AttentiveStatisticsPooling(channels[-1], attention_channels=attention_channels, global_context=global_context)
|
| 247 |
-
self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
|
| 248 |
-
self.fc = Conv1d(in_channels=channels[-1] * 2, out_channels=lin_neurons, kernel_size=1)
|
| 249 |
-
|
| 250 |
-
def forward(self, x, lengths=None):
|
| 251 |
-
x = x.transpose(1, 2)
|
| 252 |
-
|
| 253 |
-
xl = []
|
| 254 |
-
for layer in self.blocks:
|
| 255 |
-
try:
|
| 256 |
-
x = layer(x, lengths=lengths)
|
| 257 |
-
except TypeError:
|
| 258 |
-
x = layer(x)
|
| 259 |
-
|
| 260 |
-
xl.append(x)
|
| 261 |
-
|
| 262 |
-
return self.fc(self.asp_bn(self.asp(self.mfa(torch.cat(xl[1:], dim=1)), lengths=lengths))).transpose(1, 2)
|
| 263 |
-
|
| 264 |
-
class Classifier(torch.nn.Module):
|
| 265 |
-
def __init__(self, input_size, device="cpu", lin_blocks=0, lin_neurons=192, out_neurons=1211):
|
| 266 |
-
super().__init__()
|
| 267 |
-
self.blocks = nn.ModuleList()
|
| 268 |
-
|
| 269 |
-
for _ in range(lin_blocks):
|
| 270 |
-
self.blocks.extend([_BatchNorm1d(input_size=input_size), Linear(input_size=input_size, n_neurons=lin_neurons)])
|
| 271 |
-
input_size = lin_neurons
|
| 272 |
-
|
| 273 |
-
self.weight = nn.Parameter(torch.FloatTensor(out_neurons, input_size, device=device))
|
| 274 |
-
nn.init.xavier_uniform_(self.weight)
|
| 275 |
-
|
| 276 |
-
def forward(self, x):
|
| 277 |
-
for layer in self.blocks:
|
| 278 |
-
x = layer(x)
|
| 279 |
-
|
| 280 |
-
return F.linear(F.normalize(x.squeeze(1)), F.normalize(self.weight)).unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/speaker_diarization/audio.py
DELETED
|
@@ -1,170 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import math
|
| 3 |
-
import random
|
| 4 |
-
import torchaudio
|
| 5 |
-
|
| 6 |
-
from io import IOBase
|
| 7 |
-
from torch.nn.functional import pad
|
| 8 |
-
|
| 9 |
-
def get_torchaudio_info(file, backend = None):
|
| 10 |
-
if not backend:
|
| 11 |
-
backends = (torchaudio.list_audio_backends())
|
| 12 |
-
backend = "soundfile" if "soundfile" in backends else backends[0]
|
| 13 |
-
|
| 14 |
-
info = torchaudio.info(file["audio"], backend=backend)
|
| 15 |
-
if isinstance(file["audio"], IOBase): file["audio"].seek(0)
|
| 16 |
-
|
| 17 |
-
return info
|
| 18 |
-
|
| 19 |
-
class Audio:
|
| 20 |
-
@staticmethod
|
| 21 |
-
def power_normalize(waveform):
|
| 22 |
-
return waveform / (waveform.square().mean(dim=-1, keepdim=True).sqrt() + 1e-8)
|
| 23 |
-
|
| 24 |
-
@staticmethod
|
| 25 |
-
def validate_file(file):
|
| 26 |
-
if isinstance(file, (str, os.PathLike)): file = {"audio": str(file), "uri": os.path.splitext(os.path.basename(file))[0]}
|
| 27 |
-
elif isinstance(file, IOBase): return {"audio": file, "uri": "stream"}
|
| 28 |
-
else: raise ValueError
|
| 29 |
-
|
| 30 |
-
if "waveform" in file:
|
| 31 |
-
waveform = file["waveform"]
|
| 32 |
-
if len(waveform.shape) != 2 or waveform.shape[0] > waveform.shape[1]: raise ValueError
|
| 33 |
-
|
| 34 |
-
sample_rate: int = file.get("sample_rate", None)
|
| 35 |
-
if sample_rate is None: raise ValueError
|
| 36 |
-
|
| 37 |
-
file.setdefault("uri", "waveform")
|
| 38 |
-
|
| 39 |
-
elif "audio" in file:
|
| 40 |
-
if isinstance(file["audio"], IOBase): return file
|
| 41 |
-
|
| 42 |
-
path = os.path.abspath(file["audio"])
|
| 43 |
-
file.setdefault("uri", os.path.splitext(os.path.basename(path))[0])
|
| 44 |
-
|
| 45 |
-
else: raise ValueError
|
| 46 |
-
|
| 47 |
-
return file
|
| 48 |
-
|
| 49 |
-
def __init__(self, sample_rate: int = None, mono=None, backend: str = None):
|
| 50 |
-
super().__init__()
|
| 51 |
-
self.sample_rate = sample_rate
|
| 52 |
-
self.mono = mono
|
| 53 |
-
|
| 54 |
-
if not backend:
|
| 55 |
-
backends = (torchaudio.list_audio_backends())
|
| 56 |
-
backend = "soundfile" if "soundfile" in backends else backends[0]
|
| 57 |
-
|
| 58 |
-
self.backend = backend
|
| 59 |
-
|
| 60 |
-
def downmix_and_resample(self, waveform, sample_rate):
|
| 61 |
-
num_channels = waveform.shape[0]
|
| 62 |
-
|
| 63 |
-
if num_channels > 1:
|
| 64 |
-
if self.mono == "random":
|
| 65 |
-
channel = random.randint(0, num_channels - 1)
|
| 66 |
-
waveform = waveform[channel : channel + 1]
|
| 67 |
-
elif self.mono == "downmix": waveform = waveform.mean(dim=0, keepdim=True)
|
| 68 |
-
|
| 69 |
-
if (self.sample_rate is not None) and (self.sample_rate != sample_rate):
|
| 70 |
-
waveform = torchaudio.functional.resample(waveform, sample_rate, self.sample_rate)
|
| 71 |
-
sample_rate = self.sample_rate
|
| 72 |
-
|
| 73 |
-
return waveform, sample_rate
|
| 74 |
-
|
| 75 |
-
def get_duration(self, file):
|
| 76 |
-
file = self.validate_file(file)
|
| 77 |
-
|
| 78 |
-
if "waveform" in file:
|
| 79 |
-
frames = len(file["waveform"].T)
|
| 80 |
-
sample_rate = file["sample_rate"]
|
| 81 |
-
else:
|
| 82 |
-
info = file["torchaudio.info"] if "torchaudio.info" in file else get_torchaudio_info(file, backend=self.backend)
|
| 83 |
-
frames = info.num_frames
|
| 84 |
-
sample_rate = info.sample_rate
|
| 85 |
-
|
| 86 |
-
return frames / sample_rate
|
| 87 |
-
|
| 88 |
-
def get_num_samples(self, duration, sample_rate = None):
|
| 89 |
-
sample_rate = sample_rate or self.sample_rate
|
| 90 |
-
if sample_rate is None: raise ValueError
|
| 91 |
-
|
| 92 |
-
return math.floor(duration * sample_rate)
|
| 93 |
-
|
| 94 |
-
def __call__(self, file):
|
| 95 |
-
file = self.validate_file(file)
|
| 96 |
-
|
| 97 |
-
if "waveform" in file:
|
| 98 |
-
waveform = file["waveform"]
|
| 99 |
-
sample_rate = file["sample_rate"]
|
| 100 |
-
elif "audio" in file:
|
| 101 |
-
waveform, sample_rate = torchaudio.load(file["audio"], backend=self.backend)
|
| 102 |
-
if isinstance(file["audio"], IOBase): file["audio"].seek(0)
|
| 103 |
-
|
| 104 |
-
channel = file.get("channel", None)
|
| 105 |
-
if channel is not None: waveform = waveform[channel : channel + 1]
|
| 106 |
-
|
| 107 |
-
return self.downmix_and_resample(waveform, sample_rate)
|
| 108 |
-
|
| 109 |
-
def crop(self, file, segment, duration = None, mode="raise"):
|
| 110 |
-
file = self.validate_file(file)
|
| 111 |
-
|
| 112 |
-
if "waveform" in file:
|
| 113 |
-
waveform = file["waveform"]
|
| 114 |
-
frames = waveform.shape[1]
|
| 115 |
-
sample_rate = file["sample_rate"]
|
| 116 |
-
elif "torchaudio.info" in file:
|
| 117 |
-
info = file["torchaudio.info"]
|
| 118 |
-
frames = info.num_frames
|
| 119 |
-
sample_rate = info.sample_rate
|
| 120 |
-
else:
|
| 121 |
-
info = get_torchaudio_info(file, backend=self.backend)
|
| 122 |
-
frames = info.num_frames
|
| 123 |
-
sample_rate = info.sample_rate
|
| 124 |
-
|
| 125 |
-
channel = file.get("channel", None)
|
| 126 |
-
start_frame = math.floor(segment.start * sample_rate)
|
| 127 |
-
|
| 128 |
-
if duration:
|
| 129 |
-
num_frames = math.floor(duration * sample_rate)
|
| 130 |
-
end_frame = start_frame + num_frames
|
| 131 |
-
else:
|
| 132 |
-
end_frame = math.floor(segment.end * sample_rate)
|
| 133 |
-
num_frames = end_frame - start_frame
|
| 134 |
-
|
| 135 |
-
if mode == "raise":
|
| 136 |
-
if num_frames > frames: raise ValueError
|
| 137 |
-
|
| 138 |
-
if end_frame > frames + math.ceil(0.001 * sample_rate): raise ValueError
|
| 139 |
-
else:
|
| 140 |
-
end_frame = min(end_frame, frames)
|
| 141 |
-
start_frame = end_frame - num_frames
|
| 142 |
-
|
| 143 |
-
if start_frame < 0: raise ValueError
|
| 144 |
-
elif mode == "pad":
|
| 145 |
-
pad_start = -min(0, start_frame)
|
| 146 |
-
pad_end = max(end_frame, frames) - frames
|
| 147 |
-
|
| 148 |
-
start_frame = max(0, start_frame)
|
| 149 |
-
end_frame = min(end_frame, frames)
|
| 150 |
-
|
| 151 |
-
num_frames = end_frame - start_frame
|
| 152 |
-
|
| 153 |
-
if "waveform" in file: data = file["waveform"][:, start_frame:end_frame]
|
| 154 |
-
else:
|
| 155 |
-
try:
|
| 156 |
-
data, _ = torchaudio.load(file["audio"], frame_offset=start_frame, num_frames=num_frames, backend=self.backend)
|
| 157 |
-
if isinstance(file["audio"], IOBase): file["audio"].seek(0)
|
| 158 |
-
except RuntimeError:
|
| 159 |
-
if isinstance(file["audio"], IOBase): raise RuntimeError
|
| 160 |
-
|
| 161 |
-
waveform, sample_rate = self.__call__(file)
|
| 162 |
-
data = waveform[:, start_frame:end_frame]
|
| 163 |
-
|
| 164 |
-
file["waveform"] = waveform
|
| 165 |
-
file["sample_rate"] = sample_rate
|
| 166 |
-
|
| 167 |
-
if channel is not None: data = data[channel : channel + 1, :]
|
| 168 |
-
if mode == "pad": data = pad(data, (pad_start, pad_end))
|
| 169 |
-
|
| 170 |
-
return self.downmix_and_resample(data, sample_rate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/speaker_diarization/embedding.py
DELETED
|
@@ -1,90 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
|
| 8 |
-
from functools import cached_property
|
| 9 |
-
from torch.nn.utils.rnn import pad_sequence
|
| 10 |
-
|
| 11 |
-
sys.path.append(os.getcwd())
|
| 12 |
-
|
| 13 |
-
from main.library.speaker_diarization.speechbrain import EncoderClassifier
|
| 14 |
-
|
| 15 |
-
class BaseInference:
|
| 16 |
-
pass
|
| 17 |
-
|
| 18 |
-
class SpeechBrainPretrainedSpeakerEmbedding(BaseInference):
|
| 19 |
-
def __init__(self, embedding = "assets/models/speaker_diarization/models/speechbrain", device = None):
|
| 20 |
-
super().__init__()
|
| 21 |
-
|
| 22 |
-
self.embedding = embedding
|
| 23 |
-
self.device = device or torch.device("cpu")
|
| 24 |
-
self.classifier_ = EncoderClassifier.from_hparams(source=self.embedding, run_opts={"device": self.device})
|
| 25 |
-
|
| 26 |
-
def to(self, device):
|
| 27 |
-
if not isinstance(device, torch.device): raise TypeError
|
| 28 |
-
|
| 29 |
-
self.classifier_ = EncoderClassifier.from_hparams(source=self.embedding, run_opts={"device": device})
|
| 30 |
-
self.device = device
|
| 31 |
-
return self
|
| 32 |
-
|
| 33 |
-
@cached_property
|
| 34 |
-
def sample_rate(self):
|
| 35 |
-
return self.classifier_.audio_normalizer.sample_rate
|
| 36 |
-
|
| 37 |
-
@cached_property
|
| 38 |
-
def dimension(self):
|
| 39 |
-
*_, dimension = self.classifier_.encode_batch(torch.rand(1, 16000).to(self.device)).shape
|
| 40 |
-
return dimension
|
| 41 |
-
|
| 42 |
-
@cached_property
|
| 43 |
-
def metric(self):
|
| 44 |
-
return "cosine"
|
| 45 |
-
|
| 46 |
-
@cached_property
|
| 47 |
-
def min_num_samples(self):
|
| 48 |
-
with torch.inference_mode():
|
| 49 |
-
lower, upper = 2, round(0.5 * self.sample_rate)
|
| 50 |
-
middle = (lower + upper) // 2
|
| 51 |
-
|
| 52 |
-
while lower + 1 < upper:
|
| 53 |
-
try:
|
| 54 |
-
_ = self.classifier_.encode_batch(torch.randn(1, middle).to(self.device))
|
| 55 |
-
upper = middle
|
| 56 |
-
except RuntimeError:
|
| 57 |
-
lower = middle
|
| 58 |
-
|
| 59 |
-
middle = (lower + upper) // 2
|
| 60 |
-
|
| 61 |
-
return upper
|
| 62 |
-
|
| 63 |
-
def __call__(self, waveforms, masks = None):
|
| 64 |
-
batch_size, num_channels, num_samples = waveforms.shape
|
| 65 |
-
assert num_channels == 1
|
| 66 |
-
|
| 67 |
-
waveforms = waveforms.squeeze(dim=1)
|
| 68 |
-
|
| 69 |
-
if masks is None:
|
| 70 |
-
signals = waveforms.squeeze(dim=1)
|
| 71 |
-
wav_lens = signals.shape[1] * torch.ones(batch_size)
|
| 72 |
-
else:
|
| 73 |
-
batch_size_masks, _ = masks.shape
|
| 74 |
-
assert batch_size == batch_size_masks
|
| 75 |
-
|
| 76 |
-
imasks = F.interpolate(masks.unsqueeze(dim=1), size=num_samples, mode="nearest").squeeze(dim=1) > 0.5
|
| 77 |
-
signals = pad_sequence([waveform[imask].contiguous() for waveform, imask in zip(waveforms, imasks)], batch_first=True)
|
| 78 |
-
wav_lens = imasks.sum(dim=1)
|
| 79 |
-
|
| 80 |
-
max_len = wav_lens.max()
|
| 81 |
-
if max_len < self.min_num_samples: return np.nan * np.zeros((batch_size, self.dimension))
|
| 82 |
-
|
| 83 |
-
too_short = wav_lens < self.min_num_samples
|
| 84 |
-
wav_lens = wav_lens / max_len
|
| 85 |
-
wav_lens[too_short] = 1.0
|
| 86 |
-
|
| 87 |
-
embeddings = (self.classifier_.encode_batch(signals, wav_lens=wav_lens).squeeze(dim=1).cpu().numpy())
|
| 88 |
-
embeddings[too_short.cpu().numpy()] = np.nan
|
| 89 |
-
|
| 90 |
-
return embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/speaker_diarization/encoder.py
DELETED
|
@@ -1,250 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import ast
|
| 4 |
-
import torch
|
| 5 |
-
import itertools
|
| 6 |
-
import collections
|
| 7 |
-
|
| 8 |
-
sys.path.append(os.getcwd())
|
| 9 |
-
|
| 10 |
-
from main.library.speaker_diarization.speechbrain import if_main_process, ddp_barrier
|
| 11 |
-
from main.library.speaker_diarization.features import register_checkpoint_hooks, mark_as_saver, mark_as_loader
|
| 12 |
-
|
| 13 |
-
DEFAULT_UNK = "<unk>"
|
| 14 |
-
DEFAULT_BOS = "<bos>"
|
| 15 |
-
DEFAULT_EOS = "<eos>"
|
| 16 |
-
DEFAULT_BLANK = "<blank>"
|
| 17 |
-
|
| 18 |
-
@register_checkpoint_hooks
|
| 19 |
-
class CategoricalEncoder:
|
| 20 |
-
VALUE_SEPARATOR = " => "
|
| 21 |
-
EXTRAS_SEPARATOR = "================\n"
|
| 22 |
-
|
| 23 |
-
def __init__(self, starting_index=0, **special_labels):
|
| 24 |
-
self.lab2ind = {}
|
| 25 |
-
self.ind2lab = {}
|
| 26 |
-
self.starting_index = starting_index
|
| 27 |
-
self.handle_special_labels(special_labels)
|
| 28 |
-
|
| 29 |
-
def handle_special_labels(self, special_labels):
|
| 30 |
-
if "unk_label" in special_labels: self.add_unk(special_labels["unk_label"])
|
| 31 |
-
|
| 32 |
-
def __len__(self):
|
| 33 |
-
return len(self.lab2ind)
|
| 34 |
-
|
| 35 |
-
@classmethod
|
| 36 |
-
def from_saved(cls, path):
|
| 37 |
-
obj = cls()
|
| 38 |
-
obj.load(path)
|
| 39 |
-
return obj
|
| 40 |
-
|
| 41 |
-
def update_from_iterable(self, iterable, sequence_input=False):
|
| 42 |
-
label_iterator = itertools.chain.from_iterable(iterable) if sequence_input else iter(iterable)
|
| 43 |
-
for label in label_iterator:
|
| 44 |
-
self.ensure_label(label)
|
| 45 |
-
|
| 46 |
-
def update_from_didataset(self, didataset, output_key, sequence_input=False):
|
| 47 |
-
with didataset.output_keys_as([output_key]):
|
| 48 |
-
self.update_from_iterable((data_point[output_key] for data_point in didataset), sequence_input=sequence_input)
|
| 49 |
-
|
| 50 |
-
def limited_labelset_from_iterable(self, iterable, sequence_input=False, n_most_common=None, min_count=1):
|
| 51 |
-
label_iterator = itertools.chain.from_iterable(iterable) if sequence_input else iter(iterable)
|
| 52 |
-
counts = collections.Counter(label_iterator)
|
| 53 |
-
|
| 54 |
-
for label, count in counts.most_common(n_most_common):
|
| 55 |
-
if count < min_count: break
|
| 56 |
-
self.add_label(label)
|
| 57 |
-
|
| 58 |
-
return counts
|
| 59 |
-
|
| 60 |
-
def load_or_create(self, path, from_iterables=[], from_didatasets=[], sequence_input=False, output_key=None, special_labels={}):
|
| 61 |
-
try:
|
| 62 |
-
if if_main_process():
|
| 63 |
-
if not self.load_if_possible(path):
|
| 64 |
-
for iterable in from_iterables:
|
| 65 |
-
self.update_from_iterable(iterable, sequence_input)
|
| 66 |
-
|
| 67 |
-
for didataset in from_didatasets:
|
| 68 |
-
if output_key is None: raise ValueError
|
| 69 |
-
self.update_from_didataset(didataset, output_key, sequence_input)
|
| 70 |
-
|
| 71 |
-
self.handle_special_labels(special_labels)
|
| 72 |
-
self.save(path)
|
| 73 |
-
finally:
|
| 74 |
-
ddp_barrier()
|
| 75 |
-
self.load(path)
|
| 76 |
-
|
| 77 |
-
def add_label(self, label):
|
| 78 |
-
if label in self.lab2ind: raise KeyError
|
| 79 |
-
index = self._next_index()
|
| 80 |
-
|
| 81 |
-
self.lab2ind[label] = index
|
| 82 |
-
self.ind2lab[index] = label
|
| 83 |
-
|
| 84 |
-
return index
|
| 85 |
-
|
| 86 |
-
def ensure_label(self, label):
|
| 87 |
-
if label in self.lab2ind: return self.lab2ind[label]
|
| 88 |
-
else: return self.add_label(label)
|
| 89 |
-
|
| 90 |
-
def insert_label(self, label, index):
|
| 91 |
-
if label in self.lab2ind: raise KeyError
|
| 92 |
-
else: self.enforce_label(label, index)
|
| 93 |
-
|
| 94 |
-
def enforce_label(self, label, index):
|
| 95 |
-
index = int(index)
|
| 96 |
-
|
| 97 |
-
if label in self.lab2ind:
|
| 98 |
-
if index == self.lab2ind[label]: return
|
| 99 |
-
else: del self.ind2lab[self.lab2ind[label]]
|
| 100 |
-
|
| 101 |
-
if index in self.ind2lab:
|
| 102 |
-
saved_label = self.ind2lab[index]
|
| 103 |
-
moving_other = True
|
| 104 |
-
else: moving_other = False
|
| 105 |
-
|
| 106 |
-
self.lab2ind[label] = index
|
| 107 |
-
self.ind2lab[index] = label
|
| 108 |
-
|
| 109 |
-
if moving_other:
|
| 110 |
-
new_index = self._next_index()
|
| 111 |
-
self.lab2ind[saved_label] = new_index
|
| 112 |
-
self.ind2lab[new_index] = saved_label
|
| 113 |
-
|
| 114 |
-
def add_unk(self, unk_label=DEFAULT_UNK):
|
| 115 |
-
self.unk_label = unk_label
|
| 116 |
-
return self.add_label(unk_label)
|
| 117 |
-
|
| 118 |
-
def _next_index(self):
|
| 119 |
-
index = self.starting_index
|
| 120 |
-
while index in self.ind2lab:
|
| 121 |
-
index += 1
|
| 122 |
-
|
| 123 |
-
return index
|
| 124 |
-
|
| 125 |
-
def is_continuous(self):
|
| 126 |
-
indices = sorted(self.ind2lab.keys())
|
| 127 |
-
return self.starting_index in indices and all(j - i == 1 for i, j in zip(indices[:-1], indices[1:]))
|
| 128 |
-
|
| 129 |
-
def encode_label(self, label, allow_unk=True):
|
| 130 |
-
self._assert_len()
|
| 131 |
-
|
| 132 |
-
try:
|
| 133 |
-
return self.lab2ind[label]
|
| 134 |
-
except KeyError:
|
| 135 |
-
if hasattr(self, "unk_label") and allow_unk: return self.lab2ind[self.unk_label]
|
| 136 |
-
elif hasattr(self, "unk_label") and not allow_unk: raise KeyError
|
| 137 |
-
elif not hasattr(self, "unk_label") and allow_unk: raise KeyError
|
| 138 |
-
else: raise KeyError
|
| 139 |
-
|
| 140 |
-
def encode_label_torch(self, label, allow_unk=True):
|
| 141 |
-
return torch.LongTensor([self.encode_label(label, allow_unk)])
|
| 142 |
-
|
| 143 |
-
def encode_sequence(self, sequence, allow_unk=True):
|
| 144 |
-
self._assert_len()
|
| 145 |
-
return [self.encode_label(label, allow_unk) for label in sequence]
|
| 146 |
-
|
| 147 |
-
def encode_sequence_torch(self, sequence, allow_unk=True):
|
| 148 |
-
return torch.LongTensor([self.encode_label(label, allow_unk) for label in sequence])
|
| 149 |
-
|
| 150 |
-
def decode_torch(self, x):
|
| 151 |
-
self._assert_len()
|
| 152 |
-
decoded = []
|
| 153 |
-
|
| 154 |
-
if x.ndim == 1:
|
| 155 |
-
for element in x:
|
| 156 |
-
decoded.append(self.ind2lab[int(element)])
|
| 157 |
-
else:
|
| 158 |
-
for subtensor in x:
|
| 159 |
-
decoded.append(self.decode_torch(subtensor))
|
| 160 |
-
|
| 161 |
-
return decoded
|
| 162 |
-
|
| 163 |
-
def decode_ndim(self, x):
|
| 164 |
-
self._assert_len()
|
| 165 |
-
try:
|
| 166 |
-
decoded = []
|
| 167 |
-
for subtensor in x:
|
| 168 |
-
decoded.append(self.decode_ndim(subtensor))
|
| 169 |
-
|
| 170 |
-
return decoded
|
| 171 |
-
except TypeError:
|
| 172 |
-
return self.ind2lab[int(x)]
|
| 173 |
-
|
| 174 |
-
@mark_as_saver
|
| 175 |
-
def save(self, path):
|
| 176 |
-
self._save_literal(path, self.lab2ind, self._get_extras())
|
| 177 |
-
|
| 178 |
-
def load(self, path):
|
| 179 |
-
lab2ind, ind2lab, extras = self._load_literal(path)
|
| 180 |
-
self.lab2ind = lab2ind
|
| 181 |
-
self.ind2lab = ind2lab
|
| 182 |
-
self._set_extras(extras)
|
| 183 |
-
|
| 184 |
-
@mark_as_loader
|
| 185 |
-
def load_if_possible(self, path, end_of_epoch=False):
|
| 186 |
-
del end_of_epoch
|
| 187 |
-
|
| 188 |
-
try:
|
| 189 |
-
self.load(path)
|
| 190 |
-
except FileNotFoundError:
|
| 191 |
-
return False
|
| 192 |
-
except (ValueError, SyntaxError):
|
| 193 |
-
return False
|
| 194 |
-
|
| 195 |
-
return True
|
| 196 |
-
|
| 197 |
-
def expect_len(self, expected_len):
|
| 198 |
-
self.expected_len = expected_len
|
| 199 |
-
|
| 200 |
-
def ignore_len(self):
|
| 201 |
-
self.expected_len = None
|
| 202 |
-
|
| 203 |
-
def _assert_len(self):
|
| 204 |
-
if hasattr(self, "expected_len"):
|
| 205 |
-
if self.expected_len is None: return
|
| 206 |
-
if len(self) != self.expected_len: raise RuntimeError
|
| 207 |
-
else:
|
| 208 |
-
self.ignore_len()
|
| 209 |
-
return
|
| 210 |
-
|
| 211 |
-
def _get_extras(self):
|
| 212 |
-
extras = {"starting_index": self.starting_index}
|
| 213 |
-
if hasattr(self, "unk_label"): extras["unk_label"] = self.unk_label
|
| 214 |
-
|
| 215 |
-
return extras
|
| 216 |
-
|
| 217 |
-
def _set_extras(self, extras):
|
| 218 |
-
if "unk_label" in extras: self.unk_label = extras["unk_label"]
|
| 219 |
-
self.starting_index = extras["starting_index"]
|
| 220 |
-
|
| 221 |
-
@staticmethod
|
| 222 |
-
def _save_literal(path, lab2ind, extras):
|
| 223 |
-
with open(path, "w", encoding="utf-8") as f:
|
| 224 |
-
for label, ind in lab2ind.items():
|
| 225 |
-
f.write(repr(label) + CategoricalEncoder.VALUE_SEPARATOR + str(ind) + "\n")
|
| 226 |
-
|
| 227 |
-
f.write(CategoricalEncoder.EXTRAS_SEPARATOR)
|
| 228 |
-
|
| 229 |
-
for key, value in extras.items():
|
| 230 |
-
f.write(repr(key) + CategoricalEncoder.VALUE_SEPARATOR + repr(value) + "\n")
|
| 231 |
-
|
| 232 |
-
f.flush()
|
| 233 |
-
|
| 234 |
-
@staticmethod
|
| 235 |
-
def _load_literal(path):
|
| 236 |
-
lab2ind, ind2lab, extras = {}, {}, {}
|
| 237 |
-
|
| 238 |
-
with open(path, encoding="utf-8") as f:
|
| 239 |
-
for line in f:
|
| 240 |
-
if line == CategoricalEncoder.EXTRAS_SEPARATOR: break
|
| 241 |
-
literal, ind = line.strip().split(CategoricalEncoder.VALUE_SEPARATOR, maxsplit=1)
|
| 242 |
-
label = ast.literal_eval(literal)
|
| 243 |
-
lab2ind[label] = int(ind)
|
| 244 |
-
ind2lab[ind] = label
|
| 245 |
-
|
| 246 |
-
for line in f:
|
| 247 |
-
literal_key, literal_value = line.strip().split(CategoricalEncoder.VALUE_SEPARATOR, maxsplit=1)
|
| 248 |
-
extras[ast.literal_eval(literal_key)] = ast.literal_eval(literal_value)
|
| 249 |
-
|
| 250 |
-
return lab2ind, ind2lab, extras
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/speaker_diarization/features.py
DELETED
|
@@ -1,520 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import math
|
| 4 |
-
import torch
|
| 5 |
-
import inspect
|
| 6 |
-
import functools
|
| 7 |
-
|
| 8 |
-
sys.path.append(os.getcwd())
|
| 9 |
-
|
| 10 |
-
from main.library.speaker_diarization.speechbrain import MAIN_PROC_ONLY, is_distributed_initialized, main_process_only
|
| 11 |
-
|
| 12 |
-
KEYS_MAPPING = {".mutihead_attn": ".multihead_attn", ".convs_intermedite": ".convs_intermediate"}
|
| 13 |
-
|
| 14 |
-
def map_old_state_dict_weights(state_dict, mapping):
|
| 15 |
-
for replacement_old, replacement_new in mapping.items():
|
| 16 |
-
for old_key in list(state_dict.keys()):
|
| 17 |
-
if replacement_old in old_key: state_dict[old_key.replace(replacement_old, replacement_new)] = state_dict.pop(old_key)
|
| 18 |
-
|
| 19 |
-
return state_dict
|
| 20 |
-
|
| 21 |
-
def hook_on_loading_state_dict_checkpoint(state_dict):
|
| 22 |
-
return map_old_state_dict_weights(state_dict, KEYS_MAPPING)
|
| 23 |
-
|
| 24 |
-
def torch_patched_state_dict_load(path, device="cpu"):
|
| 25 |
-
return hook_on_loading_state_dict_checkpoint(torch.load(path, map_location=device))
|
| 26 |
-
|
| 27 |
-
@main_process_only
|
| 28 |
-
def torch_save(obj, path):
|
| 29 |
-
state_dict = obj.state_dict()
|
| 30 |
-
torch.save(state_dict, path)
|
| 31 |
-
|
| 32 |
-
def torch_recovery(obj, path, end_of_epoch):
|
| 33 |
-
del end_of_epoch
|
| 34 |
-
|
| 35 |
-
state_dict = torch_patched_state_dict_load(path, "cpu")
|
| 36 |
-
try:
|
| 37 |
-
obj.load_state_dict(state_dict, strict=True)
|
| 38 |
-
except TypeError:
|
| 39 |
-
obj.load_state_dict(state_dict)
|
| 40 |
-
|
| 41 |
-
def torch_parameter_transfer(obj, path):
|
| 42 |
-
incompatible_keys = obj.load_state_dict(torch_patched_state_dict_load(path, "cpu"), strict=False)
|
| 43 |
-
|
| 44 |
-
for missing_key in incompatible_keys.missing_keys:
|
| 45 |
-
pass
|
| 46 |
-
for unexpected_key in incompatible_keys.unexpected_keys:
|
| 47 |
-
pass
|
| 48 |
-
|
| 49 |
-
WEAKREF_MARKER = "WEAKREF"
|
| 50 |
-
|
| 51 |
-
def _cycliclrsaver(obj, path):
|
| 52 |
-
state_dict = obj.state_dict()
|
| 53 |
-
if state_dict.get("_scale_fn_ref") is not None: state_dict["_scale_fn_ref"] = WEAKREF_MARKER
|
| 54 |
-
|
| 55 |
-
torch.save(state_dict, path)
|
| 56 |
-
|
| 57 |
-
def _cycliclrloader(obj, path, end_of_epoch):
|
| 58 |
-
del end_of_epoch
|
| 59 |
-
|
| 60 |
-
try:
|
| 61 |
-
obj.load_state_dict(torch.load(path, map_location="cpu"), strict=True)
|
| 62 |
-
except TypeError:
|
| 63 |
-
obj.load_state_dict(torch.load(path, map_location="cpu"))
|
| 64 |
-
|
| 65 |
-
DEFAULT_LOAD_HOOKS = {torch.nn.Module: torch_recovery, torch.optim.Optimizer: torch_recovery, torch.optim.lr_scheduler.ReduceLROnPlateau: torch_recovery, torch.cuda.amp.grad_scaler.GradScaler: torch_recovery}
|
| 66 |
-
DEFAULT_SAVE_HOOKS = { torch.nn.Module: torch_save, torch.optim.Optimizer: torch_save, torch.optim.lr_scheduler.ReduceLROnPlateau: torch_save, torch.cuda.amp.grad_scaler.GradScaler: torch_save}
|
| 67 |
-
DEFAULT_LOAD_HOOKS[torch.optim.lr_scheduler.LRScheduler] = torch_recovery
|
| 68 |
-
DEFAULT_SAVE_HOOKS[torch.optim.lr_scheduler.LRScheduler] = torch_save
|
| 69 |
-
DEFAULT_TRANSFER_HOOKS = {torch.nn.Module: torch_parameter_transfer}
|
| 70 |
-
DEFAULT_SAVE_HOOKS[torch.optim.lr_scheduler.CyclicLR] = _cycliclrsaver
|
| 71 |
-
DEFAULT_LOAD_HOOKS[torch.optim.lr_scheduler.CyclicLR] = _cycliclrloader
|
| 72 |
-
|
| 73 |
-
def register_checkpoint_hooks(cls, save_on_main_only=True):
|
| 74 |
-
global DEFAULT_LOAD_HOOKS, DEFAULT_SAVE_HOOKS, DEFAULT_TRANSFER_HOOKS
|
| 75 |
-
|
| 76 |
-
for name, method in cls.__dict__.items():
|
| 77 |
-
if hasattr(method, "_speechbrain_saver"): DEFAULT_SAVE_HOOKS[cls] = main_process_only(method) if save_on_main_only else method
|
| 78 |
-
if hasattr(method, "_speechbrain_loader"): DEFAULT_LOAD_HOOKS[cls] = method
|
| 79 |
-
if hasattr(method, "_speechbrain_transfer"): DEFAULT_TRANSFER_HOOKS[cls] = method
|
| 80 |
-
|
| 81 |
-
return cls
|
| 82 |
-
|
| 83 |
-
def mark_as_saver(method):
|
| 84 |
-
sig = inspect.signature(method)
|
| 85 |
-
|
| 86 |
-
try:
|
| 87 |
-
sig.bind(object(), os.path.join("testpath"))
|
| 88 |
-
except TypeError:
|
| 89 |
-
raise TypeError
|
| 90 |
-
|
| 91 |
-
method._speechbrain_saver = True
|
| 92 |
-
return method
|
| 93 |
-
|
| 94 |
-
def mark_as_transfer(method):
|
| 95 |
-
sig = inspect.signature(method)
|
| 96 |
-
|
| 97 |
-
try:
|
| 98 |
-
sig.bind(object(), os.path.join("testpath"))
|
| 99 |
-
except TypeError:
|
| 100 |
-
raise TypeError
|
| 101 |
-
|
| 102 |
-
method._speechbrain_transfer = True
|
| 103 |
-
return method
|
| 104 |
-
|
| 105 |
-
def mark_as_loader(method):
|
| 106 |
-
sig = inspect.signature(method)
|
| 107 |
-
|
| 108 |
-
try:
|
| 109 |
-
sig.bind(object(), os.path.join("testpath"), True)
|
| 110 |
-
except TypeError:
|
| 111 |
-
raise TypeError
|
| 112 |
-
|
| 113 |
-
method._speechbrain_loader = True
|
| 114 |
-
return method
|
| 115 |
-
|
| 116 |
-
def ddp_all_reduce(communication_object, reduce_op):
|
| 117 |
-
if MAIN_PROC_ONLY >= 1 or not is_distributed_initialized(): return communication_object
|
| 118 |
-
torch.distributed.all_reduce(communication_object, op=reduce_op)
|
| 119 |
-
|
| 120 |
-
return communication_object
|
| 121 |
-
|
| 122 |
-
def fwd_default_precision(fwd = None, cast_inputs = torch.float32):
|
| 123 |
-
if fwd is None: return functools.partial(fwd_default_precision, cast_inputs=cast_inputs)
|
| 124 |
-
|
| 125 |
-
wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
|
| 126 |
-
|
| 127 |
-
@functools.wraps(fwd)
|
| 128 |
-
def wrapper(*args, force_allow_autocast = False, **kwargs):
|
| 129 |
-
return fwd(*args, **kwargs) if force_allow_autocast else wrapped_fwd(*args, **kwargs)
|
| 130 |
-
|
| 131 |
-
return wrapper
|
| 132 |
-
|
| 133 |
-
def spectral_magnitude(stft, power = 1, log = False, eps = 1e-14):
|
| 134 |
-
spectr = stft.pow(2).sum(-1)
|
| 135 |
-
|
| 136 |
-
if power < 1: spectr = spectr + eps
|
| 137 |
-
spectr = spectr.pow(power)
|
| 138 |
-
|
| 139 |
-
if log: return torch.log(spectr + eps)
|
| 140 |
-
return spectr
|
| 141 |
-
|
| 142 |
-
class Filterbank(torch.nn.Module):
|
| 143 |
-
def __init__(self, n_mels=40, log_mel=True, filter_shape="triangular", f_min=0, f_max=8000, n_fft=400, sample_rate=16000, power_spectrogram=2, amin=1e-10, ref_value=1.0, top_db=80.0, param_change_factor=1.0, param_rand_factor=0.0, freeze=True):
|
| 144 |
-
super().__init__()
|
| 145 |
-
self.n_mels = n_mels
|
| 146 |
-
self.log_mel = log_mel
|
| 147 |
-
self.filter_shape = filter_shape
|
| 148 |
-
self.f_min = f_min
|
| 149 |
-
self.f_max = f_max
|
| 150 |
-
self.n_fft = n_fft
|
| 151 |
-
self.sample_rate = sample_rate
|
| 152 |
-
self.power_spectrogram = power_spectrogram
|
| 153 |
-
self.amin = amin
|
| 154 |
-
self.ref_value = ref_value
|
| 155 |
-
self.top_db = top_db
|
| 156 |
-
self.freeze = freeze
|
| 157 |
-
self.n_stft = self.n_fft // 2 + 1
|
| 158 |
-
self.db_multiplier = math.log10(max(self.amin, self.ref_value))
|
| 159 |
-
self.device_inp = torch.device("cpu")
|
| 160 |
-
self.param_change_factor = param_change_factor
|
| 161 |
-
self.param_rand_factor = param_rand_factor
|
| 162 |
-
self.multiplier = 10 if self.power_spectrogram == 2 else 20
|
| 163 |
-
|
| 164 |
-
hz = self._to_hz(torch.linspace(self._to_mel(self.f_min), self._to_mel(self.f_max), self.n_mels + 2))
|
| 165 |
-
|
| 166 |
-
band = hz[1:] - hz[:-1]
|
| 167 |
-
self.band = band[:-1]
|
| 168 |
-
self.f_central = hz[1:-1]
|
| 169 |
-
|
| 170 |
-
if not self.freeze:
|
| 171 |
-
self.f_central = torch.nn.Parameter(self.f_central / (self.sample_rate * self.param_change_factor))
|
| 172 |
-
self.band = torch.nn.Parameter(self.band / (self.sample_rate * self.param_change_factor))
|
| 173 |
-
|
| 174 |
-
self.all_freqs_mat = torch.linspace(0, self.sample_rate // 2, self.n_stft).repeat(self.f_central.shape[0], 1)
|
| 175 |
-
|
| 176 |
-
def forward(self, spectrogram):
|
| 177 |
-
f_central_mat = self.f_central.repeat(self.all_freqs_mat.shape[1], 1).transpose(0, 1)
|
| 178 |
-
band_mat = self.band.repeat(self.all_freqs_mat.shape[1], 1).transpose(0, 1)
|
| 179 |
-
|
| 180 |
-
if not self.freeze:
|
| 181 |
-
f_central_mat = f_central_mat * (self.sample_rate * self.param_change_factor * self.param_change_factor)
|
| 182 |
-
band_mat = band_mat * (self.sample_rate * self.param_change_factor * self.param_change_factor)
|
| 183 |
-
elif self.param_rand_factor != 0 and self.training:
|
| 184 |
-
rand_change = (1.0 + torch.rand(2) * 2 * self.param_rand_factor - self.param_rand_factor)
|
| 185 |
-
f_central_mat = f_central_mat * rand_change[0]
|
| 186 |
-
band_mat = band_mat * rand_change[1]
|
| 187 |
-
|
| 188 |
-
fbank_matrix = self._create_fbank_matrix(f_central_mat, band_mat).to(spectrogram.device)
|
| 189 |
-
sp_shape = spectrogram.shape
|
| 190 |
-
if len(sp_shape) == 4: spectrogram = spectrogram.permute(0, 3, 1, 2).reshape(sp_shape[0] * sp_shape[3], sp_shape[1], sp_shape[2])
|
| 191 |
-
|
| 192 |
-
fbanks = torch.matmul(spectrogram, fbank_matrix)
|
| 193 |
-
if self.log_mel: fbanks = self._amplitude_to_DB(fbanks)
|
| 194 |
-
|
| 195 |
-
if len(sp_shape) == 4:
|
| 196 |
-
fb_shape = fbanks.shape
|
| 197 |
-
fbanks = fbanks.reshape(sp_shape[0], sp_shape[3], fb_shape[1], fb_shape[2]).permute(0, 2, 3, 1)
|
| 198 |
-
|
| 199 |
-
return fbanks
|
| 200 |
-
|
| 201 |
-
@staticmethod
|
| 202 |
-
def _to_mel(hz):
|
| 203 |
-
return 2595 * math.log10(1 + hz / 700)
|
| 204 |
-
|
| 205 |
-
@staticmethod
|
| 206 |
-
def _to_hz(mel):
|
| 207 |
-
return 700 * (10 ** (mel / 2595) - 1)
|
| 208 |
-
|
| 209 |
-
def _triangular_filters(self, all_freqs, f_central, band):
|
| 210 |
-
slope = (all_freqs - f_central) / band
|
| 211 |
-
return torch.max(torch.zeros(1, device=self.device_inp), torch.min(slope + 1.0, -slope + 1.0)).transpose(0, 1)
|
| 212 |
-
|
| 213 |
-
def _rectangular_filters(self, all_freqs, f_central, band):
|
| 214 |
-
left_side = right_size = all_freqs.ge(f_central - band)
|
| 215 |
-
right_size = all_freqs.le(f_central + band)
|
| 216 |
-
|
| 217 |
-
return (left_side * right_size).float().transpose(0, 1)
|
| 218 |
-
|
| 219 |
-
def _gaussian_filters(self, all_freqs, f_central, band, smooth_factor=torch.tensor(2)):
|
| 220 |
-
return torch.exp(-0.5 * ((all_freqs - f_central) / (band / smooth_factor)) ** 2).transpose(0, 1)
|
| 221 |
-
|
| 222 |
-
def _create_fbank_matrix(self, f_central_mat, band_mat):
|
| 223 |
-
if self.filter_shape == "triangular": fbank_matrix = self._triangular_filters(self.all_freqs_mat, f_central_mat, band_mat)
|
| 224 |
-
elif self.filter_shape == "rectangular": fbank_matrix = self._rectangular_filters(self.all_freqs_mat, f_central_mat, band_mat)
|
| 225 |
-
else: fbank_matrix = self._gaussian_filters(self.all_freqs_mat, f_central_mat, band_mat)
|
| 226 |
-
|
| 227 |
-
return fbank_matrix
|
| 228 |
-
|
| 229 |
-
def _amplitude_to_DB(self, x):
|
| 230 |
-
x_db = self.multiplier * torch.log10(torch.clamp(x, min=self.amin))
|
| 231 |
-
x_db -= self.multiplier * self.db_multiplier
|
| 232 |
-
|
| 233 |
-
return torch.max(x_db, (x_db.amax(dim=(-2, -1)) - self.top_db).view(x_db.shape[0], 1, 1))
|
| 234 |
-
|
| 235 |
-
class ContextWindow(torch.nn.Module):
|
| 236 |
-
def __init__(self, left_frames=0, right_frames=0):
|
| 237 |
-
super().__init__()
|
| 238 |
-
self.left_frames = left_frames
|
| 239 |
-
self.right_frames = right_frames
|
| 240 |
-
self.context_len = self.left_frames + self.right_frames + 1
|
| 241 |
-
self.kernel_len = 2 * max(self.left_frames, self.right_frames) + 1
|
| 242 |
-
self.kernel = torch.eye(self.context_len, self.kernel_len)
|
| 243 |
-
|
| 244 |
-
if self.right_frames > self.left_frames: self.kernel = torch.roll(self.kernel, self.right_frames - self.left_frames, 1)
|
| 245 |
-
self.first_call = True
|
| 246 |
-
|
| 247 |
-
def forward(self, x):
|
| 248 |
-
x = x.transpose(1, 2)
|
| 249 |
-
if self.first_call:
|
| 250 |
-
self.first_call = False
|
| 251 |
-
self.kernel = (self.kernel.repeat(x.shape[1], 1, 1).view(x.shape[1] * self.context_len, self.kernel_len).unsqueeze(1))
|
| 252 |
-
|
| 253 |
-
or_shape = x.shape
|
| 254 |
-
if len(or_shape) == 4: x = x.reshape(or_shape[0] * or_shape[2], or_shape[1], or_shape[3])
|
| 255 |
-
|
| 256 |
-
cw_x = torch.nn.functional.conv1d(x, self.kernel.to(x.device), groups=x.shape[1], padding=max(self.left_frames, self.right_frames))
|
| 257 |
-
if len(or_shape) == 4: cw_x = cw_x.reshape(or_shape[0], cw_x.shape[1], or_shape[2], cw_x.shape[-1])
|
| 258 |
-
|
| 259 |
-
return cw_x.transpose(1, 2)
|
| 260 |
-
|
| 261 |
-
class FilterProperties:
|
| 262 |
-
def __init__(self, window_size = 0, stride = 1, dilation = 1, causal = False):
|
| 263 |
-
self.window_size = window_size
|
| 264 |
-
self.stride = stride
|
| 265 |
-
self.dilation = dilation
|
| 266 |
-
self.causal = causal
|
| 267 |
-
|
| 268 |
-
def __post_init__(self):
|
| 269 |
-
assert self.window_size > 0
|
| 270 |
-
assert self.stride > 0
|
| 271 |
-
assert (self.dilation > 0)
|
| 272 |
-
|
| 273 |
-
@staticmethod
|
| 274 |
-
def pointwise_filter():
|
| 275 |
-
return FilterProperties(window_size=1, stride=1)
|
| 276 |
-
|
| 277 |
-
def get_effective_size(self):
|
| 278 |
-
return 1 + ((self.window_size - 1) * self.dilation)
|
| 279 |
-
|
| 280 |
-
def get_convolution_padding(self):
|
| 281 |
-
if self.window_size % 2 == 0: raise ValueError
|
| 282 |
-
if self.causal: return self.get_effective_size() - 1
|
| 283 |
-
|
| 284 |
-
return (self.get_effective_size() - 1) // 2
|
| 285 |
-
|
| 286 |
-
def get_noncausal_equivalent(self):
|
| 287 |
-
if not self.causal: return self
|
| 288 |
-
return FilterProperties(window_size=(self.window_size - 1) * 2 + 1, stride=self.stride, dilation=self.dilation, causal=False)
|
| 289 |
-
|
| 290 |
-
def with_on_top(self, other, allow_approximate=True):
|
| 291 |
-
self_size = self.window_size
|
| 292 |
-
|
| 293 |
-
if other.window_size % 2 == 0:
|
| 294 |
-
if allow_approximate: other_size = other.window_size + 1
|
| 295 |
-
else: raise ValueError
|
| 296 |
-
else: other_size = other.window_size
|
| 297 |
-
|
| 298 |
-
if (self.causal or other.causal) and not (self.causal and other.causal):
|
| 299 |
-
if allow_approximate: return self.get_noncausal_equivalent().with_on_top(other.get_noncausal_equivalent())
|
| 300 |
-
else: raise ValueError
|
| 301 |
-
|
| 302 |
-
return FilterProperties(self_size + (self.stride * (other_size - 1)), self.stride * other.stride, self.dilation * other.dilation, self.causal)
|
| 303 |
-
|
| 304 |
-
class STFT(torch.nn.Module):
|
| 305 |
-
def __init__(self, sample_rate, win_length=25, hop_length=10, n_fft=400, window_fn=torch.hamming_window, normalized_stft=False, center=True, pad_mode="constant", onesided=True):
|
| 306 |
-
super().__init__()
|
| 307 |
-
self.sample_rate = sample_rate
|
| 308 |
-
self.win_length = win_length
|
| 309 |
-
self.hop_length = hop_length
|
| 310 |
-
self.n_fft = n_fft
|
| 311 |
-
self.normalized_stft = normalized_stft
|
| 312 |
-
self.center = center
|
| 313 |
-
self.pad_mode = pad_mode
|
| 314 |
-
self.onesided = onesided
|
| 315 |
-
self.win_length = int(round((self.sample_rate / 1000.0) * self.win_length))
|
| 316 |
-
self.hop_length = int(round((self.sample_rate / 1000.0) * self.hop_length))
|
| 317 |
-
self.window = window_fn(self.win_length)
|
| 318 |
-
|
| 319 |
-
def forward(self, x):
|
| 320 |
-
or_shape = x.shape
|
| 321 |
-
if len(or_shape) == 3: x = x.transpose(1, 2).reshape(or_shape[0] * or_shape[2], or_shape[1])
|
| 322 |
-
|
| 323 |
-
stft = torch.view_as_real(torch.stft(x, self.n_fft, self.hop_length, self.win_length, self.window.to(x.device), self.center, self.pad_mode, self.normalized_stft, self.onesided, return_complex=True))
|
| 324 |
-
stft = stft.reshape(or_shape[0], or_shape[2], stft.shape[1], stft.shape[2], stft.shape[3]).permute(0, 3, 2, 4, 1) if len(or_shape) == 3 else stft.transpose(2, 1)
|
| 325 |
-
|
| 326 |
-
return stft
|
| 327 |
-
|
| 328 |
-
def get_filter_properties(self):
|
| 329 |
-
if not self.center: raise ValueError
|
| 330 |
-
return FilterProperties(window_size=self.win_length, stride=self.hop_length)
|
| 331 |
-
|
| 332 |
-
class Deltas(torch.nn.Module):
|
| 333 |
-
def __init__(self, input_size, window_length=5):
|
| 334 |
-
super().__init__()
|
| 335 |
-
self.n = (window_length - 1) // 2
|
| 336 |
-
self.denom = self.n * (self.n + 1) * (2 * self.n + 1) / 3
|
| 337 |
-
self.register_buffer("kernel", torch.arange(-self.n, self.n + 1, dtype=torch.float32).repeat(input_size, 1, 1),)
|
| 338 |
-
|
| 339 |
-
def forward(self, x):
|
| 340 |
-
x = x.transpose(1, 2).transpose(2, -1)
|
| 341 |
-
or_shape = x.shape
|
| 342 |
-
|
| 343 |
-
if len(or_shape) == 4: x = x.reshape(or_shape[0] * or_shape[2], or_shape[1], or_shape[3])
|
| 344 |
-
|
| 345 |
-
x = torch.nn.functional.pad(x, (self.n, self.n), mode="replicate")
|
| 346 |
-
delta_coeff = (torch.nn.functional.conv1d(x, self.kernel.to(x.device), groups=x.shape[1]) / self.denom)
|
| 347 |
-
|
| 348 |
-
if len(or_shape) == 4: delta_coeff = delta_coeff.reshape(or_shape[0], or_shape[1], or_shape[2], or_shape[3])
|
| 349 |
-
return delta_coeff.transpose(1, -1).transpose(2, -1)
|
| 350 |
-
|
| 351 |
-
class Fbank(torch.nn.Module):
|
| 352 |
-
def __init__(self, deltas=False, context=False, requires_grad=False, sample_rate=16000, f_min=0, f_max=None, n_fft=400, n_mels=40, filter_shape="triangular", param_change_factor=1.0, param_rand_factor=0.0, left_frames=5, right_frames=5, win_length=25, hop_length=10):
|
| 353 |
-
super().__init__()
|
| 354 |
-
self.deltas = deltas
|
| 355 |
-
self.context = context
|
| 356 |
-
self.requires_grad = requires_grad
|
| 357 |
-
if f_max is None: f_max = sample_rate / 2
|
| 358 |
-
self.compute_STFT = STFT(sample_rate=sample_rate,n_fft=n_fft,win_length=win_length,hop_length=hop_length)
|
| 359 |
-
self.compute_fbanks = Filterbank(sample_rate=sample_rate,n_fft=n_fft,n_mels=n_mels,f_min=f_min,f_max=f_max,freeze=not requires_grad,filter_shape=filter_shape,param_change_factor=param_change_factor,param_rand_factor=param_rand_factor)
|
| 360 |
-
self.compute_deltas = Deltas(input_size=n_mels)
|
| 361 |
-
self.context_window = ContextWindow(left_frames=left_frames, right_frames=right_frames)
|
| 362 |
-
|
| 363 |
-
@fwd_default_precision(cast_inputs=torch.float32)
|
| 364 |
-
def forward(self, wav):
|
| 365 |
-
fbanks = self.compute_fbanks(spectral_magnitude(self.compute_STFT(wav)))
|
| 366 |
-
if self.deltas:
|
| 367 |
-
delta1 = self.compute_deltas(fbanks)
|
| 368 |
-
fbanks = torch.cat([fbanks, delta1, self.compute_deltas(delta1)], dim=2)
|
| 369 |
-
|
| 370 |
-
if self.context: fbanks = self.context_window(fbanks)
|
| 371 |
-
return fbanks
|
| 372 |
-
|
| 373 |
-
def get_filter_properties(self):
|
| 374 |
-
return self.compute_STFT.get_filter_properties()
|
| 375 |
-
|
| 376 |
-
@register_checkpoint_hooks
|
| 377 |
-
class InputNormalization(torch.nn.Module):
|
| 378 |
-
def __init__(self, mean_norm=True, std_norm=True, norm_type="global", avg_factor=None, requires_grad=False, update_until_epoch=3):
|
| 379 |
-
super().__init__()
|
| 380 |
-
self.mean_norm = mean_norm
|
| 381 |
-
self.std_norm = std_norm
|
| 382 |
-
self.norm_type = norm_type
|
| 383 |
-
self.avg_factor = avg_factor
|
| 384 |
-
self.requires_grad = requires_grad
|
| 385 |
-
self.glob_mean = torch.tensor([0])
|
| 386 |
-
self.glob_std = torch.tensor([0])
|
| 387 |
-
self.spk_dict_mean = {}
|
| 388 |
-
self.spk_dict_std = {}
|
| 389 |
-
self.spk_dict_count = {}
|
| 390 |
-
self.weight = 1.0
|
| 391 |
-
self.count = 0
|
| 392 |
-
self.eps = 1e-10
|
| 393 |
-
self.update_until_epoch = update_until_epoch
|
| 394 |
-
|
| 395 |
-
def forward(self, x, lengths, spk_ids = torch.tensor([]), epoch=0):
|
| 396 |
-
N_batches = x.shape[0]
|
| 397 |
-
current_means, current_stds = [], []
|
| 398 |
-
|
| 399 |
-
if self.norm_type == "sentence" or self.norm_type == "speaker": out = torch.empty_like(x)
|
| 400 |
-
|
| 401 |
-
for snt_id in range(N_batches):
|
| 402 |
-
actual_size = torch.round(lengths[snt_id] * x.shape[1]).int()
|
| 403 |
-
current_mean, current_std = self._compute_current_stats(x[snt_id, 0:actual_size, ...])
|
| 404 |
-
|
| 405 |
-
current_means.append(current_mean)
|
| 406 |
-
current_stds.append(current_std)
|
| 407 |
-
|
| 408 |
-
if self.norm_type == "sentence": out[snt_id] = (x[snt_id] - current_mean.data) / current_std.data
|
| 409 |
-
|
| 410 |
-
if self.norm_type == "speaker":
|
| 411 |
-
spk_id = int(spk_ids[snt_id][0])
|
| 412 |
-
|
| 413 |
-
if self.training:
|
| 414 |
-
if spk_id not in self.spk_dict_mean:
|
| 415 |
-
self.spk_dict_mean[spk_id] = current_mean
|
| 416 |
-
self.spk_dict_std[spk_id] = current_std
|
| 417 |
-
self.spk_dict_count[spk_id] = 1
|
| 418 |
-
else:
|
| 419 |
-
self.spk_dict_count[spk_id] = (self.spk_dict_count[spk_id] + 1)
|
| 420 |
-
self.weight = (1 / self.spk_dict_count[spk_id]) if self.avg_factor is None else self.avg_factor
|
| 421 |
-
|
| 422 |
-
self.spk_dict_mean[spk_id] = (1 - self.weight) * self.spk_dict_mean[spk_id].to(current_mean) + self.weight * current_mean
|
| 423 |
-
self.spk_dict_std[spk_id] = (1 - self.weight) * self.spk_dict_std[spk_id].to(current_std) + self.weight * current_std
|
| 424 |
-
|
| 425 |
-
self.spk_dict_mean[spk_id].detach()
|
| 426 |
-
self.spk_dict_std[spk_id].detach()
|
| 427 |
-
|
| 428 |
-
speaker_mean = self.spk_dict_mean[spk_id].data
|
| 429 |
-
speaker_std = self.spk_dict_std[spk_id].data
|
| 430 |
-
else:
|
| 431 |
-
if spk_id in self.spk_dict_mean:
|
| 432 |
-
speaker_mean = self.spk_dict_mean[spk_id].data
|
| 433 |
-
speaker_std = self.spk_dict_std[spk_id].data
|
| 434 |
-
else:
|
| 435 |
-
speaker_mean = current_mean.data
|
| 436 |
-
speaker_std = current_std.data
|
| 437 |
-
|
| 438 |
-
out[snt_id] = (x[snt_id] - speaker_mean) / speaker_std
|
| 439 |
-
|
| 440 |
-
if self.norm_type == "batch" or self.norm_type == "global":
|
| 441 |
-
current_mean = ddp_all_reduce(torch.mean(torch.stack(current_means), dim=0), torch.distributed.ReduceOp.AVG)
|
| 442 |
-
current_std = ddp_all_reduce(torch.mean(torch.stack(current_stds), dim=0), torch.distributed.ReduceOp.AVG)
|
| 443 |
-
|
| 444 |
-
if self.norm_type == "batch": out = (x - current_mean.data) / (current_std.data)
|
| 445 |
-
|
| 446 |
-
if self.norm_type == "global":
|
| 447 |
-
if self.training:
|
| 448 |
-
if self.count == 0:
|
| 449 |
-
self.glob_mean = current_mean
|
| 450 |
-
self.glob_std = current_std
|
| 451 |
-
elif epoch is None or epoch < self.update_until_epoch:
|
| 452 |
-
self.weight = (1 / (self.count + 1)) if self.avg_factor is None else self.avg_factor
|
| 453 |
-
self.glob_mean = (1 - self.weight) * self.glob_mean.to(current_mean) + self.weight * current_mean
|
| 454 |
-
self.glob_std = (1 - self.weight) * self.glob_std.to(current_std) + self.weight * current_std
|
| 455 |
-
|
| 456 |
-
self.glob_mean.detach()
|
| 457 |
-
self.glob_std.detach()
|
| 458 |
-
self.count = self.count + 1
|
| 459 |
-
|
| 460 |
-
out = (x - self.glob_mean.data.to(x)) / (self.glob_std.data.to(x))
|
| 461 |
-
|
| 462 |
-
return out
|
| 463 |
-
|
| 464 |
-
def _compute_current_stats(self, x):
|
| 465 |
-
current_std = torch.std(x, dim=0).detach().data if self.std_norm else torch.tensor([1.0], device=x.device)
|
| 466 |
-
return torch.mean(x, dim=0).detach().data if self.mean_norm else torch.tensor([0.0], device=x.device), torch.max(current_std, self.eps * torch.ones_like(current_std))
|
| 467 |
-
|
| 468 |
-
def _statistics_dict(self):
|
| 469 |
-
state = {}
|
| 470 |
-
state["count"] = self.count
|
| 471 |
-
state["glob_mean"] = self.glob_mean
|
| 472 |
-
state["glob_std"] = self.glob_std
|
| 473 |
-
state["spk_dict_mean"] = self.spk_dict_mean
|
| 474 |
-
state["spk_dict_std"] = self.spk_dict_std
|
| 475 |
-
state["spk_dict_count"] = self.spk_dict_count
|
| 476 |
-
|
| 477 |
-
return state
|
| 478 |
-
|
| 479 |
-
def _load_statistics_dict(self, state):
|
| 480 |
-
self.count = state["count"]
|
| 481 |
-
|
| 482 |
-
if isinstance(state["glob_mean"], int):
|
| 483 |
-
self.glob_mean = state["glob_mean"]
|
| 484 |
-
self.glob_std = state["glob_std"]
|
| 485 |
-
else:
|
| 486 |
-
self.glob_mean = state["glob_mean"]
|
| 487 |
-
self.glob_std = state["glob_std"]
|
| 488 |
-
|
| 489 |
-
self.spk_dict_mean = {}
|
| 490 |
-
for spk in state["spk_dict_mean"]:
|
| 491 |
-
self.spk_dict_mean[spk] = state["spk_dict_mean"][spk]
|
| 492 |
-
|
| 493 |
-
self.spk_dict_std = {}
|
| 494 |
-
for spk in state["spk_dict_std"]:
|
| 495 |
-
self.spk_dict_std[spk] = state["spk_dict_std"][spk]
|
| 496 |
-
|
| 497 |
-
self.spk_dict_count = state["spk_dict_count"]
|
| 498 |
-
return state
|
| 499 |
-
|
| 500 |
-
def to(self, device):
|
| 501 |
-
self = super(InputNormalization, self).to(device)
|
| 502 |
-
self.glob_mean = self.glob_mean.to(device)
|
| 503 |
-
self.glob_std = self.glob_std.to(device)
|
| 504 |
-
|
| 505 |
-
for spk in self.spk_dict_mean:
|
| 506 |
-
self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device)
|
| 507 |
-
self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device)
|
| 508 |
-
|
| 509 |
-
return self
|
| 510 |
-
|
| 511 |
-
@mark_as_saver
|
| 512 |
-
def _save(self, path):
|
| 513 |
-
torch.save(self._statistics_dict(), path)
|
| 514 |
-
|
| 515 |
-
@mark_as_transfer
|
| 516 |
-
@mark_as_loader
|
| 517 |
-
def _load(self, path, end_of_epoch=False):
|
| 518 |
-
del end_of_epoch
|
| 519 |
-
stats = torch.load(path, map_location="cpu")
|
| 520 |
-
self._load_statistics_dict(stats)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/speaker_diarization/parameter_transfer.py
DELETED
|
@@ -1,120 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import inspect
|
| 4 |
-
|
| 5 |
-
sys.path.append(os.getcwd())
|
| 6 |
-
|
| 7 |
-
from main.library.speaker_diarization.speechbrain import fetch, run_on_main
|
| 8 |
-
from main.library.speaker_diarization.features import DEFAULT_TRANSFER_HOOKS, DEFAULT_LOAD_HOOKS
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def get_default_hook(obj, default_hooks):
|
| 12 |
-
for cls in inspect.getmro(type(obj)):
|
| 13 |
-
if cls in default_hooks: return default_hooks[cls]
|
| 14 |
-
|
| 15 |
-
return None
|
| 16 |
-
|
| 17 |
-
class Pretrainer:
|
| 18 |
-
def __init__(self, loadables=None, paths=None, custom_hooks=None, conditions=None):
|
| 19 |
-
self.loadables = {}
|
| 20 |
-
|
| 21 |
-
if loadables is not None: self.add_loadables(loadables)
|
| 22 |
-
self.paths = {}
|
| 23 |
-
|
| 24 |
-
if paths is not None: self.add_paths(paths)
|
| 25 |
-
self.custom_hooks = {}
|
| 26 |
-
|
| 27 |
-
if custom_hooks is not None: self.add_custom_hooks(custom_hooks)
|
| 28 |
-
self.conditions = {}
|
| 29 |
-
|
| 30 |
-
if conditions is not None: self.add_conditions(conditions)
|
| 31 |
-
self.is_local = []
|
| 32 |
-
|
| 33 |
-
def add_loadables(self, loadables):
|
| 34 |
-
self.loadables.update(loadables)
|
| 35 |
-
|
| 36 |
-
def add_paths(self, paths):
|
| 37 |
-
self.paths.update(paths)
|
| 38 |
-
|
| 39 |
-
def add_custom_hooks(self, custom_hooks):
|
| 40 |
-
self.custom_hooks.update(custom_hooks)
|
| 41 |
-
|
| 42 |
-
def add_conditions(self, conditions):
|
| 43 |
-
self.conditions.update(conditions)
|
| 44 |
-
|
| 45 |
-
@staticmethod
|
| 46 |
-
def split_path(path):
|
| 47 |
-
def split(src):
|
| 48 |
-
if "/" in src: return src.rsplit("/", maxsplit=1)
|
| 49 |
-
else: return "./", src
|
| 50 |
-
|
| 51 |
-
return split(path)
|
| 52 |
-
|
| 53 |
-
def collect_files(self, default_source=None):
|
| 54 |
-
loadable_paths = {}
|
| 55 |
-
for name in self.loadables:
|
| 56 |
-
if not self.is_loadable(name): continue
|
| 57 |
-
save_filename = name + ".ckpt"
|
| 58 |
-
|
| 59 |
-
if name in self.paths: source, filename = self.split_path(self.paths[name])
|
| 60 |
-
elif default_source is not None:
|
| 61 |
-
filename = save_filename
|
| 62 |
-
source = default_source
|
| 63 |
-
else: raise ValueError
|
| 64 |
-
|
| 65 |
-
fetch_kwargs = {"filename": filename, "source": source}
|
| 66 |
-
path = None
|
| 67 |
-
|
| 68 |
-
def run_fetch(**kwargs):
|
| 69 |
-
nonlocal path
|
| 70 |
-
|
| 71 |
-
path = fetch(**kwargs)
|
| 72 |
-
|
| 73 |
-
run_on_main(run_fetch, kwargs=fetch_kwargs, post_func=run_fetch, post_kwargs=fetch_kwargs)
|
| 74 |
-
|
| 75 |
-
loadable_paths[name] = path
|
| 76 |
-
self.paths[name] = str(path)
|
| 77 |
-
self.is_local.append(name)
|
| 78 |
-
|
| 79 |
-
return loadable_paths
|
| 80 |
-
|
| 81 |
-
def is_loadable(self, name):
|
| 82 |
-
if name not in self.conditions: return True
|
| 83 |
-
condition = self.conditions[name]
|
| 84 |
-
|
| 85 |
-
if callable(condition): return condition()
|
| 86 |
-
else: return bool(condition)
|
| 87 |
-
|
| 88 |
-
def load_collected(self):
|
| 89 |
-
paramfiles = {}
|
| 90 |
-
for name in self.loadables:
|
| 91 |
-
if not self.is_loadable(name): continue
|
| 92 |
-
|
| 93 |
-
if name in self.is_local: paramfiles[name] = self.paths[name]
|
| 94 |
-
else: raise ValueError
|
| 95 |
-
|
| 96 |
-
self._call_load_hooks(paramfiles)
|
| 97 |
-
|
| 98 |
-
def _call_load_hooks(self, paramfiles):
|
| 99 |
-
for name, obj in self.loadables.items():
|
| 100 |
-
if not self.is_loadable(name): continue
|
| 101 |
-
loadpath = paramfiles[name]
|
| 102 |
-
|
| 103 |
-
if name in self.custom_hooks:
|
| 104 |
-
self.custom_hooks[name](obj, loadpath)
|
| 105 |
-
continue
|
| 106 |
-
|
| 107 |
-
default_hook = get_default_hook(obj, DEFAULT_TRANSFER_HOOKS)
|
| 108 |
-
|
| 109 |
-
if default_hook is not None:
|
| 110 |
-
default_hook(obj, loadpath)
|
| 111 |
-
continue
|
| 112 |
-
|
| 113 |
-
default_hook = get_default_hook(obj, DEFAULT_LOAD_HOOKS)
|
| 114 |
-
|
| 115 |
-
if default_hook is not None:
|
| 116 |
-
end_of_epoch = False
|
| 117 |
-
default_hook(obj, loadpath, end_of_epoch)
|
| 118 |
-
continue
|
| 119 |
-
|
| 120 |
-
raise RuntimeError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/speaker_diarization/segment.py
DELETED
|
@@ -1,540 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
|
| 3 |
-
from sortedcontainers import SortedList
|
| 4 |
-
|
| 5 |
-
PYANNOTE_SEGMENT = 'segment'
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class Timeline:
|
| 9 |
-
@classmethod
|
| 10 |
-
def from_df(cls, df, uri = None):
|
| 11 |
-
return cls(segments=list(df[PYANNOTE_SEGMENT]), uri=uri)
|
| 12 |
-
|
| 13 |
-
def __init__(self, segments = None, uri = None):
|
| 14 |
-
if segments is None: segments = ()
|
| 15 |
-
segments_set = set([segment for segment in segments if segment])
|
| 16 |
-
|
| 17 |
-
self.segments_set_ = segments_set
|
| 18 |
-
self.segments_list_ = SortedList(segments_set)
|
| 19 |
-
self.segments_boundaries_ = SortedList((boundary for segment in segments_set for boundary in segment))
|
| 20 |
-
self.uri = uri
|
| 21 |
-
|
| 22 |
-
def __len__(self):
|
| 23 |
-
return len(self.segments_set_)
|
| 24 |
-
|
| 25 |
-
def __nonzero__(self):
|
| 26 |
-
return self.__bool__()
|
| 27 |
-
|
| 28 |
-
def __bool__(self):
|
| 29 |
-
return len(self.segments_set_) > 0
|
| 30 |
-
|
| 31 |
-
def __iter__(self):
|
| 32 |
-
return iter(self.segments_list_)
|
| 33 |
-
|
| 34 |
-
def __getitem__(self, k):
|
| 35 |
-
return self.segments_list_[k]
|
| 36 |
-
|
| 37 |
-
def __eq__(self, other):
|
| 38 |
-
return self.segments_set_ == other.segments_set_
|
| 39 |
-
|
| 40 |
-
def __ne__(self, other):
|
| 41 |
-
return self.segments_set_ != other.segments_set_
|
| 42 |
-
|
| 43 |
-
def index(self, segment):
|
| 44 |
-
return self.segments_list_.index(segment)
|
| 45 |
-
|
| 46 |
-
def add(self, segment):
|
| 47 |
-
segments_set_ = self.segments_set_
|
| 48 |
-
if segment in segments_set_ or not segment: return self
|
| 49 |
-
|
| 50 |
-
segments_set_.add(segment)
|
| 51 |
-
self.segments_list_.add(segment)
|
| 52 |
-
|
| 53 |
-
segments_boundaries_ = self.segments_boundaries_
|
| 54 |
-
segments_boundaries_.add(segment.start)
|
| 55 |
-
segments_boundaries_.add(segment.end)
|
| 56 |
-
|
| 57 |
-
return self
|
| 58 |
-
|
| 59 |
-
def remove(self, segment):
|
| 60 |
-
segments_set_ = self.segments_set_
|
| 61 |
-
if segment not in segments_set_: return self
|
| 62 |
-
|
| 63 |
-
segments_set_.remove(segment)
|
| 64 |
-
self.segments_list_.remove(segment)
|
| 65 |
-
|
| 66 |
-
segments_boundaries_ = self.segments_boundaries_
|
| 67 |
-
segments_boundaries_.remove(segment.start)
|
| 68 |
-
segments_boundaries_.remove(segment.end)
|
| 69 |
-
|
| 70 |
-
return self
|
| 71 |
-
|
| 72 |
-
def discard(self, segment):
|
| 73 |
-
return self.remove(segment)
|
| 74 |
-
|
| 75 |
-
def __ior__(self, timeline):
|
| 76 |
-
return self.update(timeline)
|
| 77 |
-
|
| 78 |
-
def update(self, timeline):
|
| 79 |
-
segments_set = self.segments_set_
|
| 80 |
-
segments_set |= timeline.segments_set_
|
| 81 |
-
|
| 82 |
-
self.segments_list_ = SortedList(segments_set)
|
| 83 |
-
self.segments_boundaries_ = SortedList((boundary for segment in segments_set for boundary in segment))
|
| 84 |
-
|
| 85 |
-
return self
|
| 86 |
-
|
| 87 |
-
def __or__(self, timeline):
|
| 88 |
-
return self.union(timeline)
|
| 89 |
-
|
| 90 |
-
def union(self, timeline):
|
| 91 |
-
return Timeline(segments=self.segments_set_ | timeline.segments_set_, uri=self.uri)
|
| 92 |
-
|
| 93 |
-
def co_iter(self, other):
|
| 94 |
-
for segment in self.segments_list_:
|
| 95 |
-
temp = Segment(start=segment.end, end=segment.end)
|
| 96 |
-
|
| 97 |
-
for other_segment in other.segments_list_.irange(maximum=temp):
|
| 98 |
-
if segment.intersects(other_segment): yield segment, other_segment
|
| 99 |
-
|
| 100 |
-
def crop_iter(self, support, mode = 'intersection', returns_mapping = False):
|
| 101 |
-
if mode not in {'loose', 'strict', 'intersection'}: raise ValueError
|
| 102 |
-
if not isinstance(support, (Segment, Timeline)): raise TypeError
|
| 103 |
-
|
| 104 |
-
if isinstance(support, Segment):
|
| 105 |
-
support = Timeline(segments=([support] if support else []), uri=self.uri)
|
| 106 |
-
|
| 107 |
-
for yielded in self.crop_iter(support, mode=mode, returns_mapping=returns_mapping):
|
| 108 |
-
yield yielded
|
| 109 |
-
|
| 110 |
-
return
|
| 111 |
-
|
| 112 |
-
support = support.support()
|
| 113 |
-
|
| 114 |
-
if mode == 'loose':
|
| 115 |
-
for segment, _ in self.co_iter(support):
|
| 116 |
-
yield segment
|
| 117 |
-
|
| 118 |
-
return
|
| 119 |
-
|
| 120 |
-
if mode == 'strict':
|
| 121 |
-
for segment, other_segment in self.co_iter(support):
|
| 122 |
-
if segment in other_segment: yield segment
|
| 123 |
-
|
| 124 |
-
return
|
| 125 |
-
|
| 126 |
-
for segment, other_segment in self.co_iter(support):
|
| 127 |
-
mapped_to = segment & other_segment
|
| 128 |
-
if not mapped_to: continue
|
| 129 |
-
|
| 130 |
-
if returns_mapping: yield segment, mapped_to
|
| 131 |
-
else: yield mapped_to
|
| 132 |
-
|
| 133 |
-
def crop(self, support, mode = 'intersection', returns_mapping = False):
|
| 134 |
-
if mode == 'intersection' and returns_mapping:
|
| 135 |
-
segments, mapping = [], {}
|
| 136 |
-
|
| 137 |
-
for segment, mapped_to in self.crop_iter(support, mode='intersection', returns_mapping=True):
|
| 138 |
-
segments.append(mapped_to)
|
| 139 |
-
mapping[mapped_to] = mapping.get(mapped_to, list()) + [segment]
|
| 140 |
-
|
| 141 |
-
return Timeline(segments=segments, uri=self.uri), mapping
|
| 142 |
-
|
| 143 |
-
return Timeline(segments=self.crop_iter(support, mode=mode), uri=self.uri)
|
| 144 |
-
|
| 145 |
-
def overlapping(self, t):
|
| 146 |
-
return list(self.overlapping_iter(t))
|
| 147 |
-
|
| 148 |
-
def overlapping_iter(self, t):
|
| 149 |
-
for segment in self.segments_list_.irange(maximum=Segment(start=t, end=t)):
|
| 150 |
-
if segment.overlaps(t): yield segment
|
| 151 |
-
|
| 152 |
-
def get_overlap(self):
|
| 153 |
-
overlaps_tl = Timeline(uri=self.uri)
|
| 154 |
-
|
| 155 |
-
for s1, s2 in self.co_iter(self):
|
| 156 |
-
if s1 == s2: continue
|
| 157 |
-
|
| 158 |
-
overlaps_tl.add(s1 & s2)
|
| 159 |
-
|
| 160 |
-
return overlaps_tl.support()
|
| 161 |
-
|
| 162 |
-
def extrude(self, removed, mode = 'intersection'):
|
| 163 |
-
if isinstance(removed, Segment): removed = Timeline([removed])
|
| 164 |
-
|
| 165 |
-
if mode == "loose": mode = "strict"
|
| 166 |
-
elif mode == "strict": mode = "loose"
|
| 167 |
-
|
| 168 |
-
return self.crop(removed.gaps(support=Timeline([self.extent()], uri=self.uri)), mode=mode)
|
| 169 |
-
|
| 170 |
-
def __str__(self):
|
| 171 |
-
n = len(self.segments_list_)
|
| 172 |
-
string = "["
|
| 173 |
-
|
| 174 |
-
for i, segment in enumerate(self.segments_list_):
|
| 175 |
-
string += str(segment)
|
| 176 |
-
string += "\n " if i + 1 < n else ""
|
| 177 |
-
|
| 178 |
-
string += "]"
|
| 179 |
-
return string
|
| 180 |
-
|
| 181 |
-
def __repr__(self):
|
| 182 |
-
return "<Timeline(uri=%s, segments=%s)>" % (self.uri, list(self.segments_list_))
|
| 183 |
-
|
| 184 |
-
def __contains__(self, included):
|
| 185 |
-
if isinstance(included, Segment): return included in self.segments_set_
|
| 186 |
-
elif isinstance(included, Timeline): return self.segments_set_.issuperset(included.segments_set_)
|
| 187 |
-
else: raise TypeError
|
| 188 |
-
|
| 189 |
-
def empty(self):
|
| 190 |
-
return Timeline(uri=self.uri)
|
| 191 |
-
|
| 192 |
-
def covers(self, other):
|
| 193 |
-
gaps = self.gaps(support=other.extent())
|
| 194 |
-
|
| 195 |
-
for _ in gaps.co_iter(other):
|
| 196 |
-
return False
|
| 197 |
-
|
| 198 |
-
return True
|
| 199 |
-
|
| 200 |
-
def copy(self, segment_func = None):
|
| 201 |
-
if segment_func is None: return Timeline(segments=self.segments_list_, uri=self.uri)
|
| 202 |
-
return Timeline(segments=[segment_func(s) for s in self.segments_list_], uri=self.uri)
|
| 203 |
-
|
| 204 |
-
def extent(self):
|
| 205 |
-
if self.segments_set_:
|
| 206 |
-
segments_boundaries_ = self.segments_boundaries_
|
| 207 |
-
return Segment(start=segments_boundaries_[0], end=segments_boundaries_[-1])
|
| 208 |
-
|
| 209 |
-
return Segment(start=0.0, end=0.0)
|
| 210 |
-
|
| 211 |
-
def support_iter(self, collar = 0.0):
|
| 212 |
-
if not self: return
|
| 213 |
-
|
| 214 |
-
new_segment = self.segments_list_[0]
|
| 215 |
-
|
| 216 |
-
for segment in self:
|
| 217 |
-
possible_gap = segment ^ new_segment
|
| 218 |
-
|
| 219 |
-
if not possible_gap or possible_gap.duration < collar: new_segment |= segment
|
| 220 |
-
else:
|
| 221 |
-
yield new_segment
|
| 222 |
-
new_segment = segment
|
| 223 |
-
|
| 224 |
-
yield new_segment
|
| 225 |
-
|
| 226 |
-
def support(self, collar = 0.):
|
| 227 |
-
return Timeline(segments=self.support_iter(collar), uri=self.uri)
|
| 228 |
-
|
| 229 |
-
def duration(self):
|
| 230 |
-
return sum(s.duration for s in self.support_iter())
|
| 231 |
-
|
| 232 |
-
def gaps_iter(self, support = None):
|
| 233 |
-
if support is None: support = self.extent()
|
| 234 |
-
if not isinstance(support, (Segment, Timeline)): raise TypeError
|
| 235 |
-
|
| 236 |
-
if isinstance(support, Segment):
|
| 237 |
-
end = support.start
|
| 238 |
-
|
| 239 |
-
for segment in self.crop(support, mode='intersection').support():
|
| 240 |
-
gap = Segment(start=end, end=segment.start)
|
| 241 |
-
if gap: yield gap
|
| 242 |
-
|
| 243 |
-
end = segment.end
|
| 244 |
-
|
| 245 |
-
gap = Segment(start=end, end=support.end)
|
| 246 |
-
if gap: yield gap
|
| 247 |
-
elif isinstance(support, Timeline):
|
| 248 |
-
for segment in support.support():
|
| 249 |
-
for gap in self.gaps_iter(support=segment):
|
| 250 |
-
yield gap
|
| 251 |
-
|
| 252 |
-
def gaps(self, support = None):
|
| 253 |
-
return Timeline(segments=self.gaps_iter(support=support), uri=self.uri)
|
| 254 |
-
|
| 255 |
-
def segmentation(self):
|
| 256 |
-
support = self.support()
|
| 257 |
-
timestamps = set([])
|
| 258 |
-
|
| 259 |
-
for (start, end) in self:
|
| 260 |
-
timestamps.add(start)
|
| 261 |
-
timestamps.add(end)
|
| 262 |
-
|
| 263 |
-
timestamps = sorted(timestamps)
|
| 264 |
-
if len(timestamps) == 0: return Timeline(uri=self.uri)
|
| 265 |
-
|
| 266 |
-
segments = []
|
| 267 |
-
start = timestamps[0]
|
| 268 |
-
|
| 269 |
-
for end in timestamps[1:]:
|
| 270 |
-
segment = Segment(start=start, end=end)
|
| 271 |
-
|
| 272 |
-
if segment and support.overlapping(segment.middle): segments.append(segment)
|
| 273 |
-
start = end
|
| 274 |
-
|
| 275 |
-
return Timeline(segments=segments, uri=self.uri)
|
| 276 |
-
|
| 277 |
-
def _iter_uem(self):
|
| 278 |
-
uri = self.uri if self.uri else "<NA>"
|
| 279 |
-
|
| 280 |
-
for segment in self:
|
| 281 |
-
yield f"{uri} 1 {segment.start:.3f} {segment.end:.3f}\n"
|
| 282 |
-
|
| 283 |
-
def to_uem(self):
|
| 284 |
-
return "".join([line for line in self._iter_uem()])
|
| 285 |
-
|
| 286 |
-
def write_uem(self, file):
|
| 287 |
-
for line in self._iter_uem():
|
| 288 |
-
file.write(line)
|
| 289 |
-
|
| 290 |
-
def _repr_png_(self):
|
| 291 |
-
return None
|
| 292 |
-
|
| 293 |
-
class Segment:
|
| 294 |
-
def __init__(self, start, end):
|
| 295 |
-
self.start = start
|
| 296 |
-
self.end = end
|
| 297 |
-
|
| 298 |
-
@staticmethod
|
| 299 |
-
def set_precision(ndigits = None):
|
| 300 |
-
global AUTO_ROUND_TIME, SEGMENT_PRECISION
|
| 301 |
-
|
| 302 |
-
if ndigits is None:
|
| 303 |
-
AUTO_ROUND_TIME = False
|
| 304 |
-
SEGMENT_PRECISION = 1e-6
|
| 305 |
-
else:
|
| 306 |
-
AUTO_ROUND_TIME = True
|
| 307 |
-
SEGMENT_PRECISION = 10 ** (-ndigits)
|
| 308 |
-
|
| 309 |
-
def __bool__(self):
|
| 310 |
-
return bool((self.end - self.start) > SEGMENT_PRECISION)
|
| 311 |
-
|
| 312 |
-
def __post_init__(self):
|
| 313 |
-
if AUTO_ROUND_TIME:
|
| 314 |
-
object.__setattr__(self, 'start', int(self.start / SEGMENT_PRECISION + 0.5) * SEGMENT_PRECISION)
|
| 315 |
-
object.__setattr__(self, 'end', int(self.end / SEGMENT_PRECISION + 0.5) * SEGMENT_PRECISION)
|
| 316 |
-
|
| 317 |
-
@property
|
| 318 |
-
def duration(self):
|
| 319 |
-
return self.end - self.start if self else 0.
|
| 320 |
-
|
| 321 |
-
@property
|
| 322 |
-
def middle(self):
|
| 323 |
-
return .5 * (self.start + self.end)
|
| 324 |
-
|
| 325 |
-
def __iter__(self):
|
| 326 |
-
yield self.start
|
| 327 |
-
yield self.end
|
| 328 |
-
|
| 329 |
-
def copy(self):
|
| 330 |
-
return Segment(start=self.start, end=self.end)
|
| 331 |
-
|
| 332 |
-
def __contains__(self, other):
|
| 333 |
-
return (self.start <= other.start) and (self.end >= other.end)
|
| 334 |
-
|
| 335 |
-
def __and__(self, other):
|
| 336 |
-
return Segment(start=max(self.start, other.start), end=min(self.end, other.end))
|
| 337 |
-
|
| 338 |
-
def intersects(self, other):
|
| 339 |
-
return (self.start < other.start and other.start < self.end - SEGMENT_PRECISION) or (self.start > other.start and self.start < other.end - SEGMENT_PRECISION) or (self.start == other.start)
|
| 340 |
-
|
| 341 |
-
def overlaps(self, t):
|
| 342 |
-
return self.start <= t and self.end >= t
|
| 343 |
-
|
| 344 |
-
def __or__(self, other):
|
| 345 |
-
if not self: return other
|
| 346 |
-
if not other: return self
|
| 347 |
-
|
| 348 |
-
return Segment(start=min(self.start, other.start), end=max(self.end, other.end))
|
| 349 |
-
|
| 350 |
-
def __xor__(self, other):
|
| 351 |
-
if (not self) or (not other): raise ValueError
|
| 352 |
-
|
| 353 |
-
return Segment(start=min(self.end, other.end), end=max(self.start, other.start))
|
| 354 |
-
|
| 355 |
-
def _str_helper(self, seconds):
|
| 356 |
-
from datetime import timedelta
|
| 357 |
-
|
| 358 |
-
negative = seconds < 0
|
| 359 |
-
td = timedelta(seconds=abs(seconds))
|
| 360 |
-
|
| 361 |
-
hours, remainder = divmod(td.seconds + 86400 * td.days, 3600)
|
| 362 |
-
minutes, seconds = divmod(remainder, 60)
|
| 363 |
-
|
| 364 |
-
return '%s%02d:%02d:%02d.%03d' % ('-' if negative else ' ', hours, minutes, seconds, td.microseconds / 1000)
|
| 365 |
-
|
| 366 |
-
def __str__(self):
|
| 367 |
-
if self: return '[%s --> %s]' % (self._str_helper(self.start), self._str_helper(self.end))
|
| 368 |
-
return '[]'
|
| 369 |
-
|
| 370 |
-
def __repr__(self):
|
| 371 |
-
return '<Segment(%g, %g)>' % (self.start, self.end)
|
| 372 |
-
|
| 373 |
-
def _repr_png_(self):
|
| 374 |
-
return None
|
| 375 |
-
|
| 376 |
-
class SlidingWindow:
|
| 377 |
-
def __init__(self, duration=0.030, step=0.010, start=0.000, end=None):
|
| 378 |
-
if duration <= 0: raise ValueError
|
| 379 |
-
self.__duration = duration
|
| 380 |
-
if step <= 0: raise ValueError
|
| 381 |
-
|
| 382 |
-
self.__step = step
|
| 383 |
-
self.__start = start
|
| 384 |
-
|
| 385 |
-
if end is None: self.__end = np.inf
|
| 386 |
-
else:
|
| 387 |
-
if end <= start: raise ValueError
|
| 388 |
-
self.__end = end
|
| 389 |
-
|
| 390 |
-
self.__i = -1
|
| 391 |
-
|
| 392 |
-
@property
|
| 393 |
-
def start(self):
|
| 394 |
-
return self.__start
|
| 395 |
-
|
| 396 |
-
@property
|
| 397 |
-
def end(self):
|
| 398 |
-
return self.__end
|
| 399 |
-
|
| 400 |
-
@property
|
| 401 |
-
def step(self):
|
| 402 |
-
return self.__step
|
| 403 |
-
|
| 404 |
-
@property
|
| 405 |
-
def duration(self):
|
| 406 |
-
return self.__duration
|
| 407 |
-
|
| 408 |
-
def closest_frame(self, t):
|
| 409 |
-
return int(np.rint((t - self.__start - .5 * self.__duration) / self.__step))
|
| 410 |
-
|
| 411 |
-
def samples(self, from_duration, mode = 'strict'):
|
| 412 |
-
if mode == 'strict': return int(np.floor((from_duration - self.duration) / self.step)) + 1
|
| 413 |
-
elif mode == 'loose': return int(np.floor((from_duration + self.duration) / self.step))
|
| 414 |
-
elif mode == 'center': return int(np.rint((from_duration / self.step)))
|
| 415 |
-
|
| 416 |
-
def crop(self, focus, mode = 'loose', fixed = None, return_ranges = False):
|
| 417 |
-
if not isinstance(focus, (Segment, Timeline)): raise TypeError
|
| 418 |
-
|
| 419 |
-
if isinstance(focus, Timeline):
|
| 420 |
-
if fixed is not None: raise ValueError
|
| 421 |
-
|
| 422 |
-
if return_ranges:
|
| 423 |
-
ranges = []
|
| 424 |
-
|
| 425 |
-
for i, s in enumerate(focus.support()):
|
| 426 |
-
rng = self.crop(s, mode=mode, fixed=fixed, return_ranges=True)
|
| 427 |
-
|
| 428 |
-
if i == 0 or rng[0][0] > ranges[-1][1]: ranges += rng
|
| 429 |
-
else: ranges[-1][1] = rng[0][1]
|
| 430 |
-
|
| 431 |
-
return ranges
|
| 432 |
-
|
| 433 |
-
return np.unique(np.hstack([self.crop(s, mode=mode, fixed=fixed, return_ranges=False) for s in focus.support()]))
|
| 434 |
-
|
| 435 |
-
if mode == 'loose':
|
| 436 |
-
i = int(np.ceil((focus.start - self.duration - self.start) / self.step))
|
| 437 |
-
|
| 438 |
-
if fixed is None:
|
| 439 |
-
j = int(np.floor((focus.end - self.start) / self.step))
|
| 440 |
-
rng = (i, j + 1)
|
| 441 |
-
else:
|
| 442 |
-
n = self.samples(fixed, mode='loose')
|
| 443 |
-
rng = (i, i + n)
|
| 444 |
-
elif mode == 'strict':
|
| 445 |
-
i = int(np.ceil((focus.start - self.start) / self.step))
|
| 446 |
-
|
| 447 |
-
if fixed is None:
|
| 448 |
-
j = int(np.floor((focus.end - self.duration - self.start) / self.step))
|
| 449 |
-
rng = (i, j + 1)
|
| 450 |
-
else:
|
| 451 |
-
n = self.samples(fixed, mode='strict')
|
| 452 |
-
rng = (i, i + n)
|
| 453 |
-
elif mode == 'center':
|
| 454 |
-
i = self.closest_frame(focus.start)
|
| 455 |
-
|
| 456 |
-
if fixed is None:
|
| 457 |
-
j = self.closest_frame(focus.end)
|
| 458 |
-
rng = (i, j + 1)
|
| 459 |
-
else:
|
| 460 |
-
n = self.samples(fixed, mode='center')
|
| 461 |
-
rng = (i, i + n)
|
| 462 |
-
else: raise ValueError
|
| 463 |
-
|
| 464 |
-
if return_ranges: return [list(rng)]
|
| 465 |
-
return np.array(range(*rng), dtype=np.int64)
|
| 466 |
-
|
| 467 |
-
def segmentToRange(self, segment):
|
| 468 |
-
return self.segment_to_range(segment)
|
| 469 |
-
|
| 470 |
-
def segment_to_range(self, segment):
|
| 471 |
-
return self.closest_frame(segment.start), int(segment.duration / self.step) + 1
|
| 472 |
-
|
| 473 |
-
def rangeToSegment(self, i0, n):
|
| 474 |
-
return self.range_to_segment(i0, n)
|
| 475 |
-
|
| 476 |
-
def range_to_segment(self, i0, n):
|
| 477 |
-
start = self.__start + (i0 - .5) * self.__step + .5 * self.__duration
|
| 478 |
-
|
| 479 |
-
if i0 == 0: start = self.start
|
| 480 |
-
return Segment(start, start + (n * self.__step))
|
| 481 |
-
|
| 482 |
-
def samplesToDuration(self, nSamples):
|
| 483 |
-
return self.samples_to_duration(nSamples)
|
| 484 |
-
|
| 485 |
-
def samples_to_duration(self, n_samples):
|
| 486 |
-
return self.range_to_segment(0, n_samples).duration
|
| 487 |
-
|
| 488 |
-
def durationToSamples(self, duration):
|
| 489 |
-
return self.duration_to_samples(duration)
|
| 490 |
-
|
| 491 |
-
def duration_to_samples(self, duration):
|
| 492 |
-
return self.segment_to_range(Segment(0, duration))[1]
|
| 493 |
-
|
| 494 |
-
def __getitem__(self, i):
|
| 495 |
-
start = self.__start + i * self.__step
|
| 496 |
-
if start >= self.__end: return None
|
| 497 |
-
|
| 498 |
-
return Segment(start=start, end=start + self.__duration)
|
| 499 |
-
|
| 500 |
-
def next(self):
|
| 501 |
-
return self.__next__()
|
| 502 |
-
|
| 503 |
-
def __next__(self):
|
| 504 |
-
self.__i += 1
|
| 505 |
-
window = self[self.__i]
|
| 506 |
-
|
| 507 |
-
if window: return window
|
| 508 |
-
else: raise StopIteration()
|
| 509 |
-
|
| 510 |
-
def __iter__(self):
|
| 511 |
-
self.__i = -1
|
| 512 |
-
return self
|
| 513 |
-
|
| 514 |
-
def __len__(self):
|
| 515 |
-
if np.isinf(self.__end): raise ValueError
|
| 516 |
-
i = self.closest_frame(self.__end)
|
| 517 |
-
|
| 518 |
-
while (self[i]):
|
| 519 |
-
i += 1
|
| 520 |
-
|
| 521 |
-
length = i
|
| 522 |
-
return length
|
| 523 |
-
|
| 524 |
-
def copy(self):
|
| 525 |
-
return self.__class__(duration=self.duration, step=self.step, start=self.start, end=self.end)
|
| 526 |
-
|
| 527 |
-
def __call__(self, support, align_last = False):
|
| 528 |
-
if isinstance(support, Timeline): segments = support
|
| 529 |
-
elif isinstance(support, Segment): segments = Timeline(segments=[support])
|
| 530 |
-
else: raise TypeError
|
| 531 |
-
|
| 532 |
-
for segment in segments:
|
| 533 |
-
if segment.duration < self.duration: continue
|
| 534 |
-
|
| 535 |
-
for s in SlidingWindow(duration=self.duration, step=self.step, start=segment.start, end=segment.end):
|
| 536 |
-
if s in segment:
|
| 537 |
-
yield s
|
| 538 |
-
last = s
|
| 539 |
-
|
| 540 |
-
if align_last and last.end < segment.end: yield Segment(start=segment.end - self.duration, end=segment.end)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/speaker_diarization/speechbrain.py
DELETED
|
@@ -1,220 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
import torchaudio
|
| 4 |
-
|
| 5 |
-
from functools import wraps
|
| 6 |
-
from types import SimpleNamespace
|
| 7 |
-
from torch.nn import SyncBatchNorm
|
| 8 |
-
from hyperpyyaml import load_hyperpyyaml
|
| 9 |
-
|
| 10 |
-
from torch.nn import DataParallel as DP
|
| 11 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 12 |
-
|
| 13 |
-
MAIN_PROC_ONLY = 0
|
| 14 |
-
|
| 15 |
-
def fetch(filename, source):
|
| 16 |
-
return os.path.abspath(os.path.join(source, filename))
|
| 17 |
-
|
| 18 |
-
def run_on_main(func, args=None, kwargs=None, post_func=None, post_args=None, post_kwargs=None, run_post_on_main=False):
|
| 19 |
-
if args is None: args = []
|
| 20 |
-
if kwargs is None: kwargs = {}
|
| 21 |
-
if post_args is None: post_args = []
|
| 22 |
-
if post_kwargs is None: post_kwargs = {}
|
| 23 |
-
|
| 24 |
-
main_process_only(func)(*args, **kwargs)
|
| 25 |
-
ddp_barrier()
|
| 26 |
-
|
| 27 |
-
if post_func is not None:
|
| 28 |
-
if run_post_on_main: post_func(*post_args, **post_kwargs)
|
| 29 |
-
else:
|
| 30 |
-
if not if_main_process(): post_func(*post_args, **post_kwargs)
|
| 31 |
-
ddp_barrier()
|
| 32 |
-
|
| 33 |
-
def is_distributed_initialized():
|
| 34 |
-
return (torch.distributed.is_available() and torch.distributed.is_initialized())
|
| 35 |
-
|
| 36 |
-
def if_main_process():
|
| 37 |
-
if is_distributed_initialized(): return torch.distributed.get_rank() == 0
|
| 38 |
-
else: return True
|
| 39 |
-
|
| 40 |
-
class MainProcessContext:
|
| 41 |
-
def __enter__(self):
|
| 42 |
-
global MAIN_PROC_ONLY
|
| 43 |
-
|
| 44 |
-
MAIN_PROC_ONLY += 1
|
| 45 |
-
return self
|
| 46 |
-
|
| 47 |
-
def __exit__(self, exc_type, exc_value, traceback):
|
| 48 |
-
global MAIN_PROC_ONLY
|
| 49 |
-
|
| 50 |
-
MAIN_PROC_ONLY -= 1
|
| 51 |
-
|
| 52 |
-
def main_process_only(function):
|
| 53 |
-
@wraps(function)
|
| 54 |
-
def main_proc_wrapped_func(*args, **kwargs):
|
| 55 |
-
with MainProcessContext():
|
| 56 |
-
return function(*args, **kwargs) if if_main_process() else None
|
| 57 |
-
|
| 58 |
-
return main_proc_wrapped_func
|
| 59 |
-
|
| 60 |
-
def ddp_barrier():
|
| 61 |
-
if MAIN_PROC_ONLY >= 1 or not is_distributed_initialized(): return
|
| 62 |
-
|
| 63 |
-
if torch.distributed.get_backend() == torch.distributed.Backend.NCCL: torch.distributed.barrier(device_ids=[torch.cuda.current_device()])
|
| 64 |
-
else: torch.distributed.barrier()
|
| 65 |
-
|
| 66 |
-
class Resample(torch.nn.Module):
|
| 67 |
-
def __init__(self, orig_freq=16000, new_freq=16000, *args, **kwargs):
|
| 68 |
-
super().__init__()
|
| 69 |
-
|
| 70 |
-
self.orig_freq = orig_freq
|
| 71 |
-
self.new_freq = new_freq
|
| 72 |
-
self.resampler = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=new_freq, *args, **kwargs)
|
| 73 |
-
|
| 74 |
-
def forward(self, waveforms):
|
| 75 |
-
if self.orig_freq == self.new_freq: return waveforms
|
| 76 |
-
|
| 77 |
-
unsqueezed = False
|
| 78 |
-
if len(waveforms.shape) == 2:
|
| 79 |
-
waveforms = waveforms.unsqueeze(1)
|
| 80 |
-
unsqueezed = True
|
| 81 |
-
elif len(waveforms.shape) == 3: waveforms = waveforms.transpose(1, 2)
|
| 82 |
-
else: raise ValueError
|
| 83 |
-
|
| 84 |
-
self.resampler.to(waveforms.device)
|
| 85 |
-
resampled_waveform = self.resampler(waveforms)
|
| 86 |
-
|
| 87 |
-
return resampled_waveform.squeeze(1) if unsqueezed else resampled_waveform.transpose(1, 2)
|
| 88 |
-
|
| 89 |
-
class AudioNormalizer:
|
| 90 |
-
def __init__(self, sample_rate=16000, mix="avg-to-mono"):
|
| 91 |
-
self.sample_rate = sample_rate
|
| 92 |
-
|
| 93 |
-
if mix not in ["avg-to-mono", "keep"]: raise ValueError
|
| 94 |
-
|
| 95 |
-
self.mix = mix
|
| 96 |
-
self._cached_resamplers = {}
|
| 97 |
-
|
| 98 |
-
def __call__(self, audio, sample_rate):
|
| 99 |
-
if sample_rate not in self._cached_resamplers: self._cached_resamplers[sample_rate] = Resample(sample_rate, self.sample_rate)
|
| 100 |
-
return self._mix(self._cached_resamplers[sample_rate](audio.unsqueeze(0)).squeeze(0))
|
| 101 |
-
|
| 102 |
-
def _mix(self, audio):
|
| 103 |
-
flat_input = audio.dim() == 1
|
| 104 |
-
|
| 105 |
-
if self.mix == "avg-to-mono":
|
| 106 |
-
if flat_input: return audio
|
| 107 |
-
return torch.mean(audio, 1)
|
| 108 |
-
|
| 109 |
-
if self.mix == "keep": return audio
|
| 110 |
-
|
| 111 |
-
class Pretrained(torch.nn.Module):
|
| 112 |
-
HPARAMS_NEEDED, MODULES_NEEDED = [], []
|
| 113 |
-
def __init__(self, modules=None, hparams=None, run_opts=None, freeze_params=True):
|
| 114 |
-
super().__init__()
|
| 115 |
-
|
| 116 |
-
for arg, default in {"device": "cpu", "data_parallel_count": -1, "data_parallel_backend": False, "distributed_launch": False, "distributed_backend": "nccl", "jit": False, "jit_module_keys": None, "compile": False, "compile_module_keys": None, "compile_mode": "reduce-overhead", "compile_using_fullgraph": False, "compile_using_dynamic_shape_tracing": False}.items():
|
| 117 |
-
if run_opts is not None and arg in run_opts: setattr(self, arg, run_opts[arg])
|
| 118 |
-
elif hparams is not None and arg in hparams: setattr(self, arg, hparams[arg])
|
| 119 |
-
else: setattr(self, arg, default)
|
| 120 |
-
|
| 121 |
-
self.mods = torch.nn.ModuleDict(modules)
|
| 122 |
-
|
| 123 |
-
for module in self.mods.values():
|
| 124 |
-
if module is not None: module.to(self.device)
|
| 125 |
-
|
| 126 |
-
if self.HPARAMS_NEEDED and hparams is None: raise ValueError
|
| 127 |
-
|
| 128 |
-
if hparams is not None:
|
| 129 |
-
for hp in self.HPARAMS_NEEDED:
|
| 130 |
-
if hp not in hparams: raise ValueError
|
| 131 |
-
|
| 132 |
-
self.hparams = SimpleNamespace(**hparams)
|
| 133 |
-
|
| 134 |
-
self._prepare_modules(freeze_params)
|
| 135 |
-
self.audio_normalizer = hparams.get("audio_normalizer", AudioNormalizer())
|
| 136 |
-
|
| 137 |
-
def _prepare_modules(self, freeze_params):
|
| 138 |
-
self._compile()
|
| 139 |
-
self._wrap_distributed()
|
| 140 |
-
|
| 141 |
-
if freeze_params:
|
| 142 |
-
self.mods.eval()
|
| 143 |
-
for p in self.mods.parameters():
|
| 144 |
-
p.requires_grad = False
|
| 145 |
-
|
| 146 |
-
def _compile(self):
|
| 147 |
-
compile_available = hasattr(torch, "compile")
|
| 148 |
-
if not compile_available and self.compile_module_keys is not None: raise ValueError
|
| 149 |
-
|
| 150 |
-
compile_module_keys = set()
|
| 151 |
-
if self.compile: compile_module_keys = set(self.mods) if self.compile_module_keys is None else set(self.compile_module_keys)
|
| 152 |
-
|
| 153 |
-
jit_module_keys = set()
|
| 154 |
-
if self.jit: jit_module_keys = set(self.mods) if self.jit_module_keys is None else set(self.jit_module_keys)
|
| 155 |
-
|
| 156 |
-
for name in compile_module_keys | jit_module_keys:
|
| 157 |
-
if name not in self.mods: raise ValueError
|
| 158 |
-
|
| 159 |
-
for name in compile_module_keys:
|
| 160 |
-
try:
|
| 161 |
-
module = torch.compile(self.mods[name], mode=self.compile_mode, fullgraph=self.compile_using_fullgraph, dynamic=self.compile_using_dynamic_shape_tracing)
|
| 162 |
-
except Exception:
|
| 163 |
-
continue
|
| 164 |
-
|
| 165 |
-
self.mods[name] = module.to(self.device)
|
| 166 |
-
jit_module_keys.discard(name)
|
| 167 |
-
|
| 168 |
-
for name in jit_module_keys:
|
| 169 |
-
module = torch.jit.script(self.mods[name])
|
| 170 |
-
self.mods[name] = module.to(self.device)
|
| 171 |
-
|
| 172 |
-
def _compile_jit(self):
|
| 173 |
-
self._compile()
|
| 174 |
-
|
| 175 |
-
def _wrap_distributed(self):
|
| 176 |
-
if not self.distributed_launch and not self.data_parallel_backend: return
|
| 177 |
-
elif self.distributed_launch:
|
| 178 |
-
for name, module in self.mods.items():
|
| 179 |
-
if any(p.requires_grad for p in module.parameters()): self.mods[name] = DDP(SyncBatchNorm.convert_sync_batchnorm(module), device_ids=[self.device])
|
| 180 |
-
else:
|
| 181 |
-
for name, module in self.mods.items():
|
| 182 |
-
if any(p.requires_grad for p in module.parameters()): self.mods[name] = DP(module) if self.data_parallel_count == -1 else DP(module, [i for i in range(self.data_parallel_count)])
|
| 183 |
-
|
| 184 |
-
@classmethod
|
| 185 |
-
def from_hparams(cls, source, hparams_file="hyperparams.yaml", overrides={}, download_only=False, overrides_must_match=True, **kwargs):
|
| 186 |
-
with open(fetch(filename=hparams_file, source=source)) as fin:
|
| 187 |
-
hparams = load_hyperpyyaml(fin, overrides, overrides_must_match=overrides_must_match)
|
| 188 |
-
|
| 189 |
-
pretrainer = hparams.get("pretrainer", None)
|
| 190 |
-
|
| 191 |
-
if pretrainer is not None:
|
| 192 |
-
run_on_main(pretrainer.collect_files, kwargs={"default_source": source})
|
| 193 |
-
if not download_only:
|
| 194 |
-
pretrainer.load_collected()
|
| 195 |
-
return cls(hparams["modules"], hparams, **kwargs)
|
| 196 |
-
else: return cls(hparams["modules"], hparams, **kwargs)
|
| 197 |
-
|
| 198 |
-
class EncoderClassifier(Pretrained):
|
| 199 |
-
MODULES_NEEDED = ["compute_features", "mean_var_norm", "embedding_model", "classifier"]
|
| 200 |
-
|
| 201 |
-
def encode_batch(self, wavs, wav_lens=None, normalize=False):
|
| 202 |
-
if len(wavs.shape) == 1: wavs = wavs.unsqueeze(0)
|
| 203 |
-
if wav_lens is None: wav_lens = torch.ones(wavs.shape[0], device=self.device)
|
| 204 |
-
|
| 205 |
-
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
|
| 206 |
-
wavs = wavs.float()
|
| 207 |
-
|
| 208 |
-
embeddings = self.mods.embedding_model(self.mods.mean_var_norm(self.mods.compute_features(wavs), wav_lens), wav_lens)
|
| 209 |
-
|
| 210 |
-
if normalize: embeddings = self.hparams.mean_var_norm_emb(embeddings, torch.ones(embeddings.shape[0], device=self.device))
|
| 211 |
-
return embeddings
|
| 212 |
-
|
| 213 |
-
def classify_batch(self, wavs, wav_lens=None):
|
| 214 |
-
out_prob = self.mods.classifier(self.encode_batch(wavs, wav_lens)).squeeze(1)
|
| 215 |
-
score, index = torch.max(out_prob, dim=-1)
|
| 216 |
-
|
| 217 |
-
return out_prob, score, index, self.hparams.label_encoder.decode_torch(index)
|
| 218 |
-
|
| 219 |
-
def forward(self, wavs, wav_lens=None):
|
| 220 |
-
return self.classify_batch(wavs, wav_lens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/speaker_diarization/whisper.py
DELETED
|
@@ -1,1290 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import gzip
|
| 4 |
-
import zlib
|
| 5 |
-
import tqdm
|
| 6 |
-
import torch
|
| 7 |
-
import base64
|
| 8 |
-
import string
|
| 9 |
-
import logging
|
| 10 |
-
import tiktoken
|
| 11 |
-
import itertools
|
| 12 |
-
|
| 13 |
-
import numba as nb
|
| 14 |
-
import numpy as np
|
| 15 |
-
import torch.nn as nn
|
| 16 |
-
import torch.nn.functional as F
|
| 17 |
-
|
| 18 |
-
from contextlib import contextmanager
|
| 19 |
-
from torch.distributions import Categorical
|
| 20 |
-
from functools import cached_property, lru_cache
|
| 21 |
-
from dataclasses import dataclass, replace
|
| 22 |
-
from torch.nn.functional import scaled_dot_product_attention
|
| 23 |
-
|
| 24 |
-
sys.path.append(os.getcwd())
|
| 25 |
-
|
| 26 |
-
from main.library.utils import load_audio
|
| 27 |
-
|
| 28 |
-
LANGUAGES = {"en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish", "pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian", "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese", "he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish", "hu": "hungarian", "ta": "tamil", "no": "norwegian", "th": "thai", "ur": "urdu", "hr": "croatian", "bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh", "sk": "slovak", "te": "telugu", "fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian", "az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian", "br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian", "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili", "gl": "galician", "mr": "marathi", "pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali", "af": "afrikaans", "oc": "occitan", "ka": "georgian", "be": "belarusian", "tg": "tajik", "sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek", "fo": "faroese", "ht": "haitian creole", "ps": "pashto", "tk": "turkmen", "nn": "nynorsk", "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan", "tl": "tagalog", "mg": "malagasy", "as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese", "yue": "cantonese"}
|
| 29 |
-
TO_LANGUAGE_CODE = {**{language: code for code, language in LANGUAGES.items()}, "burmese": "my", "valencian": "ca", "flemish": "nl", "haitian": "ht", "letzeburgesch": "lb", "pushto": "ps", "panjabi": "pa", "moldavian": "ro", "moldovan": "ro", "sinhalese": "si", "castilian": "es", "mandarin": "zh"}
|
| 30 |
-
_ALIGNMENT_HEADS = {"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO", "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m", "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000", "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj", "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", "large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`"}
|
| 31 |
-
|
| 32 |
-
SAMPLE_RATE, N_FFT, HOP_LENGTH, CHUNK_LENGTH = 16000, 400, 160, 30
|
| 33 |
-
|
| 34 |
-
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE
|
| 35 |
-
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2
|
| 36 |
-
|
| 37 |
-
def exact_div(x, y):
|
| 38 |
-
assert x % y == 0
|
| 39 |
-
return x // y
|
| 40 |
-
|
| 41 |
-
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH)
|
| 42 |
-
|
| 43 |
-
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH)
|
| 44 |
-
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN)
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def load_model(name = "base", device = "cpu"):
|
| 48 |
-
checkpoint_file = os.path.join("assets", "models", "speaker_diarization", "models", name + ".pt")
|
| 49 |
-
alignment_heads = _ALIGNMENT_HEADS[name]
|
| 50 |
-
|
| 51 |
-
with open(checkpoint_file, "rb") as fp:
|
| 52 |
-
checkpoint = torch.load(fp, map_location=device)
|
| 53 |
-
|
| 54 |
-
del checkpoint_file
|
| 55 |
-
|
| 56 |
-
model = Whisper(ModelDimensions(**checkpoint["dims"]))
|
| 57 |
-
model.load_state_dict(checkpoint["model_state_dict"])
|
| 58 |
-
model.set_alignment_heads(alignment_heads)
|
| 59 |
-
|
| 60 |
-
return model.to(device)
|
| 61 |
-
|
| 62 |
-
def merge_punctuations(alignment, prepended, appended):
|
| 63 |
-
i = len(alignment) - 2
|
| 64 |
-
j = len(alignment) - 1
|
| 65 |
-
|
| 66 |
-
while i >= 0:
|
| 67 |
-
previous = alignment[i]
|
| 68 |
-
following = alignment[j]
|
| 69 |
-
|
| 70 |
-
if previous.word.startswith(" ") and previous.word.strip() in prepended:
|
| 71 |
-
following.word = previous.word + following.word
|
| 72 |
-
following.tokens = previous.tokens + following.tokens
|
| 73 |
-
|
| 74 |
-
previous.word = ""
|
| 75 |
-
previous.tokens = []
|
| 76 |
-
else: j = i
|
| 77 |
-
|
| 78 |
-
i -= 1
|
| 79 |
-
|
| 80 |
-
i = 0
|
| 81 |
-
j = 1
|
| 82 |
-
|
| 83 |
-
while j < len(alignment):
|
| 84 |
-
previous = alignment[i]
|
| 85 |
-
following = alignment[j]
|
| 86 |
-
|
| 87 |
-
if not previous.word.endswith(" ") and following.word in appended:
|
| 88 |
-
previous.word = previous.word + following.word
|
| 89 |
-
previous.tokens = previous.tokens + following.tokens
|
| 90 |
-
|
| 91 |
-
following.word = ""
|
| 92 |
-
following.tokens = []
|
| 93 |
-
else: i = j
|
| 94 |
-
|
| 95 |
-
j += 1
|
| 96 |
-
|
| 97 |
-
class WordTiming:
|
| 98 |
-
def __init__(self, word, tokens, start, end, probability):
|
| 99 |
-
self.word = word
|
| 100 |
-
self.tokens = tokens
|
| 101 |
-
self.start = start
|
| 102 |
-
self.end = end
|
| 103 |
-
self.probability = probability
|
| 104 |
-
|
| 105 |
-
@contextmanager
|
| 106 |
-
def disable_sdpa():
|
| 107 |
-
prev_state = MultiHeadAttention.use_sdpa
|
| 108 |
-
try:
|
| 109 |
-
MultiHeadAttention.use_sdpa = False
|
| 110 |
-
yield
|
| 111 |
-
finally:
|
| 112 |
-
MultiHeadAttention.use_sdpa = prev_state
|
| 113 |
-
|
| 114 |
-
def median_filter(x, filter_width):
|
| 115 |
-
pad_width = filter_width // 2
|
| 116 |
-
|
| 117 |
-
if x.shape[-1] <= pad_width: return x
|
| 118 |
-
if (ndim := x.ndim) <= 2: x = x[None, None, :]
|
| 119 |
-
|
| 120 |
-
assert (filter_width > 0 and filter_width % 2 == 1)
|
| 121 |
-
|
| 122 |
-
result = None
|
| 123 |
-
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
| 124 |
-
|
| 125 |
-
if result is None: result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
| 126 |
-
if ndim <= 2: result = result[0, 0]
|
| 127 |
-
|
| 128 |
-
return result
|
| 129 |
-
|
| 130 |
-
@nb.jit(nopython=True)
|
| 131 |
-
def backtrace(trace):
|
| 132 |
-
i = trace.shape[0] - 1
|
| 133 |
-
j = trace.shape[1] - 1
|
| 134 |
-
|
| 135 |
-
trace[0, :] = 2
|
| 136 |
-
trace[:, 0] = 1
|
| 137 |
-
|
| 138 |
-
result = []
|
| 139 |
-
while i > 0 or j > 0:
|
| 140 |
-
result.append((i - 1, j - 1))
|
| 141 |
-
|
| 142 |
-
if trace[i, j] == 0:
|
| 143 |
-
i -= 1
|
| 144 |
-
j -= 1
|
| 145 |
-
elif trace[i, j] == 1: i -= 1
|
| 146 |
-
elif trace[i, j] == 2: j -= 1
|
| 147 |
-
else: raise ValueError
|
| 148 |
-
|
| 149 |
-
return np.array(result)[::-1, :].T
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
@nb.jit(nopython=True, parallel=True)
|
| 153 |
-
def dtw_cpu(x):
|
| 154 |
-
N, M = x.shape
|
| 155 |
-
|
| 156 |
-
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
| 157 |
-
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
| 158 |
-
cost[0, 0] = 0
|
| 159 |
-
|
| 160 |
-
for j in range(1, M + 1):
|
| 161 |
-
for i in range(1, N + 1):
|
| 162 |
-
c0 = cost[i - 1, j - 1]
|
| 163 |
-
c1 = cost[i - 1, j]
|
| 164 |
-
c2 = cost[i, j - 1]
|
| 165 |
-
|
| 166 |
-
if c0 < c1 and c0 < c2: c, t = c0, 0
|
| 167 |
-
elif c1 < c0 and c1 < c2: c, t = c1, 1
|
| 168 |
-
else: c, t = c2, 2
|
| 169 |
-
|
| 170 |
-
cost[i, j] = x[i - 1, j - 1] + c
|
| 171 |
-
trace[i, j] = t
|
| 172 |
-
|
| 173 |
-
return backtrace(trace)
|
| 174 |
-
|
| 175 |
-
def dtw(x):
|
| 176 |
-
return dtw_cpu(x.double().cpu().numpy())
|
| 177 |
-
|
| 178 |
-
def find_alignment(model, tokenizer, text_tokens, mel, num_frames, *, medfilt_width = 7, qk_scale = 1.0):
|
| 179 |
-
if len(text_tokens) == 0: return []
|
| 180 |
-
|
| 181 |
-
tokens = torch.tensor([*tokenizer.sot_sequence, tokenizer.no_timestamps, *text_tokens, tokenizer.eot]).to(model.device)
|
| 182 |
-
|
| 183 |
-
QKs = [None] * model.dims.n_text_layer
|
| 184 |
-
hooks = [block.cross_attn.register_forward_hook(lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])) for i, block in enumerate(model.decoder.blocks)]
|
| 185 |
-
|
| 186 |
-
with torch.no_grad(), disable_sdpa():
|
| 187 |
-
token_probs = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0][len(tokenizer.sot_sequence) :, : tokenizer.eot].softmax(dim=-1)
|
| 188 |
-
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens].tolist()
|
| 189 |
-
|
| 190 |
-
for hook in hooks:
|
| 191 |
-
hook.remove()
|
| 192 |
-
|
| 193 |
-
weights = (torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])[:, :, : num_frames // 2] * qk_scale).softmax(dim=-1)
|
| 194 |
-
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
| 195 |
-
weights = median_filter((weights - mean) / std, medfilt_width)
|
| 196 |
-
|
| 197 |
-
text_indices, time_indices = dtw(-weights.mean(axis=0)[len(tokenizer.sot_sequence) : -1])
|
| 198 |
-
|
| 199 |
-
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
| 200 |
-
if len(word_tokens) <= 1: return []
|
| 201 |
-
|
| 202 |
-
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
| 203 |
-
jump_times = time_indices[np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)] / TOKENS_PER_SECOND
|
| 204 |
-
|
| 205 |
-
return [WordTiming(word, tokens, start, end, probability) for word, tokens, start, end, probability in zip(words, word_tokens, jump_times[word_boundaries[:-1]], jump_times[word_boundaries[1:]], [np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])])]
|
| 206 |
-
|
| 207 |
-
def add_word_timestamps(*, segments, model, tokenizer, mel, num_frames, prepend_punctuations = "\"'“¿([{-", append_punctuations = "\"'.。,,!!??::”)]}、", last_speech_timestamp, **kwargs):
|
| 208 |
-
if len(segments) == 0: return
|
| 209 |
-
|
| 210 |
-
text_tokens_per_segment = [[token for token in segment["tokens"] if token < tokenizer.eot] for segment in segments]
|
| 211 |
-
|
| 212 |
-
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
| 213 |
-
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
| 214 |
-
|
| 215 |
-
word_durations = np.array([t.end - t.start for t in alignment])
|
| 216 |
-
word_durations = word_durations[word_durations.nonzero()]
|
| 217 |
-
|
| 218 |
-
median_duration = min(0.7, float(np.median(word_durations) if len(word_durations) > 0 else 0.0))
|
| 219 |
-
max_duration = median_duration * 2
|
| 220 |
-
|
| 221 |
-
if len(word_durations) > 0:
|
| 222 |
-
sentence_end_marks = ".。!!??"
|
| 223 |
-
for i in range(1, len(alignment)):
|
| 224 |
-
if alignment[i].end - alignment[i].start > max_duration:
|
| 225 |
-
if alignment[i].word in sentence_end_marks: alignment[i].end = alignment[i].start + max_duration
|
| 226 |
-
elif alignment[i - 1].word in sentence_end_marks: alignment[i].start = alignment[i].end - max_duration
|
| 227 |
-
|
| 228 |
-
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
| 229 |
-
|
| 230 |
-
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
| 231 |
-
word_index = 0
|
| 232 |
-
|
| 233 |
-
for segment, text_tokens in zip(segments, text_tokens_per_segment):
|
| 234 |
-
saved_tokens = 0
|
| 235 |
-
words = []
|
| 236 |
-
|
| 237 |
-
while word_index < len(alignment) and saved_tokens < len(text_tokens):
|
| 238 |
-
timing = alignment[word_index]
|
| 239 |
-
|
| 240 |
-
if timing.word: words.append(dict(word=timing.word, start=round(time_offset + timing.start, 2), end=round(time_offset + timing.end, 2), probability=timing.probability))
|
| 241 |
-
|
| 242 |
-
saved_tokens += len(timing.tokens)
|
| 243 |
-
word_index += 1
|
| 244 |
-
|
| 245 |
-
if len(words) > 0:
|
| 246 |
-
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (words[0]["end"] - words[0]["start"] > max_duration or (len(words) > 1 and words[1]["end"] - words[0]["start"] > max_duration * 2)):
|
| 247 |
-
if (len(words) > 1 and words[1]["end"] - words[1]["start"] > max_duration): words[0]["end"] = words[1]["start"] = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
|
| 248 |
-
words[0]["start"] = max(0, words[0]["end"] - max_duration)
|
| 249 |
-
|
| 250 |
-
if (segment["start"] < words[0]["end"] and segment["start"] - 0.5 > words[0]["start"]): words[0]["start"] = max(0, min(words[0]["end"] - median_duration, segment["start"]))
|
| 251 |
-
else: segment["start"] = words[0]["start"]
|
| 252 |
-
|
| 253 |
-
if (segment["end"] > words[-1]["start"] and segment["end"] + 0.5 < words[-1]["end"]): words[-1]["end"] = max(words[-1]["start"] + median_duration, segment["end"])
|
| 254 |
-
else: segment["end"] = words[-1]["end"]
|
| 255 |
-
|
| 256 |
-
last_speech_timestamp = segment["end"]
|
| 257 |
-
|
| 258 |
-
segment["words"] = words
|
| 259 |
-
|
| 260 |
-
@lru_cache(maxsize=None)
|
| 261 |
-
def mel_filters(device, n_mels):
|
| 262 |
-
assert n_mels in {80, 128}
|
| 263 |
-
|
| 264 |
-
with np.load(os.path.join("assets", "models", "speaker_diarization", "assets", "mel_filters.npz"), allow_pickle=False) as f:
|
| 265 |
-
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
| 266 |
-
|
| 267 |
-
def log_mel_spectrogram(audio, n_mels = 80, padding = 0, device = None):
|
| 268 |
-
if not torch.is_tensor(audio):
|
| 269 |
-
if isinstance(audio, str): audio = load_audio(logging.getLogger(__name__), audio, sample_rate=SAMPLE_RATE).astype(np.float32)
|
| 270 |
-
audio = torch.from_numpy(audio)
|
| 271 |
-
|
| 272 |
-
if device is not None: audio = audio.to(device)
|
| 273 |
-
if padding > 0: audio = F.pad(audio, (0, padding))
|
| 274 |
-
|
| 275 |
-
log_spec = torch.clamp(mel_filters(audio.device, n_mels) @ torch.stft(audio, N_FFT, HOP_LENGTH, window=torch.hann_window(N_FFT).to(audio.device), return_complex=True)[..., :-1].abs() ** 2, min=1e-10).log10()
|
| 276 |
-
return (torch.maximum(log_spec, log_spec.max() - 8.0) + 4.0) / 4.0
|
| 277 |
-
|
| 278 |
-
def pad_or_trim(array, length = N_SAMPLES, *, axis = -1):
|
| 279 |
-
if torch.is_tensor(array):
|
| 280 |
-
if array.shape[axis] > length: array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
|
| 281 |
-
|
| 282 |
-
if array.shape[axis] < length:
|
| 283 |
-
pad_widths = [(0, 0)] * array.ndim
|
| 284 |
-
pad_widths[axis] = (0, length - array.shape[axis])
|
| 285 |
-
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
|
| 286 |
-
else:
|
| 287 |
-
if array.shape[axis] > length: array = array.take(indices=range(length), axis=axis)
|
| 288 |
-
|
| 289 |
-
if array.shape[axis] < length:
|
| 290 |
-
pad_widths = [(0, 0)] * array.ndim
|
| 291 |
-
pad_widths[axis] = (0, length - array.shape[axis])
|
| 292 |
-
array = np.pad(array, pad_widths)
|
| 293 |
-
|
| 294 |
-
return array
|
| 295 |
-
|
| 296 |
-
def get_end(segments):
|
| 297 |
-
return next((w["end"] for s in reversed(segments) for w in reversed(s["words"])), segments[-1]["end"] if segments else None)
|
| 298 |
-
|
| 299 |
-
def transcribe_function(model, audio, *, verbose = None, temperature = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), compression_ratio_threshold = 2.4, logprob_threshold = -1.0, no_speech_threshold = 0.6, condition_on_previous_text = True, initial_prompt = None, carry_initial_prompt = False, word_timestamps = False, prepend_punctuations = "\"'“¿([{-", append_punctuations = "\"'.。,,!!??::”)]}、", clip_timestamps = "0", hallucination_silence_threshold = None, fp16 = False, **decode_options):
|
| 300 |
-
dtype = torch.float32
|
| 301 |
-
decode_options["fp16"] = fp16
|
| 302 |
-
|
| 303 |
-
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
| 304 |
-
content_frames = mel.shape[-1] - N_FRAMES
|
| 305 |
-
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
| 306 |
-
|
| 307 |
-
if decode_options.get("language", None) is None:
|
| 308 |
-
if not model.is_multilingual: decode_options["language"] = "vi"
|
| 309 |
-
else:
|
| 310 |
-
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
| 311 |
-
_, probs = model.detect_language(mel_segment)
|
| 312 |
-
decode_options["language"] = max(probs, key=probs.get)
|
| 313 |
-
|
| 314 |
-
if verbose is not None: print(f"{LANGUAGES[decode_options['language']].title()}")
|
| 315 |
-
|
| 316 |
-
language = decode_options["language"]
|
| 317 |
-
task = decode_options.get("task", "transcribe")
|
| 318 |
-
tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages, language=language, task=task)
|
| 319 |
-
|
| 320 |
-
if isinstance(clip_timestamps, str): clip_timestamps = [float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])]
|
| 321 |
-
seek_points = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
|
| 322 |
-
|
| 323 |
-
if len(seek_points) == 0: seek_points.append(0)
|
| 324 |
-
if len(seek_points) % 2 == 1: seek_points.append(content_frames)
|
| 325 |
-
|
| 326 |
-
seek_clips = list(zip(seek_points[::2], seek_points[1::2]))
|
| 327 |
-
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
| 328 |
-
|
| 329 |
-
def decode_with_fallback(segment):
|
| 330 |
-
temperatures = ([temperature] if isinstance(temperature, (int, float)) else temperature)
|
| 331 |
-
decode_result = None
|
| 332 |
-
|
| 333 |
-
for t in temperatures:
|
| 334 |
-
kwargs = {**decode_options}
|
| 335 |
-
|
| 336 |
-
if t > 0:
|
| 337 |
-
kwargs.pop("beam_size", None)
|
| 338 |
-
kwargs.pop("patience", None)
|
| 339 |
-
else: kwargs.pop("best_of", None)
|
| 340 |
-
|
| 341 |
-
decode_result = model.decode(segment, DecodingOptions(**kwargs, temperature=t))
|
| 342 |
-
needs_fallback = False
|
| 343 |
-
|
| 344 |
-
if (compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold): needs_fallback = True
|
| 345 |
-
if (logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold): needs_fallback = True
|
| 346 |
-
if (no_speech_threshold is not None and decode_result.no_speech_prob > no_speech_threshold and logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold): needs_fallback = False
|
| 347 |
-
if not needs_fallback: break
|
| 348 |
-
|
| 349 |
-
return decode_result
|
| 350 |
-
|
| 351 |
-
clip_idx = 0
|
| 352 |
-
seek = seek_clips[clip_idx][0]
|
| 353 |
-
|
| 354 |
-
input_stride = exact_div(N_FRAMES, model.dims.n_audio_ctx)
|
| 355 |
-
time_precision = (input_stride * HOP_LENGTH / SAMPLE_RATE)
|
| 356 |
-
|
| 357 |
-
all_tokens, all_segments = [], []
|
| 358 |
-
prompt_reset_since = 0
|
| 359 |
-
|
| 360 |
-
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
|
| 361 |
-
|
| 362 |
-
if initial_prompt is not None:
|
| 363 |
-
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
| 364 |
-
all_tokens.extend(initial_prompt_tokens)
|
| 365 |
-
remaining_prompt_length -= len(initial_prompt_tokens)
|
| 366 |
-
else: initial_prompt_tokens = []
|
| 367 |
-
|
| 368 |
-
def new_segment(*, start, end, tokens, result):
|
| 369 |
-
tokens = tokens.tolist()
|
| 370 |
-
return {"seek": seek, "start": start, "end": end, "text": tokenizer.decode([token for token in tokens if token < tokenizer.eot]), "tokens": tokens, "temperature": result.temperature, "avg_logprob": result.avg_logprob, "compression_ratio": result.compression_ratio, "no_speech_prob": result.no_speech_prob}
|
| 371 |
-
|
| 372 |
-
with tqdm.tqdm(total=content_frames, unit="frames", disable=verbose is not False) as pbar:
|
| 373 |
-
last_speech_timestamp = 0.0
|
| 374 |
-
while clip_idx < len(seek_clips):
|
| 375 |
-
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
|
| 376 |
-
if seek < seek_clip_start: seek = seek_clip_start
|
| 377 |
-
|
| 378 |
-
if seek >= seek_clip_end:
|
| 379 |
-
clip_idx += 1
|
| 380 |
-
if clip_idx < len(seek_clips): seek = seek_clips[clip_idx][0]
|
| 381 |
-
continue
|
| 382 |
-
|
| 383 |
-
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
| 384 |
-
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
|
| 385 |
-
|
| 386 |
-
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
|
| 387 |
-
mel_segment = mel[:, seek : seek + segment_size]
|
| 388 |
-
|
| 389 |
-
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
| 390 |
-
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
| 391 |
-
|
| 392 |
-
if carry_initial_prompt: decode_options["prompt"] = initial_prompt_tokens + all_tokens[max(len(initial_prompt_tokens), prompt_reset_since):][-remaining_prompt_length:]
|
| 393 |
-
else: decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
| 394 |
-
|
| 395 |
-
result = decode_with_fallback(mel_segment)
|
| 396 |
-
tokens = torch.tensor(result.tokens)
|
| 397 |
-
|
| 398 |
-
if no_speech_threshold is not None:
|
| 399 |
-
should_skip = result.no_speech_prob > no_speech_threshold
|
| 400 |
-
if (logprob_threshold is not None and result.avg_logprob > logprob_threshold):
|
| 401 |
-
should_skip = False
|
| 402 |
-
|
| 403 |
-
if should_skip:
|
| 404 |
-
seek += segment_size
|
| 405 |
-
continue
|
| 406 |
-
|
| 407 |
-
previous_seek = seek
|
| 408 |
-
current_segments = []
|
| 409 |
-
|
| 410 |
-
def word_anomaly_score(word):
|
| 411 |
-
probability = word.get("probability", 0.0)
|
| 412 |
-
duration = word["end"] - word["start"]
|
| 413 |
-
score = 0.0
|
| 414 |
-
|
| 415 |
-
if probability < 0.15: score += 1.0
|
| 416 |
-
if duration < 0.133: score += (0.133 - duration) * 15
|
| 417 |
-
if duration > 2.0: score += duration - 2.0
|
| 418 |
-
|
| 419 |
-
return score
|
| 420 |
-
|
| 421 |
-
def is_segment_anomaly(segment):
|
| 422 |
-
if segment is None or not segment["words"]: return False
|
| 423 |
-
|
| 424 |
-
words = [w for w in segment["words"] if w["word"] not in punctuation]
|
| 425 |
-
words = words[:8]
|
| 426 |
-
|
| 427 |
-
score = sum(word_anomaly_score(w) for w in words)
|
| 428 |
-
|
| 429 |
-
return score >= 3 or score + 0.01 >= len(words)
|
| 430 |
-
|
| 431 |
-
def next_words_segment(segments):
|
| 432 |
-
return next((s for s in segments if s["words"]), None)
|
| 433 |
-
|
| 434 |
-
timestamp_tokens = tokens.ge(tokenizer.timestamp_begin)
|
| 435 |
-
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
| 436 |
-
|
| 437 |
-
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
| 438 |
-
consecutive.add_(1)
|
| 439 |
-
|
| 440 |
-
if len(consecutive) > 0:
|
| 441 |
-
slices = consecutive.tolist()
|
| 442 |
-
if single_timestamp_ending:
|
| 443 |
-
slices.append(len(tokens))
|
| 444 |
-
|
| 445 |
-
last_slice = 0
|
| 446 |
-
for current_slice in slices:
|
| 447 |
-
sliced_tokens = tokens[last_slice:current_slice]
|
| 448 |
-
current_segments.append(new_segment(start=time_offset + (sliced_tokens[0].item() - tokenizer.timestamp_begin) * time_precision, end=time_offset + (sliced_tokens[-1].item() - tokenizer.timestamp_begin) * time_precision, tokens=sliced_tokens, result=result))
|
| 449 |
-
last_slice = current_slice
|
| 450 |
-
|
| 451 |
-
if single_timestamp_ending: seek += segment_size
|
| 452 |
-
else: seek += (tokens[last_slice - 1].item() - tokenizer.timestamp_begin) * input_stride
|
| 453 |
-
else:
|
| 454 |
-
duration = segment_duration
|
| 455 |
-
|
| 456 |
-
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
| 457 |
-
if (len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin): duration = (timestamps[-1].item() - tokenizer.timestamp_begin) * time_precision
|
| 458 |
-
|
| 459 |
-
current_segments.append(new_segment(start=time_offset, end=time_offset + duration, tokens=tokens, result=result))
|
| 460 |
-
seek += segment_size
|
| 461 |
-
|
| 462 |
-
if word_timestamps:
|
| 463 |
-
add_word_timestamps(segments=current_segments, model=model, tokenizer=tokenizer, mel=mel_segment, num_frames=segment_size, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, last_speech_timestamp=last_speech_timestamp)
|
| 464 |
-
|
| 465 |
-
if not single_timestamp_ending:
|
| 466 |
-
last_word_end = get_end(current_segments)
|
| 467 |
-
if last_word_end is not None and last_word_end > time_offset: seek = round(last_word_end * FRAMES_PER_SECOND)
|
| 468 |
-
|
| 469 |
-
if hallucination_silence_threshold is not None:
|
| 470 |
-
threshold = hallucination_silence_threshold
|
| 471 |
-
|
| 472 |
-
if not single_timestamp_ending:
|
| 473 |
-
last_word_end = get_end(current_segments)
|
| 474 |
-
if last_word_end is not None and last_word_end > time_offset: seek = round(last_word_end * FRAMES_PER_SECOND) if (window_end_time - last_word_end) > threshold else (previous_seek + segment_size)
|
| 475 |
-
|
| 476 |
-
first_segment = next_words_segment(current_segments)
|
| 477 |
-
|
| 478 |
-
if first_segment is not None and is_segment_anomaly(first_segment):
|
| 479 |
-
gap = first_segment["start"] - time_offset
|
| 480 |
-
|
| 481 |
-
if gap > threshold:
|
| 482 |
-
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
|
| 483 |
-
continue
|
| 484 |
-
|
| 485 |
-
hal_last_end = last_speech_timestamp
|
| 486 |
-
|
| 487 |
-
for si in range(len(current_segments)):
|
| 488 |
-
segment = current_segments[si]
|
| 489 |
-
if not segment["words"]: continue
|
| 490 |
-
|
| 491 |
-
if is_segment_anomaly(segment):
|
| 492 |
-
next_segment = next_words_segment(current_segments[si + 1 :])
|
| 493 |
-
hal_next_start = next_segment["words"][0]["start"] if next_segment is not None else (time_offset + segment_duration)
|
| 494 |
-
|
| 495 |
-
if (segment["start"] - hal_last_end > threshold or segment["start"] < threshold or segment["start"] - time_offset < 2.0) and (hal_next_start - segment["end"] > threshold or is_segment_anomaly(next_segment) or window_end_time - segment["end"] < 2.0):
|
| 496 |
-
seek = round(max(time_offset + 1, segment["start"]) * FRAMES_PER_SECOND)
|
| 497 |
-
if content_duration - segment["end"] < threshold: seek = content_frames
|
| 498 |
-
|
| 499 |
-
current_segments[si:] = []
|
| 500 |
-
break
|
| 501 |
-
|
| 502 |
-
hal_last_end = segment["end"]
|
| 503 |
-
|
| 504 |
-
last_word_end = get_end(current_segments)
|
| 505 |
-
if last_word_end is not None: last_speech_timestamp = last_word_end
|
| 506 |
-
|
| 507 |
-
for _, segment in enumerate(current_segments):
|
| 508 |
-
if segment["start"] == segment["end"] or segment["text"].strip() == "":
|
| 509 |
-
segment["text"] = ""
|
| 510 |
-
segment["tokens"] = []
|
| 511 |
-
segment["words"] = []
|
| 512 |
-
|
| 513 |
-
all_segments.extend([{"id": i, **segment} for i, segment in enumerate(current_segments, start=len(all_segments))])
|
| 514 |
-
all_tokens.extend([token for segment in current_segments for token in segment["tokens"]])
|
| 515 |
-
|
| 516 |
-
if not condition_on_previous_text or result.temperature > 0.5: prompt_reset_since = len(all_tokens)
|
| 517 |
-
pbar.update(min(content_frames, seek) - previous_seek)
|
| 518 |
-
|
| 519 |
-
return dict(text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), segments=all_segments, language=language)
|
| 520 |
-
|
| 521 |
-
def compression_ratio(text):
|
| 522 |
-
text_bytes = text.encode("utf-8")
|
| 523 |
-
return len(text_bytes) / len(zlib.compress(text_bytes))
|
| 524 |
-
|
| 525 |
-
def sinusoids(length, channels, max_timescale=10000):
|
| 526 |
-
assert channels % 2 == 0
|
| 527 |
-
|
| 528 |
-
scaled_time = torch.arange(length)[:, np.newaxis] * torch.exp(-(np.log(max_timescale) / (channels // 2 - 1)) * torch.arange(channels // 2))[np.newaxis, :]
|
| 529 |
-
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
| 530 |
-
|
| 531 |
-
@torch.no_grad()
|
| 532 |
-
def detect_language_function(model, mel, tokenizer = None):
|
| 533 |
-
if tokenizer is None: tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
|
| 534 |
-
if (tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence): raise ValueError
|
| 535 |
-
|
| 536 |
-
single = mel.ndim == 2
|
| 537 |
-
|
| 538 |
-
if single: mel = mel.unsqueeze(0)
|
| 539 |
-
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): mel = model.encoder(mel)
|
| 540 |
-
|
| 541 |
-
n_audio = mel.shape[0]
|
| 542 |
-
logits = model.logits(torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device), mel)[:, 0]
|
| 543 |
-
|
| 544 |
-
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
| 545 |
-
mask[list(tokenizer.all_language_tokens)] = False
|
| 546 |
-
|
| 547 |
-
logits[:, mask] = -np.inf
|
| 548 |
-
|
| 549 |
-
language_tokens = logits.argmax(dim=-1)
|
| 550 |
-
language_probs = [{c: logits.softmax(dim=-1).cpu()[i, j].item() for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)} for i in range(n_audio)]
|
| 551 |
-
|
| 552 |
-
if single:
|
| 553 |
-
language_tokens = language_tokens[0]
|
| 554 |
-
language_probs = language_probs[0]
|
| 555 |
-
|
| 556 |
-
return language_tokens, language_probs
|
| 557 |
-
|
| 558 |
-
@lru_cache(maxsize=None)
|
| 559 |
-
def get_tokenizer(multilingual, *, num_languages = 99, language = None, task = None):
|
| 560 |
-
if language is not None:
|
| 561 |
-
language = language.lower()
|
| 562 |
-
if language not in LANGUAGES:
|
| 563 |
-
if language in TO_LANGUAGE_CODE: language = TO_LANGUAGE_CODE[language]
|
| 564 |
-
else: raise ValueError
|
| 565 |
-
|
| 566 |
-
if multilingual:
|
| 567 |
-
encoding_name = "multilingual"
|
| 568 |
-
language = language or "en"
|
| 569 |
-
task = task or "transcribe"
|
| 570 |
-
else:
|
| 571 |
-
encoding_name = "gpt2"
|
| 572 |
-
language = None
|
| 573 |
-
task = None
|
| 574 |
-
|
| 575 |
-
return Tokenizer(encoding_name=encoding_name, num_languages=num_languages, language=language, task=task)
|
| 576 |
-
|
| 577 |
-
@lru_cache(maxsize=None)
|
| 578 |
-
def get_encoding(name = "gpt2", num_languages = 99):
|
| 579 |
-
vocab_path = os.path.join("assets", "models", "speaker_diarization", "assets", f"{name}.tiktoken")
|
| 580 |
-
ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in open(vocab_path) if line)}
|
| 581 |
-
|
| 582 |
-
n_vocab = len(ranks)
|
| 583 |
-
special_tokens = {}
|
| 584 |
-
|
| 585 |
-
specials = ["<|endoftext|>", "<|startoftranscript|>", *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], "<|translate|>", "<|transcribe|>", "<|startoflm|>", "<|startofprev|>", "<|nospeech|>", "<|notimestamps|>", *[f"<|{i * 0.02:.2f}|>" for i in range(1501)]]
|
| 586 |
-
|
| 587 |
-
for token in specials:
|
| 588 |
-
special_tokens[token] = n_vocab
|
| 589 |
-
n_vocab += 1
|
| 590 |
-
|
| 591 |
-
return tiktoken.Encoding(name=os.path.basename(vocab_path), explicit_n_vocab=n_vocab, pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", mergeable_ranks=ranks, special_tokens=special_tokens)
|
| 592 |
-
|
| 593 |
-
class DecodingOptions:
|
| 594 |
-
def __init__(self, task = "transcribe", language = None, temperature = 0.0, sample_len = None, best_of = None, beam_size = None, patience = None, length_penalty = None, prompt = None, prefix = None, suppress_tokens = "-1", suppress_blank = True, without_timestamps = False, max_initial_timestamp = 1.0, fp16 = False):
|
| 595 |
-
self.task = task
|
| 596 |
-
self.language = language
|
| 597 |
-
self.temperature = temperature
|
| 598 |
-
self.sample_len = sample_len
|
| 599 |
-
self.best_of = best_of
|
| 600 |
-
self.beam_size = beam_size
|
| 601 |
-
self.patience = patience
|
| 602 |
-
self.length_penalty = length_penalty
|
| 603 |
-
self.prompt = prompt
|
| 604 |
-
self.prefix = prefix
|
| 605 |
-
self.suppress_tokens = suppress_tokens
|
| 606 |
-
self.suppress_blank = suppress_blank
|
| 607 |
-
self.without_timestamps = without_timestamps
|
| 608 |
-
self.max_initial_timestamp = max_initial_timestamp
|
| 609 |
-
self.fp16 = fp16
|
| 610 |
-
|
| 611 |
-
@torch.no_grad()
|
| 612 |
-
def decode_function(model, mel, options = DecodingOptions(), **kwargs):
|
| 613 |
-
if single := mel.ndim == 2: mel = mel.unsqueeze(0)
|
| 614 |
-
if kwargs: options = replace(options, **kwargs)
|
| 615 |
-
|
| 616 |
-
result = DecodingTask(model, options).run(mel)
|
| 617 |
-
return result[0] if single else result
|
| 618 |
-
|
| 619 |
-
@dataclass
|
| 620 |
-
class ModelDimensions:
|
| 621 |
-
n_mels: int
|
| 622 |
-
n_audio_ctx: int
|
| 623 |
-
n_audio_state: int
|
| 624 |
-
n_audio_head: int
|
| 625 |
-
n_audio_layer: int
|
| 626 |
-
n_vocab: int
|
| 627 |
-
n_text_ctx: int
|
| 628 |
-
n_text_state: int
|
| 629 |
-
n_text_head: int
|
| 630 |
-
n_text_layer: int
|
| 631 |
-
|
| 632 |
-
class LayerNorm(nn.LayerNorm):
|
| 633 |
-
def forward(self, x):
|
| 634 |
-
return super().forward(x.float()).type(x.dtype)
|
| 635 |
-
|
| 636 |
-
class Linear(nn.Linear):
|
| 637 |
-
def forward(self, x):
|
| 638 |
-
return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype))
|
| 639 |
-
|
| 640 |
-
class Conv1d(nn.Conv1d):
|
| 641 |
-
def _conv_forward(self, x, weight, bias):
|
| 642 |
-
return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
|
| 643 |
-
|
| 644 |
-
class TextDecoder(nn.Module):
|
| 645 |
-
def __init__(self, n_vocab, n_ctx, n_state, n_head, n_layer):
|
| 646 |
-
super().__init__()
|
| 647 |
-
|
| 648 |
-
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
| 649 |
-
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
| 650 |
-
|
| 651 |
-
self.blocks = nn.ModuleList([ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)])
|
| 652 |
-
self.ln = LayerNorm(n_state)
|
| 653 |
-
self.register_buffer("mask", torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1), persistent=False)
|
| 654 |
-
|
| 655 |
-
def forward(self, x, xa, kv_cache = None):
|
| 656 |
-
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
| 657 |
-
x = (self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]).to(xa.dtype)
|
| 658 |
-
|
| 659 |
-
for block in self.blocks:
|
| 660 |
-
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
| 661 |
-
|
| 662 |
-
x = self.ln(x)
|
| 663 |
-
return (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
|
| 664 |
-
|
| 665 |
-
class AudioEncoder(nn.Module):
|
| 666 |
-
def __init__(self, n_mels, n_ctx, n_state, n_head, n_layer):
|
| 667 |
-
super().__init__()
|
| 668 |
-
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
| 669 |
-
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
| 670 |
-
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
| 671 |
-
|
| 672 |
-
self.blocks = nn.ModuleList([ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)])
|
| 673 |
-
self.ln_post = LayerNorm(n_state)
|
| 674 |
-
|
| 675 |
-
def forward(self, x):
|
| 676 |
-
x = F.gelu(self.conv2(F.gelu(self.conv1(x)))).permute(0, 2, 1)
|
| 677 |
-
|
| 678 |
-
assert x.shape[1:] == self.positional_embedding.shape
|
| 679 |
-
x = (x + self.positional_embedding).to(x.dtype)
|
| 680 |
-
|
| 681 |
-
for block in self.blocks:
|
| 682 |
-
x = block(x)
|
| 683 |
-
|
| 684 |
-
return self.ln_post(x)
|
| 685 |
-
|
| 686 |
-
class Whisper(nn.Module):
|
| 687 |
-
def __init__(self, dims):
|
| 688 |
-
super().__init__()
|
| 689 |
-
self.dims = dims
|
| 690 |
-
self.encoder = AudioEncoder(self.dims.n_mels, self.dims.n_audio_ctx, self.dims.n_audio_state, self.dims.n_audio_head, self.dims.n_audio_layer)
|
| 691 |
-
self.decoder = TextDecoder(self.dims.n_vocab, self.dims.n_text_ctx, self.dims.n_text_state, self.dims.n_text_head, self.dims.n_text_layer)
|
| 692 |
-
|
| 693 |
-
all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
|
| 694 |
-
all_heads[self.dims.n_text_layer // 2 :] = True
|
| 695 |
-
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
| 696 |
-
|
| 697 |
-
def set_alignment_heads(self, dump):
|
| 698 |
-
self.register_buffer("alignment_heads", torch.from_numpy(np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()).reshape(self.dims.n_text_layer, self.dims.n_text_head).to_sparse(), persistent=False)
|
| 699 |
-
|
| 700 |
-
def embed_audio(self, mel):
|
| 701 |
-
return self.encoder(mel)
|
| 702 |
-
|
| 703 |
-
def logits(self, tokens, audio_features):
|
| 704 |
-
return self.decoder(tokens, audio_features)
|
| 705 |
-
|
| 706 |
-
def forward(self, mel, tokens):
|
| 707 |
-
return self.decoder(tokens, self.encoder(mel))
|
| 708 |
-
|
| 709 |
-
@property
|
| 710 |
-
def device(self):
|
| 711 |
-
return next(self.parameters()).device
|
| 712 |
-
|
| 713 |
-
@property
|
| 714 |
-
def is_multilingual(self):
|
| 715 |
-
return self.dims.n_vocab >= 51865
|
| 716 |
-
|
| 717 |
-
@property
|
| 718 |
-
def num_languages(self):
|
| 719 |
-
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
| 720 |
-
|
| 721 |
-
def install_kv_cache_hooks(self, cache = None):
|
| 722 |
-
cache = {**cache} if cache is not None else {}
|
| 723 |
-
hooks = []
|
| 724 |
-
|
| 725 |
-
def save_to_cache(module, _, output):
|
| 726 |
-
cache[module] = output if module not in cache or output.shape[1] > self.dims.n_text_ctx else torch.cat([cache[module], output], dim=1).detach()
|
| 727 |
-
return cache[module]
|
| 728 |
-
|
| 729 |
-
def install_hooks(layer: nn.Module):
|
| 730 |
-
if isinstance(layer, MultiHeadAttention):
|
| 731 |
-
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
| 732 |
-
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
| 733 |
-
|
| 734 |
-
self.decoder.apply(install_hooks)
|
| 735 |
-
return cache, hooks
|
| 736 |
-
|
| 737 |
-
detect_language = detect_language_function
|
| 738 |
-
transcribe = transcribe_function
|
| 739 |
-
decode = decode_function
|
| 740 |
-
|
| 741 |
-
class ResidualAttentionBlock(nn.Module):
|
| 742 |
-
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
| 743 |
-
super().__init__()
|
| 744 |
-
|
| 745 |
-
self.attn = MultiHeadAttention(n_state, n_head)
|
| 746 |
-
self.attn_ln = LayerNorm(n_state)
|
| 747 |
-
|
| 748 |
-
self.cross_attn = (MultiHeadAttention(n_state, n_head) if cross_attention else None)
|
| 749 |
-
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
| 750 |
-
|
| 751 |
-
n_mlp = n_state * 4
|
| 752 |
-
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
|
| 753 |
-
self.mlp_ln = LayerNorm(n_state)
|
| 754 |
-
|
| 755 |
-
def forward(self, x, xa = None, mask = None, kv_cache = None):
|
| 756 |
-
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
| 757 |
-
if self.cross_attn: x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
| 758 |
-
|
| 759 |
-
return x + self.mlp(self.mlp_ln(x))
|
| 760 |
-
|
| 761 |
-
class MultiHeadAttention(nn.Module):
|
| 762 |
-
def __init__(self, n_state, n_head):
|
| 763 |
-
super().__init__()
|
| 764 |
-
self.n_head = n_head
|
| 765 |
-
self.query = Linear(n_state, n_state)
|
| 766 |
-
self.key = Linear(n_state, n_state, bias=False)
|
| 767 |
-
self.value = Linear(n_state, n_state)
|
| 768 |
-
self.out = Linear(n_state, n_state)
|
| 769 |
-
|
| 770 |
-
def forward(self, x, xa = None, mask = None, kv_cache = None):
|
| 771 |
-
k, v = (self.key(x if xa is None else xa), self.value(x if xa is None else xa)) if kv_cache is None or xa is None or self.key not in kv_cache else (kv_cache[self.key], kv_cache[self.value])
|
| 772 |
-
wv, qk = self.qkv_attention(self.query(x), k, v, mask)
|
| 773 |
-
|
| 774 |
-
return self.out(wv), qk
|
| 775 |
-
|
| 776 |
-
def qkv_attention(self, q, k, v, mask = None):
|
| 777 |
-
_, n_ctx, _ = q.shape
|
| 778 |
-
|
| 779 |
-
q, k, v = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3), k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3), v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
| 780 |
-
return scaled_dot_product_attention(q, k, v, is_causal=mask is not None and n_ctx > 1).permute(0, 2, 1, 3).flatten(start_dim=2), None
|
| 781 |
-
|
| 782 |
-
class LogitFilter:
|
| 783 |
-
def apply(self, logits, tokens):
|
| 784 |
-
pass
|
| 785 |
-
|
| 786 |
-
class SuppressBlank(LogitFilter):
|
| 787 |
-
def __init__(self, tokenizer, sample_begin):
|
| 788 |
-
self.tokenizer = tokenizer
|
| 789 |
-
self.sample_begin = sample_begin
|
| 790 |
-
|
| 791 |
-
def apply(self, logits, tokens):
|
| 792 |
-
if tokens.shape[1] == self.sample_begin: logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
| 793 |
-
|
| 794 |
-
class SuppressTokens(LogitFilter):
|
| 795 |
-
def __init__(self, suppress_tokens):
|
| 796 |
-
self.suppress_tokens = list(suppress_tokens)
|
| 797 |
-
|
| 798 |
-
def apply(self, logits, tokens):
|
| 799 |
-
logits[:, self.suppress_tokens] = -np.inf
|
| 800 |
-
|
| 801 |
-
class Inference:
|
| 802 |
-
def logits(self, tokens, audio_features):
|
| 803 |
-
pass
|
| 804 |
-
|
| 805 |
-
def rearrange_kv_cache(self, source_indices):
|
| 806 |
-
pass
|
| 807 |
-
|
| 808 |
-
def cleanup_caching(self):
|
| 809 |
-
pass
|
| 810 |
-
|
| 811 |
-
class PyTorchInference(Inference):
|
| 812 |
-
def __init__(self, model, initial_token_length):
|
| 813 |
-
self.model = model
|
| 814 |
-
self.initial_token_length = initial_token_length
|
| 815 |
-
self.kv_cache = {}
|
| 816 |
-
self.hooks = []
|
| 817 |
-
|
| 818 |
-
self.kv_modules = [block.attn.key for block in self.model.decoder.blocks] + [block.attn.value for block in self.model.decoder.blocks]
|
| 819 |
-
|
| 820 |
-
def logits(self, tokens, audio_features):
|
| 821 |
-
if not self.kv_cache: self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
| 822 |
-
if tokens.shape[-1] > self.initial_token_length: tokens = tokens[:, -1:]
|
| 823 |
-
|
| 824 |
-
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
| 825 |
-
|
| 826 |
-
def cleanup_caching(self):
|
| 827 |
-
for hook in self.hooks:
|
| 828 |
-
hook.remove()
|
| 829 |
-
|
| 830 |
-
self.kv_cache = {}
|
| 831 |
-
self.hooks = []
|
| 832 |
-
|
| 833 |
-
def rearrange_kv_cache(self, source_indices):
|
| 834 |
-
if source_indices != list(range(len(source_indices))):
|
| 835 |
-
for module in self.kv_modules:
|
| 836 |
-
self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
|
| 837 |
-
|
| 838 |
-
class SequenceRanker:
|
| 839 |
-
def rank(self, tokens, sum_logprobs):
|
| 840 |
-
pass
|
| 841 |
-
|
| 842 |
-
class MaximumLikelihoodRanker(SequenceRanker):
|
| 843 |
-
def __init__(self, length_penalty):
|
| 844 |
-
self.length_penalty = length_penalty
|
| 845 |
-
|
| 846 |
-
def rank(self, tokens, sum_logprobs):
|
| 847 |
-
def scores(logprobs, lengths):
|
| 848 |
-
result = []
|
| 849 |
-
for logprob, length in zip(logprobs, lengths):
|
| 850 |
-
result.append(logprob / (length if self.length_penalty is None else ((5 + length) / 6) ** self.length_penalty))
|
| 851 |
-
return result
|
| 852 |
-
|
| 853 |
-
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, [[len(t) for t in s] for s in tokens])]
|
| 854 |
-
|
| 855 |
-
class TokenDecoder:
|
| 856 |
-
def reset(self):
|
| 857 |
-
pass
|
| 858 |
-
|
| 859 |
-
def update(self, tokens, logits, sum_logprobs):
|
| 860 |
-
pass
|
| 861 |
-
|
| 862 |
-
def finalize(self, tokens, sum_logprobs):
|
| 863 |
-
pass
|
| 864 |
-
|
| 865 |
-
|
| 866 |
-
class GreedyDecoder(TokenDecoder):
|
| 867 |
-
def __init__(self, temperature, eot):
|
| 868 |
-
self.temperature = temperature
|
| 869 |
-
self.eot = eot
|
| 870 |
-
|
| 871 |
-
def update(self, tokens, logits, sum_logprobs):
|
| 872 |
-
next_tokens = logits.argmax(dim=-1) if self.temperature == 0 else Categorical(logits=logits / self.temperature).sample()
|
| 873 |
-
|
| 874 |
-
logprobs = F.log_softmax(logits.float(), dim=-1)
|
| 875 |
-
sum_logprobs += logprobs[torch.arange(logprobs.shape[0]), next_tokens] * (tokens[:, -1] != self.eot)
|
| 876 |
-
|
| 877 |
-
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
| 878 |
-
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
| 879 |
-
|
| 880 |
-
return tokens, (tokens[:, -1] == self.eot).all()
|
| 881 |
-
|
| 882 |
-
def finalize(self, tokens, sum_logprobs):
|
| 883 |
-
return F.pad(tokens, (0, 1), value=self.eot), sum_logprobs.tolist()
|
| 884 |
-
|
| 885 |
-
class BeamSearchDecoder(TokenDecoder):
|
| 886 |
-
def __init__(self, beam_size, eot, inference, patience = None):
|
| 887 |
-
self.beam_size = beam_size
|
| 888 |
-
self.eot = eot
|
| 889 |
-
self.inference = inference
|
| 890 |
-
self.patience = patience or 1.0
|
| 891 |
-
self.max_candidates = round(beam_size * self.patience)
|
| 892 |
-
self.finished_sequences = None
|
| 893 |
-
|
| 894 |
-
assert (self.max_candidates > 0)
|
| 895 |
-
|
| 896 |
-
def reset(self):
|
| 897 |
-
self.finished_sequences = None
|
| 898 |
-
|
| 899 |
-
def update(self, tokens, logits, sum_logprobs):
|
| 900 |
-
if tokens.shape[0] % self.beam_size != 0: raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
| 901 |
-
|
| 902 |
-
n_audio = tokens.shape[0] // self.beam_size
|
| 903 |
-
if self.finished_sequences is None: self.finished_sequences = [{} for _ in range(n_audio)]
|
| 904 |
-
|
| 905 |
-
logprobs = F.log_softmax(logits.float(), dim=-1)
|
| 906 |
-
next_tokens, source_indices, finished_sequences = [], [], []
|
| 907 |
-
|
| 908 |
-
for i in range(n_audio):
|
| 909 |
-
scores, sources, finished = {}, {}, {}
|
| 910 |
-
|
| 911 |
-
for j in range(self.beam_size):
|
| 912 |
-
idx = i * self.beam_size + j
|
| 913 |
-
prefix = tokens[idx].tolist()
|
| 914 |
-
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
|
| 915 |
-
sequence = tuple(prefix + [token.item()])
|
| 916 |
-
scores[sequence] = (sum_logprobs[idx] + logprob).item()
|
| 917 |
-
sources[sequence] = idx
|
| 918 |
-
|
| 919 |
-
saved = 0
|
| 920 |
-
|
| 921 |
-
for sequence in sorted(scores, key=scores.get, reverse=True):
|
| 922 |
-
if sequence[-1] == self.eot: finished[sequence] = scores[sequence]
|
| 923 |
-
else:
|
| 924 |
-
sum_logprobs[len(next_tokens)] = scores[sequence]
|
| 925 |
-
next_tokens.append(sequence)
|
| 926 |
-
source_indices.append(sources[sequence])
|
| 927 |
-
|
| 928 |
-
saved += 1
|
| 929 |
-
if saved == self.beam_size: break
|
| 930 |
-
|
| 931 |
-
finished_sequences.append(finished)
|
| 932 |
-
|
| 933 |
-
self.inference.rearrange_kv_cache(source_indices)
|
| 934 |
-
assert len(self.finished_sequences) == len(finished_sequences)
|
| 935 |
-
|
| 936 |
-
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
|
| 937 |
-
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
| 938 |
-
if len(previously_finished) >= self.max_candidates: break
|
| 939 |
-
previously_finished[seq] = newly_finished[seq]
|
| 940 |
-
|
| 941 |
-
return torch.tensor(next_tokens, device=tokens.device), all(len(sequences) >= self.max_candidates for sequences in self.finished_sequences)
|
| 942 |
-
|
| 943 |
-
def finalize(self, preceding_tokens, sum_logprobs):
|
| 944 |
-
sum_logprobs = sum_logprobs.cpu()
|
| 945 |
-
|
| 946 |
-
for i, sequences in enumerate(self.finished_sequences):
|
| 947 |
-
if (len(sequences) < self.beam_size):
|
| 948 |
-
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
| 949 |
-
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
| 950 |
-
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
| 951 |
-
if len(sequences) >= self.beam_size: break
|
| 952 |
-
|
| 953 |
-
return [[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences], [list(sequences.values()) for sequences in self.finished_sequences]
|
| 954 |
-
|
| 955 |
-
class ApplyTimestampRules(LogitFilter):
|
| 956 |
-
def __init__(self, tokenizer, sample_begin, max_initial_timestamp_index):
|
| 957 |
-
self.tokenizer = tokenizer
|
| 958 |
-
self.sample_begin = sample_begin
|
| 959 |
-
self.max_initial_timestamp_index = max_initial_timestamp_index
|
| 960 |
-
|
| 961 |
-
def apply(self, logits, tokens):
|
| 962 |
-
if self.tokenizer.no_timestamps is not None: logits[:, self.tokenizer.no_timestamps] = -np.inf
|
| 963 |
-
|
| 964 |
-
for k in range(tokens.shape[0]):
|
| 965 |
-
sampled_tokens = tokens[k, self.sample_begin :]
|
| 966 |
-
seq = [t for t in sampled_tokens.tolist()]
|
| 967 |
-
|
| 968 |
-
last_was_timestamp = (len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin)
|
| 969 |
-
penultimate_was_timestamp = (len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin)
|
| 970 |
-
|
| 971 |
-
if last_was_timestamp:
|
| 972 |
-
if penultimate_was_timestamp: logits[k, self.tokenizer.timestamp_begin :] = -np.inf
|
| 973 |
-
else: logits[k, : self.tokenizer.eot] = -np.inf
|
| 974 |
-
|
| 975 |
-
timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)]
|
| 976 |
-
|
| 977 |
-
if timestamps.numel() > 0: logits[k, self.tokenizer.timestamp_begin : timestamps[-1] if last_was_timestamp and not penultimate_was_timestamp else (timestamps[-1] + 1)] = -np.inf
|
| 978 |
-
|
| 979 |
-
if tokens.shape[1] == self.sample_begin:
|
| 980 |
-
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
|
| 981 |
-
|
| 982 |
-
if self.max_initial_timestamp_index is not None:
|
| 983 |
-
last_allowed = (self.tokenizer.timestamp_begin + self.max_initial_timestamp_index)
|
| 984 |
-
logits[:, last_allowed + 1 :] = -np.inf
|
| 985 |
-
|
| 986 |
-
logprobs = F.log_softmax(logits.float(), dim=-1)
|
| 987 |
-
for k in range(tokens.shape[0]):
|
| 988 |
-
if logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1) > logprobs[k, : self.tokenizer.timestamp_begin].max(): logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
| 989 |
-
|
| 990 |
-
class DecodingTask:
|
| 991 |
-
def __init__(self, model, options):
|
| 992 |
-
self.model = model
|
| 993 |
-
|
| 994 |
-
language = options.language or "en"
|
| 995 |
-
tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages, language=language, task=options.task)
|
| 996 |
-
|
| 997 |
-
self.tokenizer = tokenizer
|
| 998 |
-
self.options = self._verify_options(options)
|
| 999 |
-
|
| 1000 |
-
self.n_group = options.beam_size or options.best_of or 1
|
| 1001 |
-
self.n_ctx = model.dims.n_text_ctx
|
| 1002 |
-
self.sample_len = options.sample_len or model.dims.n_text_ctx // 2
|
| 1003 |
-
|
| 1004 |
-
self.sot_sequence = tokenizer.sot_sequence
|
| 1005 |
-
if self.options.without_timestamps: self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
| 1006 |
-
|
| 1007 |
-
self.initial_tokens = self._get_initial_tokens()
|
| 1008 |
-
self.sample_begin = len(self.initial_tokens)
|
| 1009 |
-
self.sot_index = self.initial_tokens.index(tokenizer.sot)
|
| 1010 |
-
self.inference = PyTorchInference(model, len(self.initial_tokens))
|
| 1011 |
-
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
| 1012 |
-
self.decoder = BeamSearchDecoder(options.beam_size, tokenizer.eot, self.inference, options.patience) if options.beam_size is not None else GreedyDecoder(options.temperature, tokenizer.eot)
|
| 1013 |
-
|
| 1014 |
-
self.logit_filters = []
|
| 1015 |
-
|
| 1016 |
-
if self.options.suppress_blank: self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
|
| 1017 |
-
if self.options.suppress_tokens: self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
|
| 1018 |
-
|
| 1019 |
-
if not options.without_timestamps:
|
| 1020 |
-
max_initial_timestamp_index = None
|
| 1021 |
-
if options.max_initial_timestamp: max_initial_timestamp_index = round(self.options.max_initial_timestamp / (CHUNK_LENGTH / model.dims.n_audio_ctx))
|
| 1022 |
-
self.logit_filters.append(ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index))
|
| 1023 |
-
|
| 1024 |
-
def _verify_options(self, options):
|
| 1025 |
-
if options.beam_size is not None and options.best_of is not None: raise ValueError
|
| 1026 |
-
if options.temperature == 0 and options.best_of is not None: raise ValueError
|
| 1027 |
-
if options.patience is not None and options.beam_size is None: raise ValueError
|
| 1028 |
-
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1): raise ValueError
|
| 1029 |
-
|
| 1030 |
-
return options
|
| 1031 |
-
|
| 1032 |
-
def _get_initial_tokens(self):
|
| 1033 |
-
tokens = list(self.sot_sequence)
|
| 1034 |
-
|
| 1035 |
-
if prefix := self.options.prefix:
|
| 1036 |
-
prefix_tokens = (self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix)
|
| 1037 |
-
if self.sample_len is not None: prefix_tokens = prefix_tokens[-(self.n_ctx // 2 - self.sample_len):]
|
| 1038 |
-
tokens = tokens + prefix_tokens
|
| 1039 |
-
|
| 1040 |
-
if prompt := self.options.prompt: tokens = ([self.tokenizer.sot_prev] + (self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt)[-(self.n_ctx // 2 - 1) :] + tokens)
|
| 1041 |
-
|
| 1042 |
-
return tuple(tokens)
|
| 1043 |
-
|
| 1044 |
-
def _get_suppress_tokens(self):
|
| 1045 |
-
suppress_tokens = self.options.suppress_tokens
|
| 1046 |
-
if isinstance(suppress_tokens, str): suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
| 1047 |
-
|
| 1048 |
-
if -1 in suppress_tokens:
|
| 1049 |
-
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
| 1050 |
-
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
| 1051 |
-
elif suppress_tokens is None or len(suppress_tokens) == 0: suppress_tokens = []
|
| 1052 |
-
else: assert isinstance(suppress_tokens, list)
|
| 1053 |
-
|
| 1054 |
-
suppress_tokens.extend([self.tokenizer.transcribe, self.tokenizer.translate, self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm])
|
| 1055 |
-
|
| 1056 |
-
if self.tokenizer.no_speech is not None: suppress_tokens.append(self.tokenizer.no_speech)
|
| 1057 |
-
return tuple(sorted(set(suppress_tokens)))
|
| 1058 |
-
|
| 1059 |
-
def _get_audio_features(self, mel):
|
| 1060 |
-
if self.options.fp16: mel = mel.half()
|
| 1061 |
-
|
| 1062 |
-
audio_features = mel if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state) else self.model.encoder(mel)
|
| 1063 |
-
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32): return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
|
| 1064 |
-
|
| 1065 |
-
return audio_features
|
| 1066 |
-
|
| 1067 |
-
def _detect_language(self, audio_features, tokens):
|
| 1068 |
-
languages = [self.options.language] * audio_features.shape[0]
|
| 1069 |
-
lang_probs = None
|
| 1070 |
-
|
| 1071 |
-
if self.options.language is None or self.options.task == "lang_id":
|
| 1072 |
-
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
|
| 1073 |
-
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
| 1074 |
-
|
| 1075 |
-
if self.options.language is None: tokens[:, self.sot_index + 1] = lang_tokens
|
| 1076 |
-
|
| 1077 |
-
return languages, lang_probs
|
| 1078 |
-
|
| 1079 |
-
def _main_loop(self, audio_features, tokens):
|
| 1080 |
-
n_batch = tokens.shape[0]
|
| 1081 |
-
sum_logprobs = torch.zeros(n_batch, device=audio_features.device)
|
| 1082 |
-
no_speech_probs = [np.nan] * n_batch
|
| 1083 |
-
|
| 1084 |
-
try:
|
| 1085 |
-
for i in range(self.sample_len):
|
| 1086 |
-
logits = self.inference.logits(tokens, audio_features)
|
| 1087 |
-
|
| 1088 |
-
if (i == 0 and self.tokenizer.no_speech is not None):
|
| 1089 |
-
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
| 1090 |
-
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
| 1091 |
-
|
| 1092 |
-
logits = logits[:, -1]
|
| 1093 |
-
for logit_filter in self.logit_filters:
|
| 1094 |
-
logit_filter.apply(logits, tokens)
|
| 1095 |
-
|
| 1096 |
-
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
| 1097 |
-
if completed or tokens.shape[-1] > self.n_ctx: break
|
| 1098 |
-
finally:
|
| 1099 |
-
self.inference.cleanup_caching()
|
| 1100 |
-
|
| 1101 |
-
return tokens, sum_logprobs, no_speech_probs
|
| 1102 |
-
|
| 1103 |
-
@torch.no_grad()
|
| 1104 |
-
def run(self, mel):
|
| 1105 |
-
self.decoder.reset()
|
| 1106 |
-
tokenizer = self.tokenizer
|
| 1107 |
-
n_audio = mel.shape[0]
|
| 1108 |
-
|
| 1109 |
-
audio_features = self._get_audio_features(mel)
|
| 1110 |
-
tokens = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
| 1111 |
-
|
| 1112 |
-
languages, language_probs = self._detect_language(audio_features, tokens)
|
| 1113 |
-
if self.options.task == "lang_id": return [DecodingResult(audio_features=features, language=language, language_probs=probs) for features, language, probs in zip(audio_features, languages, language_probs)]
|
| 1114 |
-
|
| 1115 |
-
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
| 1116 |
-
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
| 1117 |
-
|
| 1118 |
-
audio_features = audio_features[:: self.n_group]
|
| 1119 |
-
no_speech_probs = no_speech_probs[:: self.n_group]
|
| 1120 |
-
|
| 1121 |
-
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
| 1122 |
-
|
| 1123 |
-
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
| 1124 |
-
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
| 1125 |
-
|
| 1126 |
-
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
| 1127 |
-
tokens = [[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens]
|
| 1128 |
-
|
| 1129 |
-
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
| 1130 |
-
tokens = [t[i].tolist() for i, t in zip(selected, tokens)]
|
| 1131 |
-
|
| 1132 |
-
fields = ([tokenizer.decode(t).strip() for t in tokens], languages, tokens, audio_features, [lp / (len(t) + 1) for t, lp in zip(tokens, [lp[i] for i, lp in zip(selected, sum_logprobs)])], no_speech_probs)
|
| 1133 |
-
if len(set(map(len, fields))) != 1: raise RuntimeError
|
| 1134 |
-
|
| 1135 |
-
return [DecodingResult(audio_features=features, language=language, tokens=tokens, text=text, avg_logprob=avg_logprob, no_speech_prob=no_speech_prob, temperature=self.options.temperature, compression_ratio=compression_ratio(text)) for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)]
|
| 1136 |
-
|
| 1137 |
-
class DecodingResult:
|
| 1138 |
-
def __init__(self, audio_features, language, language_probs = None, tokens = None, text = "", avg_logprob = np.nan, no_speech_prob = np.nan, temperature = np.nan, compression_ratio = np.nan):
|
| 1139 |
-
self.audio_features = audio_features
|
| 1140 |
-
self.language = language
|
| 1141 |
-
self.language_probs = language_probs if language_probs is not None else {}
|
| 1142 |
-
self.tokens = tokens if tokens is not None else []
|
| 1143 |
-
self.text = text
|
| 1144 |
-
self.avg_logprob = avg_logprob
|
| 1145 |
-
self.no_speech_prob = no_speech_prob
|
| 1146 |
-
self.temperature = temperature
|
| 1147 |
-
self.compression_ratio = compression_ratio
|
| 1148 |
-
|
| 1149 |
-
class Tokenizer:
|
| 1150 |
-
def __init__(self, encoding_name, num_languages = 2, language = None, task = None, sot_sequence = ()):
|
| 1151 |
-
self.encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
| 1152 |
-
self.num_languages = num_languages
|
| 1153 |
-
self.language = language
|
| 1154 |
-
self.task = task
|
| 1155 |
-
self.sot_sequence = sot_sequence
|
| 1156 |
-
self.special_tokens = {}
|
| 1157 |
-
|
| 1158 |
-
for special in self.encoding.special_tokens_set:
|
| 1159 |
-
special_token = self.encoding.encode_single_token(special)
|
| 1160 |
-
self.special_tokens[special] = special_token
|
| 1161 |
-
|
| 1162 |
-
sot = self.special_tokens["<|startoftranscript|>"]
|
| 1163 |
-
langs = tuple(LANGUAGES.keys())[: self.num_languages]
|
| 1164 |
-
sot_sequence = [sot]
|
| 1165 |
-
|
| 1166 |
-
if self.language is not None: sot_sequence.append(sot + 1 + langs.index(self.language))
|
| 1167 |
-
if self.task is not None: sot_sequence.append(self.special_tokens["<|transcribe|>"] if self.task == "transcribe" else self.special_tokens["<|translate|>"])
|
| 1168 |
-
|
| 1169 |
-
self.sot_sequence = tuple(sot_sequence)
|
| 1170 |
-
|
| 1171 |
-
def encode(self, text, **kwargs):
|
| 1172 |
-
return self.encoding.encode(text, **kwargs)
|
| 1173 |
-
|
| 1174 |
-
def decode(self, token_ids, **kwargs):
|
| 1175 |
-
return self.encoding.decode([t for t in token_ids if t < self.timestamp_begin], **kwargs)
|
| 1176 |
-
|
| 1177 |
-
def decode_with_timestamps(self, token_ids, **kwargs):
|
| 1178 |
-
return self.encoding.decode(token_ids, **kwargs)
|
| 1179 |
-
|
| 1180 |
-
@cached_property
|
| 1181 |
-
def eot(self):
|
| 1182 |
-
return self.encoding.eot_token
|
| 1183 |
-
|
| 1184 |
-
@cached_property
|
| 1185 |
-
def transcribe(self):
|
| 1186 |
-
return self.special_tokens["<|transcribe|>"]
|
| 1187 |
-
|
| 1188 |
-
@cached_property
|
| 1189 |
-
def translate(self):
|
| 1190 |
-
return self.special_tokens["<|translate|>"]
|
| 1191 |
-
|
| 1192 |
-
@cached_property
|
| 1193 |
-
def sot(self):
|
| 1194 |
-
return self.special_tokens["<|startoftranscript|>"]
|
| 1195 |
-
|
| 1196 |
-
@cached_property
|
| 1197 |
-
def sot_lm(self):
|
| 1198 |
-
return self.special_tokens["<|startoflm|>"]
|
| 1199 |
-
|
| 1200 |
-
@cached_property
|
| 1201 |
-
def sot_prev(self):
|
| 1202 |
-
return self.special_tokens["<|startofprev|>"]
|
| 1203 |
-
|
| 1204 |
-
@cached_property
|
| 1205 |
-
def no_speech(self):
|
| 1206 |
-
return self.special_tokens["<|nospeech|>"]
|
| 1207 |
-
|
| 1208 |
-
@cached_property
|
| 1209 |
-
def no_timestamps(self):
|
| 1210 |
-
return self.special_tokens["<|notimestamps|>"]
|
| 1211 |
-
|
| 1212 |
-
@cached_property
|
| 1213 |
-
def timestamp_begin(self):
|
| 1214 |
-
return self.special_tokens["<|0.00|>"]
|
| 1215 |
-
|
| 1216 |
-
@cached_property
|
| 1217 |
-
def language_token(self):
|
| 1218 |
-
if self.language is None: raise ValueError
|
| 1219 |
-
return self.to_language_token(self.language)
|
| 1220 |
-
|
| 1221 |
-
def to_language_token(self, language):
|
| 1222 |
-
if token := self.special_tokens.get(f"<|{language}|>", None): return token
|
| 1223 |
-
raise KeyError
|
| 1224 |
-
|
| 1225 |
-
@cached_property
|
| 1226 |
-
def all_language_tokens(self):
|
| 1227 |
-
result = []
|
| 1228 |
-
for token, token_id in self.special_tokens.items():
|
| 1229 |
-
if token.strip("<|>") in LANGUAGES: result.append(token_id)
|
| 1230 |
-
|
| 1231 |
-
return tuple(result)[: self.num_languages]
|
| 1232 |
-
|
| 1233 |
-
@cached_property
|
| 1234 |
-
def all_language_codes(self):
|
| 1235 |
-
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
|
| 1236 |
-
|
| 1237 |
-
@cached_property
|
| 1238 |
-
def sot_sequence_including_notimestamps(self):
|
| 1239 |
-
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
| 1240 |
-
|
| 1241 |
-
@cached_property
|
| 1242 |
-
def non_speech_tokens(self):
|
| 1243 |
-
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
| 1244 |
-
symbols += ("<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split())
|
| 1245 |
-
|
| 1246 |
-
miscellaneous = set("♩♪♫♬♭♮♯")
|
| 1247 |
-
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
| 1248 |
-
|
| 1249 |
-
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
|
| 1250 |
-
for symbol in symbols + list(miscellaneous):
|
| 1251 |
-
for tokens in [self.encoding.encode(symbol), self.encoding.encode(" " + symbol)]:
|
| 1252 |
-
if len(tokens) == 1 or symbol in miscellaneous: result.add(tokens[0])
|
| 1253 |
-
|
| 1254 |
-
return tuple(sorted(result))
|
| 1255 |
-
|
| 1256 |
-
def split_to_word_tokens(self, tokens):
|
| 1257 |
-
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}: return self.split_tokens_on_unicode(tokens)
|
| 1258 |
-
return self.split_tokens_on_spaces(tokens)
|
| 1259 |
-
|
| 1260 |
-
def split_tokens_on_unicode(self, tokens):
|
| 1261 |
-
replacement_char = "\ufffd"
|
| 1262 |
-
|
| 1263 |
-
words, word_tokens, current_tokens = [], [], []
|
| 1264 |
-
unicode_offset = 0
|
| 1265 |
-
|
| 1266 |
-
for token in tokens:
|
| 1267 |
-
current_tokens.append(token)
|
| 1268 |
-
decoded = self.decode_with_timestamps(current_tokens)
|
| 1269 |
-
|
| 1270 |
-
if (replacement_char not in decoded or self.decode_with_timestamps(tokens)[unicode_offset + decoded.index(replacement_char)] == replacement_char):
|
| 1271 |
-
words.append(decoded)
|
| 1272 |
-
word_tokens.append(current_tokens)
|
| 1273 |
-
current_tokens = []
|
| 1274 |
-
unicode_offset += len(decoded)
|
| 1275 |
-
|
| 1276 |
-
return words, word_tokens
|
| 1277 |
-
|
| 1278 |
-
def split_tokens_on_spaces(self, tokens):
|
| 1279 |
-
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
| 1280 |
-
words, word_tokens = [], []
|
| 1281 |
-
|
| 1282 |
-
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
| 1283 |
-
if (subword_tokens[0] >= self.eot) or (subword.startswith(" ")) or (subword.strip() in string.punctuation) or len(words) == 0:
|
| 1284 |
-
words.append(subword)
|
| 1285 |
-
word_tokens.append(subword_tokens)
|
| 1286 |
-
else:
|
| 1287 |
-
words[-1] = words[-1] + subword
|
| 1288 |
-
word_tokens[-1].extend(subword_tokens)
|
| 1289 |
-
|
| 1290 |
-
return words, word_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/utils.py
DELETED
|
@@ -1,240 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import re
|
| 3 |
-
import sys
|
| 4 |
-
import codecs
|
| 5 |
-
import librosa
|
| 6 |
-
import logging
|
| 7 |
-
|
| 8 |
-
import numpy as np
|
| 9 |
-
import soundfile as sf
|
| 10 |
-
|
| 11 |
-
from pydub import AudioSegment
|
| 12 |
-
|
| 13 |
-
sys.path.append(os.getcwd())
|
| 14 |
-
|
| 15 |
-
from main.tools import huggingface
|
| 16 |
-
from main.configs.config import Config
|
| 17 |
-
|
| 18 |
-
for l in ["httpx", "httpcore"]:
|
| 19 |
-
logging.getLogger(l).setLevel(logging.ERROR)
|
| 20 |
-
|
| 21 |
-
translations = Config().translations
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def check_predictors(method, f0_onnx=False):
|
| 25 |
-
if f0_onnx and method not in ["harvest", "dio"]: method += "-onnx"
|
| 26 |
-
|
| 27 |
-
def download(predictors):
|
| 28 |
-
if not os.path.exists(os.path.join("assets", "models", "predictors", predictors)): huggingface.HF_download_file(codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/cerqvpgbef/", "rot13") + predictors, os.path.join("assets", "models", "predictors", predictors))
|
| 29 |
-
|
| 30 |
-
model_dict = {**dict.fromkeys(["rmvpe", "rmvpe-legacy"], "rmvpe.pt"), **dict.fromkeys(["rmvpe-onnx", "rmvpe-legacy-onnx"], "rmvpe.onnx"), **dict.fromkeys(["fcpe"], "fcpe.pt"), **dict.fromkeys(["fcpe-legacy"], "fcpe_legacy.pt"), **dict.fromkeys(["fcpe-onnx"], "fcpe.onnx"), **dict.fromkeys(["fcpe-legacy-onnx"], "fcpe_legacy.onnx"), **dict.fromkeys(["crepe-full", "mangio-crepe-full"], "crepe_full.pth"), **dict.fromkeys(["crepe-full-onnx", "mangio-crepe-full-onnx"], "crepe_full.onnx"), **dict.fromkeys(["crepe-large", "mangio-crepe-large"], "crepe_large.pth"), **dict.fromkeys(["crepe-large-onnx", "mangio-crepe-large-onnx"], "crepe_large.onnx"), **dict.fromkeys(["crepe-medium", "mangio-crepe-medium"], "crepe_medium.pth"), **dict.fromkeys(["crepe-medium-onnx", "mangio-crepe-medium-onnx"], "crepe_medium.onnx"), **dict.fromkeys(["crepe-small", "mangio-crepe-small"], "crepe_small.pth"), **dict.fromkeys(["crepe-small-onnx", "mangio-crepe-small-onnx"], "crepe_small.onnx"), **dict.fromkeys(["crepe-tiny", "mangio-crepe-tiny"], "crepe_tiny.pth"), **dict.fromkeys(["crepe-tiny-onnx", "mangio-crepe-tiny-onnx"], "crepe_tiny.onnx"), **dict.fromkeys(["harvest", "dio"], "world.pth")}
|
| 31 |
-
|
| 32 |
-
if "hybrid" in method:
|
| 33 |
-
methods_str = re.search("hybrid\[(.+)\]", method)
|
| 34 |
-
if methods_str: methods = [method.strip() for method in methods_str.group(1).split("+")]
|
| 35 |
-
|
| 36 |
-
for method in methods:
|
| 37 |
-
if method in model_dict: download(model_dict[method])
|
| 38 |
-
elif method in model_dict: download(model_dict[method])
|
| 39 |
-
|
| 40 |
-
def check_embedders(hubert, embedders_mode="fairseq"):
|
| 41 |
-
huggingface_url = codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/rzorqqref/", "rot13")
|
| 42 |
-
|
| 43 |
-
if hubert in ["contentvec_base", "hubert_base", "japanese_hubert_base", "korean_hubert_base", "chinese_hubert_base", "portuguese_hubert_base"]:
|
| 44 |
-
if embedders_mode == "fairseq": hubert += ".pt"
|
| 45 |
-
elif embedders_mode == "onnx": hubert += ".onnx"
|
| 46 |
-
|
| 47 |
-
model_path = os.path.join("assets", "models", "embedders", hubert)
|
| 48 |
-
|
| 49 |
-
if embedders_mode == "fairseq":
|
| 50 |
-
if not os.path.exists(model_path): huggingface.HF_download_file("".join([huggingface_url, "fairseq/", hubert]), model_path)
|
| 51 |
-
elif embedders_mode == "onnx":
|
| 52 |
-
if not os.path.exists(model_path): huggingface.HF_download_file("".join([huggingface_url, "onnx/", hubert]), model_path)
|
| 53 |
-
elif embedders_mode == "transformers":
|
| 54 |
-
bin_file = os.path.join(model_path, "model.safetensors")
|
| 55 |
-
config_file = os.path.join(model_path, "config.json")
|
| 56 |
-
|
| 57 |
-
os.makedirs(model_path, exist_ok=True)
|
| 58 |
-
|
| 59 |
-
if not os.path.exists(bin_file): huggingface.HF_download_file("".join([huggingface_url, "transformers/", hubert, "/model.safetensors"]), bin_file)
|
| 60 |
-
if not os.path.exists(config_file): huggingface.HF_download_file("".join([huggingface_url, "transformers/", hubert, "/config.json"]), config_file)
|
| 61 |
-
else: raise ValueError(translations["option_not_valid"])
|
| 62 |
-
|
| 63 |
-
def check_spk_diarization(model_size):
|
| 64 |
-
whisper_model = os.path.join("assets", "models", "speaker_diarization", "models", f"{model_size}.pt")
|
| 65 |
-
if not os.path.exists(whisper_model): huggingface.HF_download_file("".join([codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/fcrnxre_qvnevmngvba/", "rot13"), model_size, ".pt"]), whisper_model)
|
| 66 |
-
|
| 67 |
-
speechbrain_path = os.path.join("assets", "models", "speaker_diarization", "models", "speechbrain")
|
| 68 |
-
if not os.path.exists(speechbrain_path): os.makedirs(speechbrain_path, exist_ok=True)
|
| 69 |
-
|
| 70 |
-
for f in ["classifier.ckpt", "config.json", "embedding_model.ckpt", "hyperparams.yaml", "mean_var_norm_emb.ckpt"]:
|
| 71 |
-
speechbrain_model = os.path.join(speechbrain_path, f)
|
| 72 |
-
if not os.path.exists(speechbrain_model): huggingface.HF_download_file(codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/fcrnxre_qvnevmngvba/fcrrpuoenva/", "rot13") + f, speechbrain_model)
|
| 73 |
-
|
| 74 |
-
def check_audioldm2(model):
|
| 75 |
-
for f in ["feature_extractor", "language_model", "projection_model", "scheduler", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "unet", "vae", "vocoder"]:
|
| 76 |
-
folder_path = os.path.join("assets", "models", "audioldm2", model, f)
|
| 77 |
-
if not os.path.exists(folder_path): os.makedirs(folder_path, exist_ok=True)
|
| 78 |
-
|
| 79 |
-
for f in ["feature_extractor/preprocessor_config.json","language_model/config.json","language_model/model.safetensors","model_index.json","projection_model/config.json","projection_model/diffusion_pytorch_model.safetensors","scheduler/scheduler_config.json","text_encoder/config.json","text_encoder/model.safetensors","text_encoder_2/config.json","text_encoder_2/model.safetensors","tokenizer/merges.txt","tokenizer/special_tokens_map.json","tokenizer/tokenizer.json","tokenizer/tokenizer_config.json","tokenizer/vocab.json","tokenizer_2/special_tokens_map.json","tokenizer_2/spiece.model","tokenizer_2/tokenizer.json","tokenizer_2/tokenizer_config.json","unet/config.json","unet/diffusion_pytorch_model.safetensors","vae/config.json","vae/diffusion_pytorch_model.safetensors","vocoder/config.json","vocoder/model.safetensors"]:
|
| 80 |
-
model_path = os.path.join("assets", "models", "audioldm2", model, f)
|
| 81 |
-
if not os.path.exists(model_path): huggingface.HF_download_file("".join([codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/nhqvbyqz/", "rot13"), model, "/", f]), model_path)
|
| 82 |
-
|
| 83 |
-
def load_audio(logger, file, sample_rate=16000, formant_shifting=False, formant_qfrency=0.8, formant_timbre=0.8):
|
| 84 |
-
try:
|
| 85 |
-
file = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
|
| 86 |
-
if not os.path.isfile(file): raise FileNotFoundError(translations["not_found"].format(name=file))
|
| 87 |
-
|
| 88 |
-
try:
|
| 89 |
-
logger.debug(translations['read_sf'])
|
| 90 |
-
audio, sr = sf.read(file, dtype=np.float32)
|
| 91 |
-
except:
|
| 92 |
-
logger.debug(translations['read_librosa'])
|
| 93 |
-
audio, sr = librosa.load(file, sr=None)
|
| 94 |
-
|
| 95 |
-
if len(audio.shape) > 1: audio = librosa.to_mono(audio.T)
|
| 96 |
-
if sr != sample_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate, res_type="soxr_vhq")
|
| 97 |
-
|
| 98 |
-
if formant_shifting:
|
| 99 |
-
from main.library.algorithm.stftpitchshift import StftPitchShift
|
| 100 |
-
|
| 101 |
-
pitchshifter = StftPitchShift(1024, 32, sample_rate)
|
| 102 |
-
audio = pitchshifter.shiftpitch(audio, factors=1, quefrency=formant_qfrency * 1e-3, distortion=formant_timbre)
|
| 103 |
-
except Exception as e:
|
| 104 |
-
raise RuntimeError(f"{translations['errors_loading_audio']}: {e}")
|
| 105 |
-
|
| 106 |
-
return audio.flatten()
|
| 107 |
-
|
| 108 |
-
def pydub_convert(audio):
|
| 109 |
-
samples = np.frombuffer(audio.raw_data, dtype=np.int16)
|
| 110 |
-
if samples.dtype != np.int16: samples = (samples * 32767).astype(np.int16)
|
| 111 |
-
return AudioSegment(samples.tobytes(), frame_rate=audio.frame_rate, sample_width=samples.dtype.itemsize, channels=audio.channels)
|
| 112 |
-
|
| 113 |
-
def pydub_load(input_path):
|
| 114 |
-
try:
|
| 115 |
-
if input_path.endswith(".wav"): audio = AudioSegment.from_wav(input_path)
|
| 116 |
-
elif input_path.endswith(".mp3"): audio = AudioSegment.from_mp3(input_path)
|
| 117 |
-
elif input_path.endswith(".ogg"): audio = AudioSegment.from_ogg(input_path)
|
| 118 |
-
else: audio = AudioSegment.from_file(input_path)
|
| 119 |
-
except:
|
| 120 |
-
audio = AudioSegment.from_file(input_path)
|
| 121 |
-
|
| 122 |
-
return audio
|
| 123 |
-
|
| 124 |
-
def load_embedders_model(embedder_model, embedders_mode="fairseq", providers=None):
|
| 125 |
-
if embedders_mode == "fairseq": embedder_model += ".pt"
|
| 126 |
-
elif embedders_mode == "onnx": embedder_model += ".onnx"
|
| 127 |
-
|
| 128 |
-
embedder_model_path = os.path.join("assets", "models", "embedders", embedder_model)
|
| 129 |
-
if not os.path.exists(embedder_model_path): raise FileNotFoundError(f"{translations['not_found'].format(name=translations['model'])}: {embedder_model}")
|
| 130 |
-
|
| 131 |
-
try:
|
| 132 |
-
if embedders_mode == "fairseq":
|
| 133 |
-
from main.library.architectures import fairseq
|
| 134 |
-
|
| 135 |
-
models, saved_cfg, _ = fairseq.load_model(embedder_model_path)
|
| 136 |
-
embed_suffix = ".pt"
|
| 137 |
-
hubert_model = models[0]
|
| 138 |
-
elif embedders_mode == "onnx":
|
| 139 |
-
import onnxruntime
|
| 140 |
-
|
| 141 |
-
sess_options = onnxruntime.SessionOptions()
|
| 142 |
-
sess_options.log_severity_level = 3
|
| 143 |
-
embed_suffix, saved_cfg = ".onnx", None
|
| 144 |
-
hubert_model = onnxruntime.InferenceSession(embedder_model_path, sess_options=sess_options, providers=providers)
|
| 145 |
-
elif embedders_mode == "transformers":
|
| 146 |
-
from torch import nn
|
| 147 |
-
from transformers import HubertModel
|
| 148 |
-
|
| 149 |
-
class HubertModelWithFinalProj(HubertModel):
|
| 150 |
-
def __init__(self, config):
|
| 151 |
-
super().__init__(config)
|
| 152 |
-
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
| 153 |
-
|
| 154 |
-
embed_suffix, saved_cfg = ".safetensors", None
|
| 155 |
-
hubert_model = HubertModelWithFinalProj.from_pretrained(embedder_model_path)
|
| 156 |
-
else: raise ValueError(translations["option_not_valid"])
|
| 157 |
-
except Exception as e:
|
| 158 |
-
raise RuntimeError(translations["read_model_error"].format(e=e))
|
| 159 |
-
|
| 160 |
-
return hubert_model, saved_cfg, embed_suffix
|
| 161 |
-
|
| 162 |
-
def cut(audio, sr, db_thresh=-60, min_interval=250):
|
| 163 |
-
from main.inference.preprocess import Slicer, get_rms
|
| 164 |
-
|
| 165 |
-
class Slicer2(Slicer):
|
| 166 |
-
def slice2(self, waveform):
|
| 167 |
-
samples = waveform.mean(axis=0) if len(waveform.shape) > 1 else waveform
|
| 168 |
-
|
| 169 |
-
if samples.shape[0] <= self.min_length: return [(waveform, 0, samples.shape[0])]
|
| 170 |
-
rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
| 171 |
-
|
| 172 |
-
sil_tags = []
|
| 173 |
-
silence_start, clip_start = None, 0
|
| 174 |
-
|
| 175 |
-
for i, rms in enumerate(rms_list):
|
| 176 |
-
if rms < self.threshold:
|
| 177 |
-
if silence_start is None: silence_start = i
|
| 178 |
-
continue
|
| 179 |
-
|
| 180 |
-
if silence_start is None: continue
|
| 181 |
-
|
| 182 |
-
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
| 183 |
-
need_slice_middle = (i - silence_start >= self.min_interval and i - clip_start >= self.min_length)
|
| 184 |
-
|
| 185 |
-
if not is_leading_silence and not need_slice_middle:
|
| 186 |
-
silence_start = None
|
| 187 |
-
continue
|
| 188 |
-
|
| 189 |
-
if i - silence_start <= self.max_sil_kept:
|
| 190 |
-
pos = rms_list[silence_start : i + 1].argmin() + silence_start
|
| 191 |
-
sil_tags.append((0, pos) if silence_start == 0 else (pos, pos))
|
| 192 |
-
clip_start = pos
|
| 193 |
-
elif i - silence_start <= self.max_sil_kept * 2:
|
| 194 |
-
pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
|
| 195 |
-
pos += i - self.max_sil_kept
|
| 196 |
-
|
| 197 |
-
pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
|
| 198 |
-
|
| 199 |
-
if silence_start == 0:
|
| 200 |
-
sil_tags.append((0, pos_r))
|
| 201 |
-
clip_start = pos_r
|
| 202 |
-
else:
|
| 203 |
-
sil_tags.append((min((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos), max(pos_r, pos)))
|
| 204 |
-
clip_start = max(pos_r, pos)
|
| 205 |
-
else:
|
| 206 |
-
pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
|
| 207 |
-
sil_tags.append((0, pos_r) if silence_start == 0 else ((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos_r))
|
| 208 |
-
clip_start = pos_r
|
| 209 |
-
|
| 210 |
-
silence_start = None
|
| 211 |
-
|
| 212 |
-
total_frames = rms_list.shape[0]
|
| 213 |
-
if (silence_start is not None and total_frames - silence_start >= self.min_interval): sil_tags.append((rms_list[silence_start : min(total_frames, silence_start + self.max_sil_kept) + 1].argmin() + silence_start, total_frames + 1))
|
| 214 |
-
|
| 215 |
-
if not sil_tags: return [(waveform, 0, samples.shape[-1])]
|
| 216 |
-
else:
|
| 217 |
-
chunks = []
|
| 218 |
-
if sil_tags[0][0] > 0: chunks.append((self._apply_slice(waveform, 0, sil_tags[0][0]), 0, sil_tags[0][0] * self.hop_size))
|
| 219 |
-
|
| 220 |
-
for i in range(len(sil_tags) - 1):
|
| 221 |
-
chunks.append((self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]), sil_tags[i][1] * self.hop_size, sil_tags[i + 1][0] * self.hop_size))
|
| 222 |
-
|
| 223 |
-
if sil_tags[-1][1] < total_frames: chunks.append((self._apply_slice(waveform, sil_tags[-1][1], total_frames), sil_tags[-1][1] * self.hop_size, samples.shape[-1]))
|
| 224 |
-
return chunks
|
| 225 |
-
|
| 226 |
-
slicer = Slicer2(sr=sr, threshold=db_thresh, min_interval=min_interval)
|
| 227 |
-
return slicer.slice2(audio)
|
| 228 |
-
|
| 229 |
-
def restore(segments, total_len, dtype=np.float32):
|
| 230 |
-
out = []
|
| 231 |
-
last_end = 0
|
| 232 |
-
|
| 233 |
-
for start, end, processed_seg in segments:
|
| 234 |
-
if start > last_end: out.append(np.zeros(start - last_end, dtype=dtype))
|
| 235 |
-
|
| 236 |
-
out.append(processed_seg)
|
| 237 |
-
last_end = end
|
| 238 |
-
|
| 239 |
-
if last_end < total_len: out.append(np.zeros(total_len - last_end, dtype=dtype))
|
| 240 |
-
return np.concatenate(out, axis=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|