kevinwang676 commited on
Commit
b8292ba
·
verified ·
1 Parent(s): c2d2c2b

Delete GPT-SoVITS-models

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. GPT-SoVITS-models/.gitattributes +0 -44
  2. GPT-SoVITS-models/GPT-SoVITS/.ipynb_checkpoints/del-checkpoint.sh +0 -12
  3. GPT-SoVITS-models/GPT-SoVITS/.ipynb_checkpoints/webui-checkpoint.py +0 -719
  4. GPT-SoVITS-models/GPT-SoVITS/.ipynb_checkpoints/启动webui-checkpoint.sh +0 -2
  5. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/.ipynb_checkpoints/inference_webui-checkpoint.py +0 -270
  6. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/__init__.py +0 -0
  7. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/__init__.py +0 -0
  8. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/bucket_sampler.py +0 -157
  9. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/data_module.py +0 -66
  10. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/dataset.py +0 -302
  11. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/__init__.py +0 -0
  12. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/BEATs.py +0 -179
  13. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/README.md +0 -127
  14. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/Tokenizers.py +0 -172
  15. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/__init__.py +0 -2
  16. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/backbone.py +0 -791
  17. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/config.py +0 -19
  18. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/modules.py +0 -220
  19. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/ontology.json +0 -0
  20. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/quantizer.py +0 -235
  21. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_beats_librilight.py +0 -321
  22. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_phones.py +0 -232
  23. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_phones_librilight.py +0 -198
  24. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_txt_librilight.py +0 -255
  25. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/split_train_val.py +0 -35
  26. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/t2s.py +0 -197
  27. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/test.py +0 -139
  28. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/text.txt +0 -10
  29. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/train.py +0 -103
  30. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/train_librilight_6k.py +0 -170
  31. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/__init__.py +0 -0
  32. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/t2s_lightning_module.py +0 -128
  33. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/t2s_model.py +0 -298
  34. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/utils.py +0 -164
  35. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/__init__.py +0 -0
  36. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/activation.py +0 -397
  37. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/embedding.py +0 -78
  38. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/lr_schedulers.py +0 -85
  39. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/optim.py +0 -622
  40. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/patched_mha_with_cache.py +0 -388
  41. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/scaling.py +0 -319
  42. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/transformer.py +0 -347
  43. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/text_processing/__init__.py +0 -0
  44. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/text_processing/phonemizer.py +0 -80
  45. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/text_processing/symbols.py +0 -9
  46. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/utils/__init__.py +0 -37
  47. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/utils/initialize.py +0 -38
  48. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/utils/io.py +0 -32
  49. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/configs/s1.yaml +0 -31
  50. GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/configs/s1big.yaml +0 -31
GPT-SoVITS-models/.gitattributes DELETED
@@ -1,44 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- GPT-SoVITS/TEMP/gradio/2bbf387613664982acc3847e4b4970fc6bf09120/audio.wav filter=lfs diff=lfs merge=lfs -text
37
- GPT-SoVITS/TEMP/gradio/4e25df8e5470697bd435cc94559e1c34f09bab16/audio.wav filter=lfs diff=lfs merge=lfs -text
38
- GPT-SoVITS/TEMP/gradio/6b579cccde8715941d9b9b06a1b9787ce0fdb4db/audio.wav filter=lfs diff=lfs merge=lfs -text
39
- GPT-SoVITS/TEMP/gradio/873c1f03462a87c00222fd2422a8b328244f45da/audio.wav filter=lfs diff=lfs merge=lfs -text
40
- GPT-SoVITS/TEMP/gradio/d2c38e2d7f131cfc51fe07c541177b0f5a061cc3/audio.wav filter=lfs diff=lfs merge=lfs -text
41
- GPT-SoVITS/TEMP/gradio/e6f05e0d768171ac3b7355d968cb1badf9d84864/wyxy_101-0-100.wav filter=lfs diff=lfs merge=lfs -text
42
- GPT-SoVITS/TEMP/gradio/e6f05e0d768171ac3b7355d968cb1badf9d84864/wyxy_101.wav filter=lfs diff=lfs merge=lfs -text
43
- GPT-SoVITS/TEMP/jieba.cache filter=lfs diff=lfs merge=lfs -text
44
- GPT-SoVITS/tools/damo_asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/.ipynb_checkpoints/del-checkpoint.sh DELETED
@@ -1,12 +0,0 @@
1
- #!/bin/bash
2
- cd /root/autodl-tmp/workdir/GPT-SoVITS
3
- rm -rf GPT_weights/*
4
- rm -rf SoVITS_weights/*
5
-
6
- rm -rf input/*
7
- rm -rf output/asr_opt/*
8
- rm -rf output/slicer_opt/*
9
- rm -rf output/uvr5_opt/*
10
- rm -rf logs/*
11
-
12
- echo 初始化完成
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/.ipynb_checkpoints/webui-checkpoint.py DELETED
@@ -1,719 +0,0 @@
1
- import json,yaml,warnings,torch
2
- warnings.filterwarnings("ignore")
3
- torch.manual_seed(233333)
4
- import os,pdb,sys
5
- now_dir = os.getcwd()
6
- tmp = os.path.join(now_dir, "TEMP")
7
- os.makedirs(tmp, exist_ok=True)
8
- os.environ["TEMP"] = tmp
9
- import site
10
- site_packages_root="%s/root/miniconda3/lib/python3.10/site-packages"%now_dir
11
- for path in site.getsitepackages():
12
- if("site-packages"in path):site_packages_root=path
13
- os.environ["OPENBLAS_NUM_THREADS"] = "4"
14
- os.environ["no_proxy"] = "localhost, 127.0.0.1, ::1"
15
- with open("%s/users.pth"%(site_packages_root),"w")as f:
16
- f.write("%s\n%s/tools\n%s/tools/damo_asr\n%s/GPT_SoVITS\n%s/tools/uvr5"%(now_dir,now_dir,now_dir,now_dir,now_dir))
17
- import traceback
18
- sys.path.append(now_dir)
19
- import shutil
20
- import pdb
21
- import gradio as gr
22
- from subprocess import Popen
23
- import signal
24
- from config import python_exec,infer_device,is_half,exp_root
25
- from i18n.i18n import I18nAuto
26
- i18n = I18nAuto()
27
- from scipy.io import wavfile
28
- from tools.my_utils import load_audio
29
- from multiprocessing import cpu_count
30
- n_cpu=cpu_count()
31
-
32
- # 判断是否有能用来训练和加速推理的N卡
33
- ngpu = torch.cuda.device_count()
34
- gpu_infos = []
35
- mem = []
36
- if_gpu_ok = False
37
-
38
- if torch.cuda.is_available() or ngpu != 0:
39
- for i in range(ngpu):
40
- gpu_name = torch.cuda.get_device_name(i)
41
- if any(value in gpu_name.upper()for value in ["10","16","20","30","40","A2","A3","A4","P4","A50","500","A60","70","80","90","M4","T4","TITAN","L"]):
42
- # A10#A100#V100#A40#P40#M40#K80#A4500
43
- if_gpu_ok = True # 至少有一张能用的N卡
44
- gpu_infos.append("%s\t%s" % (i, gpu_name))
45
- mem.append(int(torch.cuda.get_device_properties(i).total_memory/ 1024/ 1024/ 1024+ 0.4))
46
- if if_gpu_ok and len(gpu_infos) > 0:
47
- gpu_info = "\n".join(gpu_infos)
48
- default_batch_size = min(mem) // 2
49
- else:
50
- gpu_info = i18n("很遗憾您这没有能用的显卡来支持您训练")
51
- default_batch_size = 1
52
- gpus = "-".join([i[0] for i in gpu_infos])
53
-
54
- pretrained_sovits_name="GPT_SoVITS/pretrained_models/s2G488k.pth"
55
- pretrained_gpt_name="GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
56
- def get_weights_names():
57
- SoVITS_names = [pretrained_sovits_name]
58
- for name in os.listdir(SoVITS_weight_root):
59
- if name.endswith(".pth"):SoVITS_names.append(name)
60
- GPT_names = [pretrained_gpt_name]
61
- for name in os.listdir(GPT_weight_root):
62
- if name.endswith(".ckpt"): GPT_names.append(name)
63
- return SoVITS_names,GPT_names
64
- SoVITS_weight_root="SoVITS_weights"
65
- GPT_weight_root="GPT_weights"
66
- SoVITS_names,GPT_names = get_weights_names()
67
-
68
- def change_choices():
69
- SoVITS_names, GPT_names = get_weights_names()
70
- return {"choices": sorted(SoVITS_names), "__type__": "update"}, {"choices": sorted(GPT_names), "__type__": "update"}
71
-
72
- p_label=None
73
- p_uvr5=None
74
- p_asr=None
75
- p_tts_inference=None
76
-
77
- def kill_process(pid):
78
- os.system("taskkill /t /f /pid %s" % pid) # todo:识别linux用kill -9
79
- # os.kill(p_label.pid,19)#主进程#控制台进程#python子进程###不好使,连主进程的webui一起关了,辣鸡
80
-
81
- def change_label(if_label,path_list):
82
- global p_label
83
- if(if_label==True and p_label==None):
84
- cmd = '"%s" tools/subfix_webui.py --load_list "%s"'%(python_exec,path_list)
85
- yield "打标工具WebUI已开启"
86
- print(cmd)
87
- p_label = Popen(cmd, shell=True)
88
- elif(if_label==False and p_label!=None):
89
- kill_process(p_label.pid)
90
- p_label=None
91
- yield "打标工具WebUI已关闭"
92
-
93
- def change_uvr5(if_uvr5):
94
- global p_uvr5
95
- if(if_uvr5==True and p_uvr5==None):
96
- cmd = '"%s" tools/uvr5/webui.py "%s" %s'%(python_exec,infer_device,is_half)
97
- yield "UVR5已开启"
98
- print(cmd)
99
- p_uvr5 = Popen(cmd, shell=True)
100
- elif(if_uvr5==False and p_uvr5!=None):
101
- kill_process(p_uvr5.pid)
102
- p_uvr5=None
103
- yield "UVR5已关闭"
104
-
105
- def change_tts_inference(if_tts,bert_path,cnhubert_base_path,gpu_number,gpt_path,sovits_path):
106
- global p_tts_inference
107
- if(if_tts==True and p_tts_inference==None):
108
- os.environ["gpt_path"]=gpt_path if "/" in gpt_path else "%s/%s"%(GPT_weight_root,gpt_path)
109
- os.environ["sovits_path"]=sovits_path if "/"in sovits_path else "%s/%s"%(SoVITS_weight_root,sovits_path)
110
- os.environ["cnhubert_base_path"]=cnhubert_base_path
111
- os.environ["bert_path"]=bert_path
112
- os.environ["_CUDA_VISIBLE_DEVICES"]=gpu_number
113
- os.environ["is_half"]=str(is_half)
114
- cmd = '"%s" GPT_SoVITS/inference_webui.py'%(python_exec)
115
- yield "TTS推理进程已开启"
116
- print(cmd)
117
- p_tts_inference = Popen(cmd, shell=True)
118
- elif(if_tts==False and p_tts_inference!=None):
119
- kill_process(p_tts_inference.pid)
120
- p_tts_inference=None
121
- yield "TTS推理进程已关闭"
122
-
123
-
124
- def open_asr(asr_inp_dir):
125
- global p_asr
126
- if(p_asr==None):
127
- cmd = '"%s" tools/damo_asr/cmd-asr.py "%s"'%(python_exec,asr_inp_dir)
128
- yield "ASR任务开启:%s"%cmd,{"__type__":"update","visible":False},{"__type__":"update","visible":True}
129
- print(cmd)
130
- p_asr = Popen(cmd, shell=True)
131
- p_asr.wait()
132
- p_asr=None
133
- yield "ASR任务完成",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
134
- else:
135
- yield "已有正在进行的ASR任务,需先终止才能开启下一次任务",{"__type__":"update","visible":False},{"__type__":"update","visible":True}
136
-
137
- def close_asr():
138
- global p_asr
139
- if(p_asr!=None):
140
- kill_process(p_asr.pid)
141
- p_asr=None
142
- return "已终止ASR进程",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
143
-
144
- '''
145
- button1Ba_open.click(open1Ba, [batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D], [info1Bb,button1Ba_open,button1Ba_close])
146
- button1Ba_close.click(close1Ba, [], [info1Bb,button1Ba_open,button1Ba_close])
147
- '''
148
- p_train_SoVITS=None
149
- def open1Ba(batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D):
150
- global p_train_SoVITS
151
- if(p_train_SoVITS==None):
152
- with open("GPT_SoVITS/configs/s2.json")as f:
153
- data=f.read()
154
- data=json.loads(data)
155
- s2_dir="%s/%s"%(exp_root,exp_name)
156
- os.makedirs("%s/logs_s2"%(s2_dir),exist_ok=True)
157
- data["train"]["batch_size"]=batch_size
158
- data["train"]["epochs"]=total_epoch
159
- data["train"]["text_low_lr_rate"]=text_low_lr_rate
160
- data["train"]["pretrained_s2G"]=pretrained_s2G
161
- data["train"]["pretrained_s2D"]=pretrained_s2D
162
- data["train"]["if_save_latest"]=if_save_latest
163
- data["train"]["if_save_every_weights"]=if_save_every_weights
164
- data["train"]["save_every_epoch"]=save_every_epoch
165
- data["train"]["gpu_numbers"]=gpu_numbers1Ba
166
- data["data"]["exp_dir"]=data["s2_ckpt_dir"]=s2_dir
167
- data["save_weight_dir"]=SoVITS_weight_root
168
- data["name"]=exp_name
169
- tmp_config_path="TEMP/tmp_s2.json"
170
- with open(tmp_config_path,"w")as f:f.write(json.dumps(data))
171
-
172
- cmd = '"%s" GPT_SoVITS/s2_train.py --config "%s"'%(python_exec,tmp_config_path)
173
- yield "SoVITS训练开始:%s"%cmd,{"__type__":"update","visible":False},{"__type__":"update","visible":True}
174
- print(cmd)
175
- p_train_SoVITS = Popen(cmd, shell=True)
176
- p_train_SoVITS.wait()
177
- p_train_SoVITS=None
178
- yield "SoVITS训练完成",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
179
- else:
180
- yield "已有正在进行的SoVITS训练任务,需先终止才能开启下一次任务",{"__type__":"update","visible":False},{"__type__":"update","visible":True}
181
-
182
- def close1Ba():
183
- global p_train_SoVITS
184
- if(p_train_SoVITS!=None):
185
- kill_process(p_train_SoVITS.pid)
186
- p_train_SoVITS=None
187
- return "已终止SoVITS训练",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
188
-
189
- p_train_GPT=None
190
- def open1Bb(batch_size,total_epoch,exp_name,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers,pretrained_s1):
191
- global p_train_GPT
192
- if(p_train_GPT==None):
193
- with open("GPT_SoVITS/configs/s1longer.yaml")as f:
194
- data=f.read()
195
- data=yaml.load(data, Loader=yaml.FullLoader)
196
- s1_dir="%s/%s"%(exp_root,exp_name)
197
- os.makedirs("%s/logs_s1"%(s1_dir),exist_ok=True)
198
- data["train"]["batch_size"]=batch_size
199
- data["train"]["epochs"]=total_epoch
200
- data["pretrained_s1"]=pretrained_s1
201
- data["train"]["save_every_n_epoch"]=save_every_epoch
202
- data["train"]["if_save_every_weights"]=if_save_every_weights
203
- data["train"]["if_save_latest"]=if_save_latest
204
- data["train"]["half_weights_save_dir"]=GPT_weight_root
205
- data["train"]["exp_name"]=exp_name
206
- data["train_semantic_path"]="%s/6-name2semantic.tsv"%s1_dir
207
- data["train_phoneme_path"]="%s/2-name2text.txt"%s1_dir
208
- data["output_dir"]="%s/logs_s1"%s1_dir
209
-
210
- os.environ["_CUDA_VISIBLE_DEVICES"]=gpu_numbers.replace("-",",")
211
- os.environ["hz"]="25hz"
212
- tmp_config_path="TEMP/tmp_s1.yaml"
213
- with open(tmp_config_path, "w") as f:f.write(yaml.dump(data, default_flow_style=False))
214
- # cmd = '"%s" GPT_SoVITS/s1_train.py --config_file "%s" --train_semantic_path "%s/6-name2semantic.tsv" --train_phoneme_path "%s/2-name2text.txt" --output_dir "%s/logs_s1"'%(python_exec,tmp_config_path,s1_dir,s1_dir,s1_dir)
215
- cmd = '"%s" GPT_SoVITS/s1_train.py --config_file "%s" '%(python_exec,tmp_config_path)
216
- yield "GPT训练开始:%s"%cmd,{"__type__":"update","visible":False},{"__type__":"update","visible":True}
217
- print(cmd)
218
- p_train_GPT = Popen(cmd, shell=True)
219
- p_train_GPT.wait()
220
- p_train_GPT=None
221
- yield "GPT训练完成",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
222
- else:
223
- yield "已有正在进行的GPT训练任务,需先终止才能开启下一次任务",{"__type__":"update","visible":False},{"__type__":"update","visible":True}
224
-
225
- def close1Bb():
226
- global p_train_GPT
227
- if(p_train_GPT!=None):
228
- kill_process(p_train_GPT.pid)
229
- p_train_GPT=None
230
- return "已终止GPT训练",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
231
-
232
- ps_slice=[]
233
- def open_slice(inp,opt_root,threshold,min_length,min_interval,hop_size,max_sil_kept,_max,alpha,n_parts):
234
- global ps_slice
235
- if(os.path.exists(inp)==False):
236
- yield "输入路径不存在",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
237
- return
238
- if os.path.isfile(inp):n_parts=1
239
- elif os.path.isdir(inp):pass
240
- else:
241
- yield "输入路径存在但既不是文件也不是文件夹",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
242
- return
243
- if (ps_slice == []):
244
- for i_part in range(n_parts):
245
- cmd = '"%s" tools/slice_audio.py "%s" "%s" %s %s %s %s %s %s %s %s %s''' % (python_exec,inp, opt_root, threshold, min_length, min_interval, hop_size, max_sil_kept, _max, alpha, i_part, n_parts)
246
- print(cmd)
247
- p = Popen(cmd, shell=True)
248
- ps_slice.append(p)
249
- yield "切割执行中", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
250
- for p in ps_slice:
251
- p.wait()
252
- ps_slice=[]
253
- yield "切割结束",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
254
- else:
255
- yield "已有正在进行的切割任务,需先终止才能开启下一次任务", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
256
-
257
- def close_slice():
258
- global ps_slice
259
- if (ps_slice != []):
260
- for p_slice in ps_slice:
261
- try:
262
- kill_process(p_slice.pid)
263
- except:
264
- traceback.print_exc()
265
- ps_slice=[]
266
- return "已终止所有切割进程", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
267
-
268
- '''
269
- inp_text= os.environ.get("inp_text")
270
- inp_wav_dir= os.environ.get("inp_wav_dir")
271
- exp_name= os.environ.get("exp_name")
272
- i_part= os.environ.get("i_part")
273
- all_parts= os.environ.get("all_parts")
274
- os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
275
- opt_dir= os.environ.get("opt_dir")#"/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
276
- bert_pretrained_dir= os.environ.get("bert_pretrained_dir")#"/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large"
277
- '''
278
- ps1a=[]
279
- def open1a(inp_text,inp_wav_dir,exp_name,gpu_numbers,bert_pretrained_dir):
280
- global ps1a
281
- if (ps1a == []):
282
- config={
283
- "inp_text":inp_text,
284
- "inp_wav_dir":inp_wav_dir,
285
- "exp_name":exp_name,
286
- "opt_dir":"%s/%s"%(exp_root,exp_name),
287
- "bert_pretrained_dir":bert_pretrained_dir,
288
- }
289
- gpu_names=gpu_numbers.split("-")
290
- all_parts=len(gpu_names)
291
- for i_part in range(all_parts):
292
- config.update(
293
- {
294
- "i_part": str(i_part),
295
- "all_parts": str(all_parts),
296
- "_CUDA_VISIBLE_DEVICES": gpu_names[i_part],
297
- "is_half": str(is_half)
298
- }
299
- )
300
- os.environ.update(config)
301
- cmd = '"%s" GPT_SoVITS/prepare_datasets/1-get-text.py'%python_exec
302
- print(cmd)
303
- p = Popen(cmd, shell=True)
304
- ps1a.append(p)
305
- yield "文本进程执行中", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
306
- for p in ps1a:
307
- p.wait()
308
- ps1a=[]
309
- yield "文本进程结束",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
310
- else:
311
- yield "已有正在进行的文本任务,需先终止才能开启下一次任务", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
312
-
313
- def close1a():
314
- global ps1a
315
- if (ps1a != []):
316
- for p1a in ps1a:
317
- try:
318
- kill_process(p1a.pid)
319
- except:
320
- traceback.print_exc()
321
- ps1a=[]
322
- return "已终止所有1a进程", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
323
- '''
324
- inp_text= os.environ.get("inp_text")
325
- inp_wav_dir= os.environ.get("inp_wav_dir")
326
- exp_name= os.environ.get("exp_name")
327
- i_part= os.environ.get("i_part")
328
- all_parts= os.environ.get("all_parts")
329
- os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
330
- opt_dir= os.environ.get("opt_dir")
331
- cnhubert.cnhubert_base_path= os.environ.get("cnhubert_base_dir")
332
- '''
333
- ps1b=[]
334
- def open1b(inp_text,inp_wav_dir,exp_name,gpu_numbers,ssl_pretrained_dir):
335
- global ps1b
336
- if (ps1b == []):
337
- config={
338
- "inp_text":inp_text,
339
- "inp_wav_dir":inp_wav_dir,
340
- "exp_name":exp_name,
341
- "opt_dir":"%s/%s"%(exp_root,exp_name),
342
- "cnhubert_base_dir":ssl_pretrained_dir,
343
- "is_half": str(is_half)
344
- }
345
- gpu_names=gpu_numbers.split("-")
346
- all_parts=len(gpu_names)
347
- for i_part in range(all_parts):
348
- config.update(
349
- {
350
- "i_part": str(i_part),
351
- "all_parts": str(all_parts),
352
- "_CUDA_VISIBLE_DEVICES": gpu_names[i_part],
353
- }
354
- )
355
- os.environ.update(config)
356
- cmd = '"%s" GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py'%python_exec
357
- print(cmd)
358
- p = Popen(cmd, shell=True)
359
- ps1b.append(p)
360
- yield "SSL提取进程执行中", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
361
- for p in ps1b:
362
- p.wait()
363
- ps1b=[]
364
- yield "SSL提取进程结束",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
365
- else:
366
- yield "已有正在进行的SSL提取任务,需先终止才能开启下一次任务", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
367
-
368
- def close1b():
369
- global ps1b
370
- if (ps1b != []):
371
- for p1b in ps1b:
372
- try:
373
- kill_process(p1b.pid)
374
- except:
375
- traceback.print_exc()
376
- ps1b=[]
377
- return "已终止所有1b进程", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
378
- '''
379
- inp_text= os.environ.get("inp_text")
380
- exp_name= os.environ.get("exp_name")
381
- i_part= os.environ.get("i_part")
382
- all_parts= os.environ.get("all_parts")
383
- os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
384
- opt_dir= os.environ.get("opt_dir")
385
- pretrained_s2G= os.environ.get("pretrained_s2G")
386
- '''
387
- ps1c=[]
388
- def open1c(inp_text,exp_name,gpu_numbers,pretrained_s2G_path):
389
- global ps1c
390
- if (ps1c == []):
391
- config={
392
- "inp_text":inp_text,
393
- "exp_name":exp_name,
394
- "opt_dir":"%s/%s"%(exp_root,exp_name),
395
- "pretrained_s2G":pretrained_s2G_path,
396
- "s2config_path":"GPT_SoVITS/configs/s2.json",
397
- "is_half": str(is_half)
398
- }
399
- gpu_names=gpu_numbers.split("-")
400
- all_parts=len(gpu_names)
401
- for i_part in range(all_parts):
402
- config.update(
403
- {
404
- "i_part": str(i_part),
405
- "all_parts": str(all_parts),
406
- "_CUDA_VISIBLE_DEVICES": gpu_names[i_part],
407
- }
408
- )
409
- os.environ.update(config)
410
- cmd = '"%s" GPT_SoVITS/prepare_datasets/3-get-semantic.py'%python_exec
411
- print(cmd)
412
- p = Popen(cmd, shell=True)
413
- ps1c.append(p)
414
- yield "语义token提取进程执行中", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
415
- for p in ps1c:
416
- p.wait()
417
- ps1c=[]
418
- yield "语义token提取进程结束",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
419
- else:
420
- yield "已有正在进行的语义token提取任务,需先终止才能开启下一次任务", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
421
-
422
- def close1c():
423
- global ps1c
424
- if (ps1c != []):
425
- for p1c in ps1c:
426
- try:
427
- kill_process(p1c.pid)
428
- except:
429
- traceback.print_exc()
430
- ps1c=[]
431
- return "已终止所有语义token进程", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
432
- #####inp_text,inp_wav_dir,exp_name,gpu_numbers1a,gpu_numbers1Ba,gpu_numbers1c,bert_pretrained_dir,cnhubert_base_dir,pretrained_s2G
433
- ps1abc=[]
434
- def open1abc(inp_text,inp_wav_dir,exp_name,gpu_numbers1a,gpu_numbers1Ba,gpu_numbers1c,bert_pretrained_dir,ssl_pretrained_dir,pretrained_s2G_path):
435
- global ps1abc
436
- if (ps1abc == []):
437
- opt_dir="%s/%s"%(exp_root,exp_name)
438
- try:
439
- #############################1a
440
- path_text="%s/2-name2text.txt" % opt_dir
441
- if(os.path.exists(path_text)==False):
442
- config={
443
- "inp_text":inp_text,
444
- "inp_wav_dir":inp_wav_dir,
445
- "exp_name":exp_name,
446
- "opt_dir":opt_dir,
447
- "bert_pretrained_dir":bert_pretrained_dir,
448
- "is_half": str(is_half)
449
- }
450
- gpu_names=gpu_numbers1a.split("-")
451
- all_parts=len(gpu_names)
452
- for i_part in range(all_parts):
453
- config.update(
454
- {
455
- "i_part": str(i_part),
456
- "all_parts": str(all_parts),
457
- "_CUDA_VISIBLE_DEVICES": gpu_names[i_part],
458
- }
459
- )
460
- os.environ.update(config)
461
- cmd = '"%s" GPT_SoVITS/prepare_datasets/1-get-text.py'%python_exec
462
- print(cmd)
463
- p = Popen(cmd, shell=True)
464
- ps1abc.append(p)
465
- yield "进度:1a-ing", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
466
- for p in ps1abc:p.wait()
467
-
468
- opt = []
469
- for i_part in range(all_parts):#txt_path="%s/2-name2text-%s.txt"%(opt_dir,i_part)
470
- txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
471
- with open(txt_path, "r",encoding="utf8") as f:
472
- opt += f.read().strip("\n").split("\n")
473
- os.remove(txt_path)
474
- with open(path_text, "w",encoding="utf8") as f:
475
- f.write("\n".join(opt) + "\n")
476
-
477
- yield "进度:1a-done", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
478
- ps1abc=[]
479
- #############################1b
480
- config={
481
- "inp_text":inp_text,
482
- "inp_wav_dir":inp_wav_dir,
483
- "exp_name":exp_name,
484
- "opt_dir":opt_dir,
485
- "cnhubert_base_dir":ssl_pretrained_dir,
486
- }
487
- gpu_names=gpu_numbers1Ba.split("-")
488
- all_parts=len(gpu_names)
489
- for i_part in range(all_parts):
490
- config.update(
491
- {
492
- "i_part": str(i_part),
493
- "all_parts": str(all_parts),
494
- "_CUDA_VISIBLE_DEVICES": gpu_names[i_part],
495
- }
496
- )
497
- os.environ.update(config)
498
- cmd = '"%s" GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py'%python_exec
499
- print(cmd)
500
- p = Popen(cmd, shell=True)
501
- ps1abc.append(p)
502
- yield "进度:1a-done, 1b-ing", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
503
- for p in ps1abc:p.wait()
504
- yield "进度:1a1b-done", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
505
- ps1abc=[]
506
- #############################1c
507
- path_semantic = "%s/6-name2semantic.tsv" % opt_dir
508
- if(os.path.exists(path_semantic)==False):
509
- config={
510
- "inp_text":inp_text,
511
- "exp_name":exp_name,
512
- "opt_dir":opt_dir,
513
- "pretrained_s2G":pretrained_s2G_path,
514
- "s2config_path":"GPT_SoVITS/configs/s2.json",
515
- }
516
- gpu_names=gpu_numbers1c.split("-")
517
- all_parts=len(gpu_names)
518
- for i_part in range(all_parts):
519
- config.update(
520
- {
521
- "i_part": str(i_part),
522
- "all_parts": str(all_parts),
523
- "_CUDA_VISIBLE_DEVICES": gpu_names[i_part],
524
- }
525
- )
526
- os.environ.update(config)
527
- cmd = '"%s" GPT_SoVITS/prepare_datasets/3-get-semantic.py'%python_exec
528
- print(cmd)
529
- p = Popen(cmd, shell=True)
530
- ps1abc.append(p)
531
- yield "进度:1a1b-done, 1cing", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
532
- for p in ps1abc:p.wait()
533
-
534
- opt = ["item_name semantic_audio"]
535
- for i_part in range(all_parts):
536
- semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
537
- with open(semantic_path, "r",encoding="utf8") as f:
538
- opt += f.read().strip("\n").split("\n")
539
- os.remove(semantic_path)
540
- with open(path_semantic, "w",encoding="utf8") as f:
541
- f.write("\n".join(opt) + "\n")
542
- yield "进度:all-done", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
543
- ps1abc = []
544
- yield "一键三连进程结束", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
545
- except:
546
- traceback.print_exc()
547
- close1abc()
548
- yield "一键三连中途报错", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
549
- else:
550
- yield "已有正在进行的一键三连任务,需先终止才能开启下一次任务", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
551
-
552
- def close1abc():
553
- global ps1abc
554
- if (ps1abc != []):
555
- for p1abc in ps1abc:
556
- try:
557
- kill_process(p1abc.pid)
558
- except:
559
- traceback.print_exc()
560
- ps1abc=[]
561
- return "已终止所有一键三连进程", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
562
-
563
- with gr.Blocks(title="GPT-SoVITS WebUI") as app:
564
- gr.Markdown(
565
- value=
566
- "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
567
- )
568
- with gr.Tabs():
569
- with gr.TabItem("0-前置数据集获取工具"):#提前随机切片防止uvr5爆内存->uvr5->slicer->asr->打标
570
- gr.Markdown(value="0a-UVR5人声伴奏分离&去混响去延迟工具")
571
- with gr.Row():
572
- if_uvr5 = gr.Checkbox(label="是否开启UVR5-WebUI",show_label=True)
573
- uvr5_info = gr.Textbox(label="UVR5进程输出信息")
574
- gr.Markdown(value="0b-语音切分工具")
575
- with gr.Row():
576
- with gr.Row():
577
- slice_inp_path=gr.Textbox(label="音频自动切分输入路径,可文件可文件夹",value="")
578
- slice_opt_root=gr.Textbox(label="切分后的子音频的输出根目录",value="output/slicer_opt")
579
- threshold=gr.Textbox(label="threshold:音量小于这个值视作静音的备选切割点",value="-34")
580
- min_length=gr.Textbox(label="min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值",value="4000")
581
- min_interval=gr.Textbox(label="min_interval:最短切割间隔",value="300")
582
- hop_size=gr.Textbox(label="hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)",value="10")
583
- max_sil_kept=gr.Textbox(label="max_sil_kept:切完后静音最多留多长",value="500")
584
- with gr.Row():
585
- open_slicer_button=gr.Button("开启语音切割", variant="primary",visible=True)
586
- close_slicer_button=gr.Button("终止语音切割", variant="primary",visible=False)
587
- _max=gr.Slider(minimum=0,maximum=1,step=0.05,label="max:归一化后最大值多少",value=0.9,interactive=True)
588
- alpha=gr.Slider(minimum=0,maximum=1,step=0.05,label="alpha_mix:混多少比例归一化后音频进来",value=0.25,interactive=True)
589
- n_process=gr.Slider(minimum=1,maximum=n_cpu,step=1,label="切割使用的进程数",value=4,interactive=True)
590
- slicer_info = gr.Textbox(label="语音切割进程输出信息")
591
- gr.Markdown(value="0c-中文批量离线ASR工具")
592
- with gr.Row():
593
- open_asr_button = gr.Button("开启离线批量ASR", variant="primary",visible=True)
594
- close_asr_button = gr.Button("终止ASR进程", variant="primary",visible=False)
595
- asr_inp_dir = gr.Textbox(
596
- label="批量ASR(中文only)输入文件夹路径",
597
- value="D:\\RVC1006\\GPT-SoVITS\\raw\\xxx",
598
- interactive=True,
599
- )
600
- asr_info = gr.Textbox(label="ASR进程输出信息")
601
- gr.Markdown(value="0d-语音文本校对标注工具")
602
- with gr.Row():
603
- if_label = gr.Checkbox(label="是否开启打标WebUI",show_label=True)
604
- path_list = gr.Textbox(
605
- label="打标数据标注文件路径",
606
- value="D:\\RVC1006\\GPT-SoVITS\\raw\\xxx.list",
607
- interactive=True,
608
- )
609
- label_info = gr.Textbox(label="打标工具进程输出信息")
610
- if_label.change(change_label, [if_label,path_list], [label_info])
611
- if_uvr5.change(change_uvr5, [if_uvr5], [uvr5_info])
612
- open_asr_button.click(open_asr, [asr_inp_dir], [asr_info,open_asr_button,close_asr_button])
613
- close_asr_button.click(close_asr, [], [asr_info,open_asr_button,close_asr_button])
614
- open_slicer_button.click(open_slice, [slice_inp_path,slice_opt_root,threshold,min_length,min_interval,hop_size,max_sil_kept,_max,alpha,n_process], [slicer_info,open_slicer_button,close_slicer_button])
615
- close_slicer_button.click(close_slice, [], [slicer_info,open_slicer_button,close_slicer_button])
616
- with gr.TabItem("1-GPT-SoVITS-TTS"):
617
- with gr.Row():
618
- exp_name = gr.Textbox(label="*实验/模型名", value="xxx", interactive=True)
619
- gpu_info = gr.Textbox(label="显卡信息", value=gpu_info, visible=True, interactive=False)
620
- pretrained_s2G = gr.Textbox(label="预训练的SoVITS-G模型路径", value="GPT_SoVITS/pretrained_models/s2G488k.pth", interactive=True)
621
- pretrained_s2D = gr.Textbox(label="预训练的SoVITS-D模型路径", value="GPT_SoVITS/pretrained_models/s2D488k.pth", interactive=True)
622
- pretrained_s1 = gr.Textbox(label="预训练的GPT模型路径", value="GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", interactive=True)
623
- with gr.TabItem("1A-训练集格式化工具"):
624
- gr.Markdown(value="输出logs/实验名目录下应有23456开头的文件和文件夹")
625
- with gr.Row():
626
- inp_text = gr.Textbox(label="*文本标注文件",value=r"D:\RVC1006\GPT-SoVITS\raw\xxx.list",interactive=True)
627
- inp_wav_dir = gr.Textbox(label="*训练集音频文件目录",value=r"D:\RVC1006\GPT-SoVITS\raw\xxx",interactive=True)
628
- gr.Markdown(value="1Aa-文本内容")
629
- with gr.Row():
630
- gpu_numbers1a = gr.Textbox(label="GPU卡号以-分割,每个卡号一个进程",value="%s-%s"%(gpus,gpus),interactive=True)
631
- bert_pretrained_dir = gr.Textbox(label="预训练的中文BERT模型路径",value="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",interactive=False)
632
- button1a_open = gr.Button("开启文本获取", variant="primary",visible=True)
633
- button1a_close = gr.Button("终止文本获取进程", variant="primary",visible=False)
634
- info1a=gr.Textbox(label="文本进程输出信息")
635
- gr.Markdown(value="1Ab-SSL自监督特征提取")
636
- with gr.Row():
637
- gpu_numbers1Ba = gr.Textbox(label="GPU卡号以-分割,每个卡号一个进程",value="%s-%s"%(gpus,gpus),interactive=True)
638
- cnhubert_base_dir = gr.Textbox(label="预训练的SSL模型路径",value="GPT_SoVITS/pretrained_models/chinese-hubert-base",interactive=False)
639
- button1b_open = gr.Button("开启SSL提取", variant="primary",visible=True)
640
- button1b_close = gr.Button("终止SSL提取进程", variant="primary",visible=False)
641
- info1b=gr.Textbox(label="SSL进程输出信息")
642
- gr.Markdown(value="1Ac-语义token提取")
643
- with gr.Row():
644
- gpu_numbers1c = gr.Textbox(label="GPU卡号以-分割,每个卡号一个进程",value="%s-%s"%(gpus,gpus),interactive=True)
645
- button1c_open = gr.Button("开启语义token提取", variant="primary",visible=True)
646
- button1c_close = gr.Button("终止语义token提取进程", variant="primary",visible=False)
647
- info1c=gr.Textbox(label="语义token提取进程输出信息")
648
- gr.Markdown(value="1Aabc-训练集格式化一键三连")
649
- with gr.Row():
650
- button1abc_open = gr.Button("开启一键三连", variant="primary",visible=True)
651
- button1abc_close = gr.Button("终止一键三连", variant="primary",visible=False)
652
- info1abc=gr.Textbox(label="一键三连进程输出信息")
653
- button1a_open.click(open1a, [inp_text,inp_wav_dir,exp_name,gpu_numbers1a,bert_pretrained_dir], [info1a,button1a_open,button1a_close])
654
- button1a_close.click(close1a, [], [info1a,button1a_open,button1a_close])
655
- button1b_open.click(open1b, [inp_text,inp_wav_dir,exp_name,gpu_numbers1Ba,cnhubert_base_dir], [info1b,button1b_open,button1b_close])
656
- button1b_close.click(close1b, [], [info1b,button1b_open,button1b_close])
657
- button1c_open.click(open1c, [inp_text,exp_name,gpu_numbers1c,pretrained_s2G], [info1c,button1c_open,button1c_close])
658
- button1c_close.click(close1c, [], [info1c,button1c_open,button1c_close])
659
- button1abc_open.click(open1abc, [inp_text,inp_wav_dir,exp_name,gpu_numbers1a,gpu_numbers1Ba,gpu_numbers1c,bert_pretrained_dir,cnhubert_base_dir,pretrained_s2G], [info1abc,button1abc_open,button1abc_close])
660
- button1abc_close.click(close1abc, [], [info1abc,button1abc_open,button1abc_close])
661
- with gr.TabItem("1B-微调训练"):
662
- gr.Markdown(value="1Ba-SoVITS训练。用于分享的模型文件输出在SoVITS_weights下。")
663
- with gr.Row():
664
- batch_size = gr.Slider(minimum=1,maximum=40,step=1,label=i18n("每张显卡的batch_size"),value=default_batch_size,interactive=True)
665
- total_epoch = gr.Slider(minimum=2,maximum=100,step=1,label=i18n("总训练轮数total_epoch,不建议太高"),value=10,interactive=True)
666
- text_low_lr_rate = gr.Slider(minimum=0.2,maximum=0.6,step=0.05,label="文本模块学习率权重",value=0.4,interactive=True)
667
- save_every_epoch = gr.Slider(minimum=1,maximum=50,step=1,label=i18n("保存频率save_every_epoch"),value=5,interactive=True)
668
- if_save_latest = gr.Checkbox(label=i18n("是否仅保存最新的ckpt文件以节省硬盘空间"), value=True, interactive=True, show_label=True)
669
- if_save_every_weights = gr.Checkbox(label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"), value=True, interactive=True, show_label=True)
670
- gpu_numbers1Ba = gr.Textbox(label="GPU卡号以-分割,每个卡号一个进程", value="%s" % (gpus), interactive=True)
671
- with gr.Row():
672
- button1Ba_open = gr.Button("开启SoVITS训练", variant="primary",visible=True)
673
- button1Ba_close = gr.Button("终止SoVITS训练", variant="primary",visible=False)
674
- info1Ba=gr.Textbox(label="SoVITS训练进程输出信息")
675
- gr.Markdown(value="1Bb-GPT训练。用于分享的模型文件输出在GPT_weights下。")
676
- with gr.Row():
677
- batch_size1Bb = gr.Slider(minimum=1,maximum=40,step=1,label=i18n("每张显卡的batch_size"),value=default_batch_size,interactive=True)
678
- total_epoch1Bb = gr.Slider(minimum=2,maximum=200,step=1,label=i18n("总训练轮数total_epoch"),value=15,interactive=True)
679
- if_save_latest1Bb = gr.Checkbox(label=i18n("是否仅保存最新的ckpt文件以节省硬盘空间"), value=True, interactive=True, show_label=True)
680
- if_save_every_weights1Bb = gr.Checkbox(label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"), value=True, interactive=True, show_label=True)
681
- save_every_epoch1Bb = gr.Slider(minimum=1,maximum=50,step=1,label=i18n("保存频率save_every_epoch"),value=5,interactive=True)
682
- gpu_numbers1Bb = gr.Textbox(label="GPU卡号以-分割,每个卡号一个进程", value="%s" % (gpus), interactive=True)
683
- with gr.Row():
684
- button1Bb_open = gr.Button("开启GPT训练", variant="primary",visible=True)
685
- button1Bb_close = gr.Button("终止GPT训练", variant="primary",visible=False)
686
- info1Bb=gr.Textbox(label="GPT训练进程输出信息")
687
- button1Ba_open.click(open1Ba, [batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D], [info1Ba,button1Ba_open,button1Ba_close])
688
- button1Ba_close.click(close1Ba, [], [info1Ba,button1Ba_open,button1Ba_close])
689
- button1Bb_open.click(open1Bb, [batch_size1Bb,total_epoch1Bb,exp_name,if_save_latest1Bb,if_save_every_weights1Bb,save_every_epoch1Bb,gpu_numbers1Bb,pretrained_s1], [info1Bb,button1Bb_open,button1Bb_close])
690
- button1Bb_close.click(close1Bb, [], [info1Bb,button1Bb_open,button1Bb_close])
691
- with gr.TabItem("1C-推理"):
692
- gr.Markdown(value="选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模,体验5秒Zero Shot TTS用。")
693
- with gr.Row():
694
- GPT_dropdown = gr.Dropdown(label="*GPT模型列表", choices=sorted(GPT_names),value=pretrained_gpt_name)
695
- SoVITS_dropdown = gr.Dropdown(label="*SoVITS模型列表", choices=sorted(SoVITS_names),value=pretrained_sovits_name)
696
- gpu_number_1C=gr.Textbox(label="GPU卡号,只能填1个整数", value=gpus, interactive=True)
697
- refresh_button = gr.Button("刷新模型路径", variant="primary")
698
- refresh_button.click(fn=change_choices,inputs=[],outputs=[SoVITS_dropdown,GPT_dropdown])
699
- with gr.Row():
700
- if_tts = gr.Checkbox(label="是否开启TTS推理WebUI", show_label=True)
701
- tts_info = gr.Textbox(label="TTS推理WebUI进程输出信息")
702
- if_tts.change(change_tts_inference, [if_tts,bert_pretrained_dir,cnhubert_base_dir,gpu_number_1C,GPT_dropdown,SoVITS_dropdown], [tts_info])
703
- with gr.TabItem("2-GPT-SoVITS-变声"):gr.Markdown(value="施工中,请静候佳音")
704
-
705
- '''
706
- os.environ["gpt_path"]=gpt_path
707
- os.environ["sovits_path"]=sovits_path#bert_pretrained_dir
708
- os.environ["cnhubert_base_path"]=cnhubert_base_path#cnhubert_base_dir
709
- os.environ["bert_path"]=bert_path
710
- os.environ["_CUDA_VISIBLE_DEVICES"]=gpu_number
711
- '''
712
-
713
- app.queue(concurrency_count=511, max_size=1022).launch(
714
- share=True,
715
- server_name="0.0.0.0",
716
- inbrowser=True,
717
- server_port=7890,
718
- quiet=True,
719
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/.ipynb_checkpoints/启动webui-checkpoint.sh DELETED
@@ -1,2 +0,0 @@
1
- #!/bin/bash
2
- python /root/autodl-tmp/workdir/GPT-SoVITS/webui.py
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/.ipynb_checkpoints/inference_webui-checkpoint.py DELETED
@@ -1,270 +0,0 @@
1
- import os
2
- gpt_path=os.environ.get("gpt_path","pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
3
- sovits_path=os.environ.get("sovits_path","pretrained_models/s2G488k.pth")
4
- cnhubert_base_path=os.environ.get("cnhubert_base_path","pretrained_models/chinese-hubert-base")
5
- bert_path=os.environ.get("bert_path","pretrained_models/chinese-roberta-wwm-ext-large")
6
- if("_CUDA_VISIBLE_DEVICES"in os.environ):
7
- os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"]
8
- is_half=eval(os.environ.get("is_half","True"))
9
- import gradio as gr
10
- from transformers import AutoModelForMaskedLM, AutoTokenizer
11
- import sys,torch,numpy as np
12
- from pathlib import Path
13
- import os,pdb,utils,librosa,math,traceback,requests,argparse,torch,multiprocessing,pandas as pd,torch.multiprocessing as mp,soundfile
14
- # torch.backends.cuda.sdp_kernel("flash")
15
- # torch.backends.cuda.enable_flash_sdp(True)
16
- # torch.backends.cuda.enable_mem_efficient_sdp(True) # Not avaliable if torch version is lower than 2.0
17
- # torch.backends.cuda.enable_math_sdp(True)
18
- from random import shuffle
19
- from AR.utils import get_newest_ckpt
20
- from glob import glob
21
- from tqdm import tqdm
22
- from feature_extractor import cnhubert
23
- cnhubert.cnhubert_base_path=cnhubert_base_path
24
- from io import BytesIO
25
- from module.models import SynthesizerTrn
26
- from AR.models.t2s_lightning_module import Text2SemanticLightningModule
27
- from AR.utils.io import load_yaml_config
28
- from text import cleaned_text_to_sequence
29
- from text.cleaner import text_to_sequence, clean_text
30
- from time import time as ttime
31
- from module.mel_processing import spectrogram_torch
32
- from my_utils import load_audio
33
-
34
- device="cuda"
35
- tokenizer = AutoTokenizer.from_pretrained(bert_path)
36
- bert_model=AutoModelForMaskedLM.from_pretrained(bert_path)
37
- if(is_half==True):bert_model=bert_model.half().to(device)
38
- else:bert_model=bert_model.to(device)
39
- # bert_model=bert_model.to(device)
40
- def get_bert_feature(text, word2ph):
41
- with torch.no_grad():
42
- inputs = tokenizer(text, return_tensors="pt")
43
- for i in inputs:
44
- inputs[i] = inputs[i].to(device)#####输入是long不用管精度问题,精度随bert_model
45
- res = bert_model(**inputs, output_hidden_states=True)
46
- res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
47
- assert len(word2ph) == len(text)
48
- phone_level_feature = []
49
- for i in range(len(word2ph)):
50
- repeat_feature = res[i].repeat(word2ph[i], 1)
51
- phone_level_feature.append(repeat_feature)
52
- phone_level_feature = torch.cat(phone_level_feature, dim=0)
53
- # if(is_half==True):phone_level_feature=phone_level_feature.half()
54
- return phone_level_feature.T
55
-
56
- n_semantic = 1024
57
- dict_s2=torch.load(sovits_path,map_location="cpu")
58
- hps=dict_s2["config"]
59
- class DictToAttrRecursive:
60
- def __init__(self, input_dict):
61
- for key, value in input_dict.items():
62
- if isinstance(value, dict):
63
- # 如果值是字典,递归调用构造函数
64
- setattr(self, key, DictToAttrRecursive(value))
65
- else:
66
- setattr(self, key, value)
67
-
68
- hps = DictToAttrRecursive(hps)
69
- hps.model.semantic_frame_rate="25hz"
70
- dict_s1=torch.load(gpt_path,map_location="cpu")
71
- config=dict_s1["config"]
72
- ssl_model=cnhubert.get_model()
73
- if(is_half==True):ssl_model=ssl_model.half().to(device)
74
- else:ssl_model=ssl_model.to(device)
75
-
76
- vq_model = SynthesizerTrn(
77
- hps.data.filter_length // 2 + 1,
78
- hps.train.segment_size // hps.data.hop_length,
79
- n_speakers=hps.data.n_speakers,
80
- **hps.model)
81
- if(is_half==True):vq_model=vq_model.half().to(device)
82
- else:vq_model=vq_model.to(device)
83
- vq_model.eval()
84
- print(vq_model.load_state_dict(dict_s2["weight"],strict=False))
85
- hz = 50
86
- max_sec = config['data']['max_sec']
87
- # t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
88
- t2s_model = Text2SemanticLightningModule(config,"ojbk",is_train=False)
89
- t2s_model.load_state_dict(dict_s1["weight"])
90
- if(is_half==True):t2s_model=t2s_model.half()
91
- t2s_model=t2s_model.to(device)
92
- t2s_model.eval()
93
- total = sum([param.nelement() for param in t2s_model.parameters()])
94
- print("Number of parameter: %.2fM" % (total / 1e6))
95
- def get_spepc(hps, filename):
96
- audio=load_audio(filename,int(hps.data.sampling_rate))
97
- audio=torch.FloatTensor(audio)
98
- audio_norm = audio
99
- audio_norm = audio_norm.unsqueeze(0)
100
- spec = spectrogram_torch(audio_norm, hps.data.filter_length,hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,center=False)
101
- return spec
102
-
103
- dict_language={
104
- "中文":"zh",
105
- "英文":"en",
106
- "日文":"ja"
107
- }
108
- def get_tts_wav(ref_wav_path,prompt_text,prompt_language,text,text_language):
109
- t0 = ttime()
110
- prompt_text=prompt_text.strip("\n")
111
- prompt_language,text=prompt_language,text.strip("\n")
112
- with torch.no_grad():
113
- wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
114
- wav16k = torch.from_numpy(wav16k)
115
- if(is_half==True):wav16k=wav16k.half().to(device)
116
- else:wav16k=wav16k.to(device)
117
- ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)#.float()
118
- codes = vq_model.extract_latent(ssl_content)
119
- prompt_semantic = codes[0, 0]
120
- t1 = ttime()
121
- prompt_language=dict_language[prompt_language]
122
- text_language=dict_language[text_language]
123
- phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
124
- phones1=cleaned_text_to_sequence(phones1)
125
- texts=text.split("\n")
126
- audio_opt = []
127
- zero_wav=np.zeros(int(hps.data.sampling_rate*0.3),dtype=np.float16 if is_half==True else np.float32)
128
- for text in texts:
129
- phones2, word2ph2, norm_text2 = clean_text(text, text_language)
130
- phones2 = cleaned_text_to_sequence(phones2)
131
- if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1)
132
- else:bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device)
133
- if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2)
134
- else:bert2 = torch.zeros((1024, len(phones2))).to(bert1)
135
- bert = torch.cat([bert1, bert2], 1)
136
-
137
- all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
138
- bert = bert.to(device).unsqueeze(0)
139
- all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
140
- prompt = prompt_semantic.unsqueeze(0).to(device)
141
- t2 = ttime()
142
- with torch.no_grad():
143
- # pred_semantic = t2s_model.model.infer(
144
- pred_semantic,idx = t2s_model.model.infer_panel(
145
- all_phoneme_ids,
146
- all_phoneme_len,
147
- prompt,
148
- bert,
149
- # prompt_phone_len=ph_offset,
150
- top_k=config['inference']['top_k'],
151
- early_stop_num=hz * max_sec)
152
- t3 = ttime()
153
- # print(pred_semantic.shape,idx)
154
- pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
155
- refer = get_spepc(hps, ref_wav_path)#.to(device)
156
- if(is_half==True):refer=refer.half().to(device)
157
- else:refer=refer.to(device)
158
- # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
159
- audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分
160
- audio_opt.append(audio)
161
- audio_opt.append(zero_wav)
162
- t4 = ttime()
163
- print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
164
- yield hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16)
165
-
166
-
167
- splits={",","。","?","!",",",".","?","!","~",":",":","—","…",}#不考虑省略号
168
- def split(todo_text):
169
- todo_text = todo_text.replace("……", "。").replace("——", ",")
170
- if (todo_text[-1] not in splits): todo_text += "。"
171
- i_split_head = i_split_tail = 0
172
- len_text = len(todo_text)
173
- todo_texts = []
174
- while (1):
175
- if (i_split_head >= len_text): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
176
- if (todo_text[i_split_head] in splits):
177
- i_split_head += 1
178
- todo_texts.append(todo_text[i_split_tail:i_split_head])
179
- i_split_tail = i_split_head
180
- else:
181
- i_split_head += 1
182
- return todo_texts
183
- def cut1(inp):
184
- inp=inp.strip("\n")
185
- inps=split(inp)
186
- split_idx=list(range(0,len(inps),5))
187
- split_idx[-1]=None
188
- if(len(split_idx)>1):
189
- opts=[]
190
- for idx in range(len(split_idx)-1):
191
- opts.append("".join(inps[split_idx[idx]:split_idx[idx+1]]))
192
- else:
193
- opts=[inp]
194
- return "\n".join(opts)
195
-
196
- def cut2(inp):
197
- inp=inp.strip("\n")
198
- inps=split(inp)
199
- if(len(inps)<2):return [inp]
200
- opts=[]
201
- summ=0
202
- tmp_str=""
203
- for i in range(len(inps)):
204
- summ+=len(inps[i])
205
- tmp_str+=inps[i]
206
- if(summ>50):
207
- summ=0
208
- opts.append(tmp_str)
209
- tmp_str=""
210
- if(tmp_str!=""):opts.append(tmp_str)
211
- if(len(opts[-1])<50):##如果最后一个太短了,和前一个合一起
212
- opts[-2]=opts[-2]+opts[-1]
213
- opts=opts[:-1]
214
- return "\n".join(opts)
215
-
216
- def cut3(inp):
217
- inp=inp.strip("\n")
218
- return "\n".join(["%s。"%item for item in inp.strip("。").split("。")])
219
-
220
- with gr.Blocks(title="GPT-SoVITS WebUI") as app:
221
- gr.Markdown(
222
- value=
223
- "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
224
- )
225
- # with gr.Tabs():
226
- # with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
227
- with gr.Group():
228
- gr.Markdown(
229
- value=
230
- "*请上传并填写参考信息"
231
- )
232
- with gr.Row():
233
- inp_ref = gr.Audio(label="请上传参考音频", type="filepath")
234
- prompt_text= gr.Textbox(label="参考音频的文本",value="")
235
- prompt_language= gr.Dropdown(label="参考音频的语种",choices=["中文","英文","日文"])
236
- gr.Markdown(
237
- value=
238
- "*请填写需要合成的目标文本"
239
- )
240
- with gr.Row():
241
- text=gr.Textbox(label="需要合成的文本",value="")
242
- text_language = gr.Dropdown(label="需要合成的语种", choices=["中文", "英文", "日文"])
243
- inference_button=gr.Button("合成语音", variant="primary")
244
- output = gr.Audio(label="输出的语音")
245
- inference_button.click(get_tts_wav, [inp_ref, prompt_text,prompt_language, text,text_language], [output])
246
-
247
- gr.Markdown(
248
- value=
249
- "文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
250
- )
251
- with gr.Row():
252
- text_inp=gr.Textbox(label="需要合成的切分前文本",value="")
253
- button1 = gr.Button("凑五句一切", variant="primary")
254
- button2 = gr.Button("凑50字一切", variant="primary")
255
- button3 = gr.Button("按中文句号。切", variant="primary")
256
- text_opt = gr.Textbox(label="切分后文本", value="")
257
- button1.click(cut1,[text_inp],[text_opt])
258
- button2.click(cut2,[text_inp],[text_opt])
259
- button3.click(cut3,[text_inp],[text_opt])
260
- gr.Markdown(
261
- value=
262
- "后续将支持混合语种编码文本输入。"
263
- )
264
-
265
- app.queue(concurrency_count=511, max_size=1022).launch(
266
- server_name="0.0.0.0",
267
- inbrowser=True,
268
- server_port=6006,
269
- quiet=True,
270
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/__init__.py DELETED
File without changes
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/__init__.py DELETED
File without changes
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/bucket_sampler.py DELETED
@@ -1,157 +0,0 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/bucketsampler.py
2
- import itertools
3
- import math
4
- import random
5
- from random import shuffle
6
- from typing import Iterator
7
- from typing import Optional
8
- from typing import TypeVar
9
-
10
- import torch
11
- import torch.distributed as dist
12
- from torch.utils.data import Dataset
13
- from torch.utils.data import Sampler
14
-
15
- __all__ = [
16
- "DistributedBucketSampler",
17
- ]
18
-
19
- T_co = TypeVar('T_co', covariant=True)
20
-
21
-
22
- class DistributedBucketSampler(Sampler[T_co]):
23
- r"""
24
- sort the dataset wrt. input length
25
- divide samples into buckets
26
- sort within buckets
27
- divide buckets into batches
28
- sort batches
29
- """
30
-
31
- def __init__(self,
32
- dataset: Dataset,
33
- num_replicas: Optional[int]=None,
34
- rank: Optional[int]=None,
35
- shuffle: bool=True,
36
- seed: int=0,
37
- drop_last: bool=False,
38
- batch_size: int=32) -> None:
39
- if num_replicas is None:
40
- if not dist.is_available():
41
- raise RuntimeError(
42
- "Requires distributed package to be available")
43
- num_replicas = dist.get_world_size()
44
- if rank is None:
45
- if not dist.is_available():
46
- raise RuntimeError(
47
- "Requires distributed package to be available")
48
- rank = dist.get_rank()
49
- torch.cuda.set_device(rank)
50
- if rank >= num_replicas or rank < 0:
51
- raise ValueError("Invalid rank {}, rank should be in the interval"
52
- " [0, {}]".format(rank, num_replicas - 1))
53
- self.dataset = dataset
54
- self.num_replicas = num_replicas
55
- self.rank = rank
56
- self.epoch = 0
57
- self.drop_last = drop_last
58
- # If the dataset length is evenly divisible by # of replicas, then there
59
- # is no need to drop any data, since the dataset will be split equally.
60
- if self.drop_last and len(
61
- self.
62
- dataset) % self.num_replicas != 0: # type: ignore[arg-type]
63
- # Split to nearest available length that is evenly divisible.
64
- # This is to ensure each rank receives the same amount of data when
65
- # using this Sampler.
66
- self.num_samples = math.ceil(
67
- (len(self.dataset) - self.num_replicas) /
68
- self.num_replicas # type: ignore[arg-type]
69
- )
70
- else:
71
- self.num_samples = math.ceil(
72
- len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
73
- self.total_size = self.num_samples * self.num_replicas
74
- self.shuffle = shuffle
75
- self.seed = seed
76
- self.batch_size = batch_size
77
- self.id_with_length = self._get_sample_lengths()
78
- self.id_buckets = self.make_buckets(bucket_width=2.0)
79
-
80
- def _get_sample_lengths(self):
81
- id_with_lengths = []
82
- for i in range(len(self.dataset)):
83
- id_with_lengths.append((i, self.dataset.get_sample_length(i)))
84
- id_with_lengths.sort(key=lambda x: x[1])
85
- return id_with_lengths
86
-
87
- def make_buckets(self, bucket_width: float=2.0):
88
- buckets = []
89
- cur = []
90
- max_sec = bucket_width
91
- for id, sec in self.id_with_length:
92
- if sec < max_sec:
93
- cur.append(id)
94
- else:
95
- buckets.append(cur)
96
- cur = [id]
97
- max_sec += bucket_width
98
- if len(cur) > 0:
99
- buckets.append(cur)
100
- return buckets
101
-
102
- def __iter__(self) -> Iterator[T_co]:
103
- if self.shuffle:
104
- # deterministically shuffle based on epoch and seed
105
- g = torch.Generator()
106
- g.manual_seed(self.seed + self.epoch)
107
- random.seed(self.epoch + self.seed)
108
- shuffled_bucket = []
109
- for buc in self.id_buckets:
110
- buc_copy = buc.copy()
111
- shuffle(buc_copy)
112
- shuffled_bucket.append(buc_copy)
113
- grouped_batch_size = self.batch_size * self.num_replicas
114
- shuffled_bucket = list(itertools.chain(*shuffled_bucket))
115
- n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
116
- batches = [
117
- shuffled_bucket[b * grouped_batch_size:(b + 1) *
118
- grouped_batch_size] for b in range(n_batch)
119
- ]
120
- shuffle(batches)
121
- indices = list(itertools.chain(*batches))
122
- else:
123
- # type: ignore[arg-type]
124
- indices = list(range(len(self.dataset)))
125
-
126
- if not self.drop_last:
127
- # add extra samples to make it evenly divisible
128
- padding_size = self.total_size - len(indices)
129
- if padding_size <= len(indices):
130
- indices += indices[:padding_size]
131
- else:
132
- indices += (indices * math.ceil(padding_size /
133
- len(indices)))[:padding_size]
134
- else:
135
- # remove tail of data to make it evenly divisible.
136
- indices = indices[:self.total_size]
137
- assert len(indices) == self.total_size
138
-
139
- # subsample
140
- indices = indices[self.rank:self.total_size:self.num_replicas]
141
- assert len(indices) == self.num_samples
142
-
143
- return iter(indices)
144
-
145
- def __len__(self) -> int:
146
- return self.num_samples
147
-
148
- def set_epoch(self, epoch: int) -> None:
149
- r"""
150
- Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
151
- use a different random ordering for each epoch. Otherwise, the next iteration of this
152
- sampler will yield the same ordering.
153
-
154
- Args:
155
- epoch (int): Epoch number.
156
- """
157
- self.epoch = epoch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/data_module.py DELETED
@@ -1,66 +0,0 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py
2
- from pytorch_lightning import LightningDataModule
3
- from AR.data.bucket_sampler import DistributedBucketSampler
4
- from AR.data.dataset import Text2SemanticDataset
5
- from torch.utils.data import DataLoader
6
-
7
-
8
- class Text2SemanticDataModule(LightningDataModule):
9
- def __init__(self, config, train_semantic_path, train_phoneme_path,dev_semantic_path=None, dev_phoneme_path=None):
10
- super().__init__()
11
- self.config = config
12
- self.train_semantic_path = train_semantic_path
13
- self.train_phoneme_path = train_phoneme_path
14
- self.dev_semantic_path = dev_semantic_path
15
- self.dev_phoneme_path = dev_phoneme_path
16
- self.num_workers = self.config['data']['num_workers']
17
-
18
- def prepare_data(self):
19
- pass
20
-
21
- def setup(self, stage=None, output_logs=False):
22
- self._train_dataset = Text2SemanticDataset(
23
- phoneme_path=self.train_phoneme_path,
24
- semantic_path=self.train_semantic_path,
25
- max_sec=self.config['data']['max_sec'],
26
- pad_val=self.config['data']['pad_val'])
27
- self._dev_dataset = self._train_dataset
28
- # self._dev_dataset = Text2SemanticDataset(
29
- # phoneme_path=self.dev_phoneme_path,
30
- # semantic_path=self.dev_semantic_path,
31
- # max_sample=self.config['data']['max_eval_sample'],
32
- # max_sec=self.config['data']['max_sec'],
33
- # pad_val=self.config['data']['pad_val'])
34
-
35
- def train_dataloader(self):
36
- batch_size = self.config['train']['batch_size']
37
- sampler = DistributedBucketSampler(
38
- self._train_dataset, batch_size=batch_size)
39
- return DataLoader(
40
- self._train_dataset,
41
- batch_size=batch_size,
42
- sampler=sampler,
43
- collate_fn=self._train_dataset.collate,
44
- num_workers=self.num_workers,
45
- persistent_workers=True,
46
- prefetch_factor=16
47
- )
48
-
49
- def val_dataloader(self):
50
- return DataLoader(
51
- self._dev_dataset,
52
- batch_size=1,
53
- shuffle=False,
54
- collate_fn=self._train_dataset.collate,
55
- num_workers=max(self.num_workers,12),
56
- persistent_workers=True,
57
- prefetch_factor=16
58
- )
59
-
60
- # 这个会使用到嘛?
61
- def test_dataloader(self):
62
- return DataLoader(
63
- self._dev_dataset,
64
- batch_size=1,
65
- shuffle=False,
66
- collate_fn=self._train_dataset.collate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/data/dataset.py DELETED
@@ -1,302 +0,0 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py
2
- import pdb
3
- import sys
4
- # sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
5
- import traceback,os
6
- from typing import Dict
7
- from typing import List
8
-
9
- import numpy as np
10
- import pandas as pd
11
- import torch,json
12
- from torch.utils.data import DataLoader
13
- from torch.utils.data import Dataset
14
- from transformers import AutoTokenizer
15
-
16
- from text import cleaned_text_to_sequence
17
- # from config import exp_dir
18
-
19
- def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
20
- seq = sequences[0]
21
- ndim = seq.ndim
22
- if axis < 0:
23
- axis += ndim
24
- dtype = seq.dtype
25
- pad_value = dtype.type(pad_value)
26
- seq_lengths = [seq.shape[axis] for seq in sequences]
27
- max_length = np.max(seq_lengths)
28
-
29
- padded_sequences = []
30
- for seq, length in zip(sequences, seq_lengths):
31
- padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (
32
- ndim - axis - 1)
33
- padded_seq = np.pad(
34
- seq, padding, mode='constant', constant_values=pad_value)
35
- padded_sequences.append(padded_seq)
36
- batch = np.stack(padded_sequences)
37
- return batch
38
-
39
- class Text2SemanticDataset(Dataset):
40
- """dataset class for text tokens to semantic model training."""
41
-
42
- def __init__(self,
43
- phoneme_path: str,
44
- semantic_path: str,
45
- max_sample: int = None,
46
- max_sec: int = 100,
47
- pad_val: int = 1024,
48
- # min value of phoneme/sec
49
- min_ps_ratio: int = 3,
50
- # max value of phoneme/sec
51
- max_ps_ratio: int = 25) -> None:
52
- super().__init__()
53
-
54
- self.semantic_data = pd.read_csv(semantic_path, delimiter='\t', encoding="utf-8")
55
- # get dict
56
- self.path2=phoneme_path#"%s/2-name2text.txt"%exp_dir#phoneme_path
57
- self.path3="%s/3-bert"%(os.path.basename(phoneme_path))#"%s/3-bert"%exp_dir#bert_dir
58
- self.path6=semantic_path#"%s/6-name2semantic.tsv"%exp_dir#semantic_path
59
- assert os.path.exists(self.path2)
60
- assert os.path.exists(self.path6)
61
- self.phoneme_data={}
62
- with open(self.path2,"r",encoding="utf8")as f:
63
- lines=f.read().strip("\n").split("\n")
64
-
65
- for line in lines:
66
- tmp=line.split("\t")
67
- if(len(tmp)!=4):continue
68
- self.phoneme_data[tmp[0]]=[tmp[1],tmp[2],tmp[3]]
69
-
70
- # self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
71
- # pad for semantic tokens
72
- self.PAD: int = pad_val
73
- # self.hz = 25
74
- # with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read()
75
- # data=json.loads(data)["model"]["semantic_frame_rate"]#50hz
76
- # self.hz=int(data[:-2])#
77
- self.hz=int(os.environ.get("hz","25hz")[:-2])
78
-
79
- # max seconds of semantic token
80
- self.max_sec = max_sec
81
- self.min_ps_ratio = min_ps_ratio
82
- self.max_ps_ratio = max_ps_ratio
83
-
84
- if max_sample is not None:
85
- self.semantic_data = self.semantic_data[:max_sample]
86
-
87
- # {idx: (semantic, phoneme)}
88
- # semantic list, phoneme list
89
- self.semantic_phoneme = []
90
- self.item_names = []
91
-
92
- self.inited = False
93
-
94
- if not self.inited:
95
- # 调用初始化函数
96
- self.init_batch()
97
- self.inited = True
98
- del self.semantic_data
99
- del self.phoneme_data
100
- # self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
101
- # self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
102
-
103
-
104
- def init_batch(self):
105
- semantic_data_len = len(self.semantic_data)
106
- phoneme_data_len = len(self.phoneme_data.keys())
107
- print("semantic_data_len:", semantic_data_len)
108
- print("phoneme_data_len:", phoneme_data_len)
109
- idx = 0
110
- num_not_in = 0
111
- num_deleted_bigger = 0
112
- num_deleted_ps = 0
113
- for i in range(semantic_data_len):
114
- # 先依次遍历
115
- # get str
116
- item_name = self.semantic_data['item_name'][i]
117
- # print(self.phoneme_data)
118
- try:
119
- phoneme, word2ph, text = self.phoneme_data[item_name]
120
- except Exception:
121
- traceback.print_exc()
122
- # print(f"{item_name} not in self.phoneme_data !")
123
- num_not_in += 1
124
- continue
125
-
126
- semantic_str = self.semantic_data['semantic_audio'][i]
127
- # get token list
128
- semantic_ids = [int(idx) for idx in semantic_str.split(' ')]
129
- # (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
130
- # 过滤掉太长的样本
131
- if len(semantic_ids) > self.max_sec * self.hz:#########1###根据token���数推测总时长过滤时长60s(config里)#40*25=1k
132
- num_deleted_bigger += 1
133
- continue
134
- # (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
135
- phoneme = phoneme.split(' ')
136
-
137
- try:
138
- phoneme_ids = cleaned_text_to_sequence(phoneme)
139
- except:
140
- traceback.print_exc()
141
- # print(f"{item_name} not in self.phoneme_data !")
142
- num_not_in += 1
143
- continue
144
- # if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
145
- if len(phoneme_ids) >self.max_sec * self.hz/2.5:###########2:改为恒定限制为semantic/2.5就行
146
- num_deleted_ps += 1
147
- continue
148
- # if len(semantic_ids) > 1000:###########3
149
- # num_deleted_bigger += 1
150
- # continue
151
-
152
- ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
153
-
154
- if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio:##########4#3~25#每秒多少个phone
155
- num_deleted_ps += 1
156
- # print(item_name)
157
- continue
158
-
159
- self.semantic_phoneme.append((semantic_ids, phoneme_ids))
160
- idx += 1
161
- self.item_names.append(item_name)
162
-
163
- min_num=100#20直接不补#30补了也不存ckpt
164
- leng =len(self.semantic_phoneme)
165
- if(leng<min_num):
166
- tmp1=self.semantic_phoneme
167
- tmp2=self.item_names
168
- self.semantic_phoneme=[]
169
- self.item_names=[]
170
- for _ in range(max(2,int(min_num/leng))):
171
- self.semantic_phoneme+=tmp1
172
- self.item_names+=tmp2
173
- if num_not_in > 0:
174
- print(f"there are {num_not_in} semantic datas not in phoneme datas")
175
- if num_deleted_bigger > 0:
176
- print(
177
- f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds"
178
- )
179
- if num_deleted_ps > 0:
180
- # 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
181
- print(
182
- f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
183
- )
184
- '''
185
- there are 31 semantic datas not in phoneme datas
186
- deleted 34 audios who's duration are bigger than 54 seconds
187
- deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
188
- dataset.__len__(): 366463
189
-
190
- '''
191
- # 345410 for LibriTTS
192
- print("dataset.__len__():", self.__len__())
193
-
194
- def __get_item_names__(self) -> List[str]:
195
- return self.item_names
196
-
197
- def __len__(self) -> int:
198
- return len(self.semantic_phoneme)
199
-
200
- def __getitem__(self, idx: int) -> Dict:
201
- semantic_ids, phoneme_ids = self.semantic_phoneme[idx]
202
- item_name = self.item_names[idx]
203
- phoneme_ids_len = len(phoneme_ids)
204
- # semantic tokens target
205
- semantic_ids_len = len(semantic_ids)
206
-
207
- flag=0
208
- path_bert = "%s/%s.pt" % (self.path3, item_name)
209
- if(os.path.exists(path_bert)==True):bert_feature = torch.load(path_bert,map_location="cpu")
210
- else:flag=1
211
- if(flag==1):
212
- # bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
213
- bert_feature=None
214
- else:
215
- assert bert_feature.shape[-1] == len(phoneme_ids)
216
- return {
217
- 'idx': idx,
218
- 'phoneme_ids': phoneme_ids,
219
- 'phoneme_ids_len': phoneme_ids_len,
220
- 'semantic_ids': semantic_ids,
221
- 'semantic_ids_len': semantic_ids_len,
222
- 'bert_feature': bert_feature,
223
- }
224
-
225
- def get_sample_length(self, idx: int):
226
- semantic_ids = self.semantic_phoneme[idx][0]
227
- sec = 1.0 * len(semantic_ids) / self.hz
228
- return sec
229
-
230
- def collate(self, examples: List[Dict]) -> Dict:
231
- sample_index: List[int] = []
232
- phoneme_ids: List[torch.Tensor] = []
233
- phoneme_ids_lens: List[int] = []
234
- semantic_ids: List[torch.Tensor] = []
235
- semantic_ids_lens: List[int] = []
236
- # return
237
-
238
-
239
- for item in examples:
240
- sample_index.append(item["idx"])
241
- phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
242
- semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64))
243
- phoneme_ids_lens.append(item["phoneme_ids_len"])
244
- semantic_ids_lens.append(item["semantic_ids_len"])
245
-
246
- # pad 0
247
- phoneme_ids = batch_sequences(phoneme_ids)
248
- semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD)
249
-
250
- # # convert each batch to torch.tensor
251
- phoneme_ids = torch.tensor(phoneme_ids)
252
- semantic_ids = torch.tensor(semantic_ids)
253
- phoneme_ids_lens = torch.tensor(phoneme_ids_lens)
254
- semantic_ids_lens = torch.tensor(semantic_ids_lens)
255
- bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens))
256
- bert_padded.zero_()
257
-
258
- for idx, item in enumerate(examples):
259
- bert = item['bert_feature']
260
- if(bert!=None):
261
- bert_padded[idx, :, :bert.shape[-1]] = bert
262
-
263
- return {
264
- # List[int]
265
- "ids": sample_index,
266
- # torch.Tensor (B, max_phoneme_length)
267
- "phoneme_ids": phoneme_ids,
268
- # torch.Tensor (B)
269
- "phoneme_ids_len": phoneme_ids_lens,
270
- # torch.Tensor (B, max_semantic_ids_length)
271
- "semantic_ids": semantic_ids,
272
- # torch.Tensor (B)
273
- "semantic_ids_len": semantic_ids_lens,
274
- # torch.Tensor (B, 1024, max_phoneme_length)
275
- "bert_feature": bert_padded,
276
- }
277
-
278
-
279
- if __name__ == '__main__':
280
- root_dir = '/data/docker/liujing04/gpt-vits/prepare/dump_mix/'
281
- dataset = Text2SemanticDataset(
282
- phoneme_path=root_dir + 'phoneme_train.npy',
283
- semantic_path=root_dir + 'semantic_train.tsv')
284
-
285
- batch_size = 12
286
- dataloader = DataLoader(
287
- dataset,
288
- batch_size=batch_size,
289
- collate_fn=dataset.collate,
290
- shuffle=False)
291
- for i, batch in enumerate(dataloader):
292
- if(i%1000==0):print(i)
293
- # if i == 0:
294
- # print('batch["ids"]:', batch["ids"])
295
- # print('batch["phoneme_ids"]:', batch["phoneme_ids"],
296
- # batch["phoneme_ids"].shape)
297
- # print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
298
- # batch["phoneme_ids_len"].shape)
299
- # print('batch["semantic_ids"]:', batch["semantic_ids"],
300
- # batch["semantic_ids"].shape)
301
- # print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
302
- # batch["semantic_ids_len"].shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/__init__.py DELETED
File without changes
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/BEATs.py DELETED
@@ -1,179 +0,0 @@
1
- # --------------------------------------------------------
2
- # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
- # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
- # Copyright (c) 2022 Microsoft
5
- # Licensed under The MIT License [see LICENSE for details]
6
- # Based on fairseq code bases
7
- # https://github.com/pytorch/fairseq
8
- # --------------------------------------------------------
9
- import logging
10
- from typing import Optional
11
-
12
- import torch
13
- import torch.nn as nn
14
- import torchaudio.compliance.kaldi as ta_kaldi
15
- from torch.nn import LayerNorm
16
-
17
- from .backbone import TransformerEncoder
18
-
19
- logger = logging.getLogger(__name__)
20
-
21
-
22
- class BEATsConfig:
23
- def __init__(self, cfg=None):
24
- self.input_patch_size: int = -1 # path size of patch embedding
25
- self.embed_dim: int = 512 # patch embedding dimension
26
- self.conv_bias: bool = False # include bias in conv encoder
27
-
28
- self.encoder_layers: int = 12 # num encoder layers in the transformer
29
- self.encoder_embed_dim: int = 768 # encoder embedding dimension
30
- self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
31
- self.encoder_attention_heads: int = 12 # num encoder attention heads
32
- self.activation_fn: str = "gelu" # activation function to use
33
-
34
- self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay
35
- self.layer_norm_first: bool = False # apply layernorm first in the transformer
36
- self.deep_norm: bool = False # apply deep_norm first in the transformer
37
-
38
- # dropouts
39
- self.dropout: float = 0.1 # dropout probability for the transformer
40
- self.attention_dropout: float = 0.1 # dropout probability for attention weights
41
- self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
42
- self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
43
- self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
44
-
45
- # positional embeddings
46
- self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
47
- self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
48
-
49
- # relative position embedding
50
- self.relative_position_embedding: bool = False # apply relative position embedding
51
- self.num_buckets: int = 320 # number of buckets for relative position embedding
52
- self.max_distance: int = 1280 # maximum distance for relative position embedding
53
- self.gru_rel_pos: bool = False # apply gated relative position embedding
54
-
55
- # label predictor
56
- self.finetuned_model: bool = False # whether the model is a fine-tuned model.
57
- self.predictor_dropout: float = 0.1 # dropout probability for the predictor
58
- self.predictor_class: int = 527 # target class number for the predictor
59
-
60
- if cfg is not None:
61
- self.update(cfg)
62
-
63
- def update(self, cfg: dict):
64
- self.__dict__.update(cfg)
65
-
66
-
67
- class BEATs(nn.Module):
68
- def __init__(
69
- self,
70
- cfg: BEATsConfig, ) -> None:
71
- super().__init__()
72
- logger.info(f"BEATs Config: {cfg.__dict__}")
73
-
74
- self.cfg = cfg
75
-
76
- self.embed = cfg.embed_dim
77
- self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim)
78
- if self.embed != cfg.encoder_embed_dim else
79
- None)
80
-
81
- self.input_patch_size = cfg.input_patch_size
82
- self.patch_embedding = nn.Conv2d(
83
- 1,
84
- self.embed,
85
- kernel_size=self.input_patch_size,
86
- stride=self.input_patch_size,
87
- bias=cfg.conv_bias)
88
-
89
- self.dropout_input = nn.Dropout(cfg.dropout_input)
90
-
91
- assert not cfg.deep_norm or not cfg.layer_norm_first
92
- self.encoder = TransformerEncoder(cfg)
93
- self.layer_norm = LayerNorm(self.embed)
94
-
95
- if cfg.finetuned_model:
96
- self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
97
- self.predictor = nn.Linear(cfg.encoder_embed_dim,
98
- cfg.predictor_class)
99
- else:
100
- self.predictor = None
101
-
102
- def forward_padding_mask(
103
- self,
104
- features: torch.Tensor,
105
- padding_mask: torch.Tensor, ) -> torch.Tensor:
106
- extra = padding_mask.size(1) % features.size(1)
107
- if extra > 0:
108
- padding_mask = padding_mask[:, :-extra]
109
- padding_mask = padding_mask.view(
110
- padding_mask.size(0), features.size(1), -1)
111
- padding_mask = padding_mask.all(-1)
112
- return padding_mask
113
-
114
- def preprocess(
115
- self,
116
- source: torch.Tensor,
117
- fbank_mean: float=15.41663,
118
- fbank_std: float=6.55582, ) -> torch.Tensor:
119
- fbanks = []
120
- for waveform in source:
121
- waveform = waveform.unsqueeze(0) * 2**15
122
- fbank = ta_kaldi.fbank(
123
- waveform,
124
- num_mel_bins=128,
125
- sample_frequency=16000,
126
- frame_length=25,
127
- frame_shift=10)
128
- fbanks.append(fbank)
129
- fbank = torch.stack(fbanks, dim=0)
130
- fbank = (fbank - fbank_mean) / (2 * fbank_std)
131
- return fbank
132
-
133
- def extract_features(
134
- self,
135
- source: torch.Tensor,
136
- padding_mask: Optional[torch.Tensor]=None,
137
- fbank_mean: float=15.41663,
138
- fbank_std: float=6.55582, ):
139
- fbank = self.preprocess(
140
- source, fbank_mean=fbank_mean, fbank_std=fbank_std)
141
-
142
- if padding_mask is not None:
143
- padding_mask = self.forward_padding_mask(fbank, padding_mask)
144
-
145
- fbank = fbank.unsqueeze(1)
146
- features = self.patch_embedding(fbank)
147
- features = features.reshape(features.shape[0], features.shape[1], -1)
148
- features = features.transpose(1, 2)
149
- features = self.layer_norm(features)
150
-
151
- if padding_mask is not None:
152
- padding_mask = self.forward_padding_mask(features, padding_mask)
153
-
154
- if self.post_extract_proj is not None:
155
- features = self.post_extract_proj(features)
156
-
157
- x = self.dropout_input(features)
158
-
159
- x, layer_results = self.encoder(
160
- x,
161
- padding_mask=padding_mask, )
162
-
163
- if self.predictor is not None:
164
- x = self.predictor_dropout(x)
165
- logits = self.predictor(x)
166
-
167
- if padding_mask is not None and padding_mask.any():
168
- logits[padding_mask] = 0
169
- logits = logits.sum(dim=1)
170
- logits = logits / (~padding_mask).sum(
171
- dim=1).unsqueeze(-1).expand_as(logits)
172
- else:
173
- logits = logits.mean(dim=1)
174
-
175
- lprobs = torch.sigmoid(logits)
176
-
177
- return lprobs, padding_mask
178
- else:
179
- return x, padding_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/README.md DELETED
@@ -1,127 +0,0 @@
1
-
2
- # BEATs
3
-
4
- [**BEATs**](https://arxiv.org/abs/2212.09058): **Audio Pre-Training with Acoustic Tokenizers**
5
-
6
- Official PyTorch implementation and pretrained models of BEATs
7
-
8
- ## Pre-Trained and Fine-Tuned Tokenizers and Models
9
- Iterations | Tokenizer | Pre-Trained Model | AudioSet Fine-Tuned Model 1 | AudioSet Fine-Tuned Model 2
10
- |---|---|---|---|---
11
- Iter1 | Random Projection | [BEATs_iter1](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter1 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter1 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
12
- Iter2 | [Tokenizer_iter2](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter2](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter2 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter2 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
13
- Iter3 | [Tokenizer_iter3](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
14
- Iter3+ | [Tokenizer_iter3+ (AS20K)](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3_plus_AS20K.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3+ (AS20K)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS20K) (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS20K) (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
15
- Iter3+ | [Tokenizer_iter3+ (AS2M)](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3_plus_AS2M.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3+ (AS2M)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS2M) (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS2M) (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) |
16
-
17
-
18
- ### Load Tokenizers
19
-
20
- ```python
21
- import torch
22
- from Tokenizers import TokenizersConfig, Tokenizers
23
-
24
- # load the pre-trained checkpoints
25
- checkpoint = torch.load('/path/to/tokenizer.pt')
26
-
27
- cfg = TokenizersConfig(checkpoint['cfg'])
28
- BEATs_tokenizer = Tokenizers(cfg)
29
- BEATs_tokenizer.load_state_dict(checkpoint['model'])
30
- BEATs_tokenizer.eval()
31
-
32
- # tokenize the audio and generate the labels
33
- audio_input_16khz = torch.randn(1, 10000)
34
- padding_mask = torch.zeros(1, 10000).bool()
35
-
36
- labels = BEATs_tokenizer.extract_labels(audio_input_16khz, padding_mask=padding_mask)
37
- ```
38
-
39
-
40
- ### Load Pre-Trained Models
41
-
42
- ```python
43
- import torch
44
- from BEATs import BEATs, BEATsConfig
45
-
46
- # load the pre-trained checkpoints
47
- checkpoint = torch.load('/path/to/model.pt')
48
-
49
- cfg = BEATsConfig(checkpoint['cfg'])
50
- BEATs_model = BEATs(cfg)
51
- BEATs_model.load_state_dict(checkpoint['model'])
52
- BEATs_model.eval()
53
-
54
- # extract the the audio representation
55
- audio_input_16khz = torch.randn(1, 10000)
56
- padding_mask = torch.zeros(1, 10000).bool()
57
-
58
- representation = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]
59
- ```
60
-
61
-
62
- ### Load Fine-tuned Models
63
-
64
- ```python
65
- import torch
66
- from BEATs import BEATs, BEATsConfig
67
-
68
- # load the fine-tuned checkpoints
69
- checkpoint = torch.load('/path/to/model.pt')
70
-
71
- cfg = BEATsConfig(checkpoint['cfg'])
72
- BEATs_model = BEATs(cfg)
73
- BEATs_model.load_state_dict(checkpoint['model'])
74
- BEATs_model.eval()
75
-
76
- # predict the classification probability of each class
77
- audio_input_16khz = torch.randn(3, 10000)
78
- padding_mask = torch.zeros(3, 10000).bool()
79
-
80
- probs = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]
81
-
82
- for i, (top5_label_prob, top5_label_idx) in enumerate(zip(*probs.topk(k=5))):
83
- top5_label = [checkpoint['label_dict'][label_idx.item()] for label_idx in top5_label_idx]
84
- print(f'Top 5 predicted labels of the {i}th audio are {top5_label} with probability of {top5_label_prob}')
85
- ```
86
-
87
- ## Evaluation Results
88
-
89
- ### Comparing with the SOTA Single Models
90
- ![alt text](Evaluation_Results/Comparing_with_the_SOTA_Single_Models.png)
91
-
92
-
93
- ### Comparing with the SOTA Ensemble Models
94
- ![alt text](Evaluation_Results/Comparing_with_the_SOTA_Ensemble_Models.png)
95
-
96
-
97
- ### Comparing Different BEATS Tokenizers
98
- ![alt text](Evaluation_Results/Comparing_Different_BEATS_Tokenizers.png)
99
-
100
-
101
- ### Comparing Different Pre-Training Targets
102
- ![alt text](Evaluation_Results/Comparing_Different_Pre-Training_Targets.png)
103
-
104
-
105
- ## License
106
- This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
107
- Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq) and [VQGAN](https://github.com/CompVis/taming-transformers) project.
108
-
109
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
110
-
111
-
112
- ### Reference
113
- If you find our work is useful in your research, please cite the following paper:
114
- ``` latex
115
- @article{Chen2022beats,
116
- title = {BEATs: Audio Pre-Training with Acoustic Tokenizers},
117
- author = {Sanyuan Chen and Yu Wu and Chengyi Wang and Shujie Liu and Daniel Tompkins and Zhuo Chen and Furu Wei},
118
- eprint={2212.09058},
119
- archivePrefix={arXiv},
120
- year={2022}
121
- }
122
- ```
123
- ### Contact Information
124
-
125
- For help or issues using BEATs models, please submit a GitHub issue.
126
-
127
- For other communications related to BEATs, please contact Yu Wu (`[email protected]`).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/Tokenizers.py DELETED
@@ -1,172 +0,0 @@
1
- # --------------------------------------------------------
2
- # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
- # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
- # Copyright (c) 2022 Microsoft
5
- # Licensed under The MIT License [see LICENSE for details]
6
- # Based on fairseq code bases
7
- # https://github.com/pytorch/fairseq
8
- # --------------------------------------------------------
9
- import logging
10
- from typing import Optional
11
-
12
- import torch
13
- import torch.nn as nn
14
- import torchaudio.compliance.kaldi as ta_kaldi
15
- from backbone import (
16
- TransformerEncoder, )
17
- from quantizer import (
18
- NormEMAVectorQuantizer, )
19
- from torch.nn import LayerNorm
20
-
21
- logger = logging.getLogger(__name__)
22
-
23
-
24
- class TokenizersConfig:
25
- def __init__(self, cfg=None):
26
- self.input_patch_size: int = -1 # path size of patch embedding
27
- self.embed_dim: int = 512 # patch embedding dimension
28
- self.conv_bias: bool = False # include bias in conv encoder
29
-
30
- self.encoder_layers: int = 12 # num encoder layers in the transformer
31
- self.encoder_embed_dim: int = 768 # encoder embedding dimension
32
- self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
33
- self.encoder_attention_heads: int = 12 # num encoder attention heads
34
- self.activation_fn: str = "gelu" # activation function to use
35
-
36
- self.layer_norm_first: bool = False # apply layernorm first in the transformer
37
- self.deep_norm: bool = False # apply deep_norm first in the transformer
38
-
39
- # dropouts
40
- self.dropout: float = 0.1 # dropout probability for the transformer
41
- self.attention_dropout: float = 0.1 # dropout probability for attention weights
42
- self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
43
- self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
44
- self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
45
-
46
- # positional embeddings
47
- self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
48
- self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
49
-
50
- # relative position embedding
51
- self.relative_position_embedding: bool = False # apply relative position embedding
52
- self.num_buckets: int = 320 # number of buckets for relative position embedding
53
- self.max_distance: int = 1280 # maximum distance for relative position embedding
54
- self.gru_rel_pos: bool = False # apply gated relative position embedding
55
-
56
- # quantizer
57
- self.quant_n: int = 1024 # codebook number in quantizer
58
- self.quant_dim: int = 256 # codebook dimension in quantizer
59
-
60
- if cfg is not None:
61
- self.update(cfg)
62
-
63
- def update(self, cfg: dict):
64
- self.__dict__.update(cfg)
65
-
66
-
67
- class Tokenizers(nn.Module):
68
- def __init__(
69
- self,
70
- cfg: TokenizersConfig, ) -> None:
71
- super().__init__()
72
- logger.info(f"Tokenizers Config: {cfg.__dict__}")
73
-
74
- self.cfg = cfg
75
-
76
- self.embed = cfg.embed_dim
77
- self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim)
78
- if self.embed != cfg.encoder_embed_dim else
79
- None)
80
-
81
- self.input_patch_size = cfg.input_patch_size
82
- self.patch_embedding = nn.Conv2d(
83
- 1,
84
- self.embed,
85
- kernel_size=self.input_patch_size,
86
- stride=self.input_patch_size,
87
- bias=cfg.conv_bias)
88
-
89
- self.dropout_input = nn.Dropout(cfg.dropout_input)
90
-
91
- assert not cfg.deep_norm or not cfg.layer_norm_first
92
- self.encoder = TransformerEncoder(cfg)
93
- self.layer_norm = LayerNorm(self.embed)
94
-
95
- self.quantize = NormEMAVectorQuantizer(
96
- n_embed=cfg.quant_n,
97
- embedding_dim=cfg.quant_dim,
98
- beta=1.0,
99
- kmeans_init=True,
100
- decay=0.99, )
101
- self.quant_n = cfg.quant_n
102
- self.quantize_layer = nn.Sequential(
103
- nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
104
- nn.Tanh(),
105
- nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize
106
- )
107
-
108
- def forward_padding_mask(
109
- self,
110
- features: torch.Tensor,
111
- padding_mask: torch.Tensor, ) -> torch.Tensor:
112
- extra = padding_mask.size(1) % features.size(1)
113
- if extra > 0:
114
- padding_mask = padding_mask[:, :-extra]
115
- padding_mask = padding_mask.view(
116
- padding_mask.size(0), features.size(1), -1)
117
- padding_mask = padding_mask.all(-1)
118
- return padding_mask
119
-
120
- def preprocess(
121
- self,
122
- source: torch.Tensor,
123
- fbank_mean: float=15.41663,
124
- fbank_std: float=6.55582, ) -> torch.Tensor:
125
- fbanks = []
126
- for waveform in source:
127
- waveform = waveform.unsqueeze(0) * 2**15
128
- fbank = ta_kaldi.fbank(
129
- waveform,
130
- num_mel_bins=128,
131
- sample_frequency=16000,
132
- frame_length=25,
133
- frame_shift=10)
134
- fbanks.append(fbank)
135
- fbank = torch.stack(fbanks, dim=0)
136
- fbank = (fbank - fbank_mean) / (2 * fbank_std)
137
- return fbank
138
-
139
- def extract_labels(
140
- self,
141
- source: torch.Tensor,
142
- padding_mask: Optional[torch.Tensor]=None,
143
- fbank_mean: float=15.41663,
144
- fbank_std: float=6.55582, ):
145
- fbank = self.preprocess(
146
- source, fbank_mean=fbank_mean, fbank_std=fbank_std)
147
-
148
- if padding_mask is not None:
149
- padding_mask = self.forward_padding_mask(fbank, padding_mask)
150
-
151
- fbank = fbank.unsqueeze(1)
152
- features = self.patch_embedding(fbank)
153
- features = features.reshape(features.shape[0], features.shape[1], -1)
154
- features = features.transpose(1, 2)
155
- features = self.layer_norm(features)
156
-
157
- if padding_mask is not None:
158
- padding_mask = self.forward_padding_mask(features, padding_mask)
159
-
160
- if self.post_extract_proj is not None:
161
- features = self.post_extract_proj(features)
162
-
163
- x = self.dropout_input(features)
164
-
165
- x, layer_results = self.encoder(
166
- x,
167
- padding_mask=padding_mask, )
168
-
169
- quantize_input = self.quantize_layer(x)
170
- quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
171
-
172
- return embed_ind
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- # this folder is modified from https://github.com/microsoft/unilm/tree/master/beats
2
- # ontology.json is from https://github.com/audioset/ontology/
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/backbone.py DELETED
@@ -1,791 +0,0 @@
1
- # --------------------------------------------------------
2
- # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
- # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
- # Copyright (c) 2022 Microsoft
5
- # Licensed under The MIT License [see LICENSE for details]
6
- # Based on fairseq code bases
7
- # https://github.com/pytorch/fairseq
8
- # --------------------------------------------------------
9
- import math
10
- from typing import Dict
11
- from typing import Optional
12
- from typing import Tuple
13
-
14
- import numpy as np
15
- import torch
16
- import torch.nn.functional as F
17
- from torch import nn
18
- from torch import Tensor
19
- from torch.nn import LayerNorm
20
- from torch.nn import Parameter
21
-
22
- from .modules import get_activation_fn
23
- from .modules import GLU_Linear
24
- from .modules import GradMultiply
25
- from .modules import quant_noise
26
- from .modules import SamePad
27
-
28
-
29
- class TransformerEncoder(nn.Module):
30
- def __init__(self, args):
31
- super().__init__()
32
-
33
- self.dropout = args.dropout
34
- self.embedding_dim = args.encoder_embed_dim
35
-
36
- self.pos_conv = nn.Conv1d(
37
- self.embedding_dim,
38
- self.embedding_dim,
39
- kernel_size=args.conv_pos,
40
- padding=args.conv_pos // 2,
41
- groups=args.conv_pos_groups, )
42
- dropout = 0
43
- std = math.sqrt(
44
- (4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
45
- nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
46
- nn.init.constant_(self.pos_conv.bias, 0)
47
-
48
- self.pos_conv = nn.utils.weight_norm(
49
- self.pos_conv, name="weight", dim=2)
50
- self.pos_conv = nn.Sequential(self.pos_conv,
51
- SamePad(args.conv_pos), nn.GELU())
52
-
53
- if hasattr(args, "relative_position_embedding"):
54
- self.relative_position_embedding = args.relative_position_embedding
55
- self.num_buckets = args.num_buckets
56
- self.max_distance = args.max_distance
57
- else:
58
- self.relative_position_embedding = False
59
- self.num_buckets = 0
60
- self.max_distance = 0
61
-
62
- self.layers = nn.ModuleList([
63
- TransformerSentenceEncoderLayer(
64
- embedding_dim=self.embedding_dim,
65
- ffn_embedding_dim=args.encoder_ffn_embed_dim,
66
- num_attention_heads=args.encoder_attention_heads,
67
- dropout=self.dropout,
68
- attention_dropout=args.attention_dropout,
69
- activation_dropout=args.activation_dropout,
70
- activation_fn=args.activation_fn,
71
- layer_norm_first=args.layer_norm_first,
72
- deep_norm=args.deep_norm,
73
- has_relative_attention_bias=self.relative_position_embedding,
74
- num_buckets=self.num_buckets,
75
- max_distance=self.max_distance,
76
- gru_rel_pos=args.gru_rel_pos,
77
- encoder_layers=args.encoder_layers, )
78
- for i in range(args.encoder_layers)
79
- ])
80
- if self.relative_position_embedding:
81
- for i in range(1, args.encoder_layers):
82
- del self.layers[i].self_attn.relative_attention_bias
83
- self.layers[i].self_attn.relative_attention_bias = self.layers[
84
- 0].self_attn.relative_attention_bias
85
-
86
- self.layer_norm_first = args.layer_norm_first
87
- self.layer_norm = LayerNorm(self.embedding_dim)
88
- self.layerdrop = args.encoder_layerdrop
89
-
90
- self.apply(init_bert_params)
91
-
92
- if args.deep_norm:
93
- deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
94
- for i in range(args.encoder_layers):
95
- nn.init.xavier_normal_(
96
- self.layers[i].self_attn.k_proj.weight, gain=1)
97
- nn.init.xavier_normal_(
98
- self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
99
- nn.init.xavier_normal_(
100
- self.layers[i].self_attn.q_proj.weight, gain=1)
101
- nn.init.xavier_normal_(
102
- self.layers[i].self_attn.out_proj.weight,
103
- gain=deep_norm_beta)
104
- nn.init.xavier_normal_(
105
- self.layers[i].fc1.weight, gain=deep_norm_beta)
106
- nn.init.xavier_normal_(
107
- self.layers[i].fc2.weight, gain=deep_norm_beta)
108
-
109
- self.layer_wise_gradient_decay_ratio = getattr(
110
- args, "layer_wise_gradient_decay_ratio", 1)
111
-
112
- def forward(self, x, padding_mask=None, layer=None):
113
- x, layer_results = self.extract_features(x, padding_mask, layer)
114
-
115
- if self.layer_norm_first and layer is None:
116
- x = self.layer_norm(x)
117
-
118
- return x, layer_results
119
-
120
- def extract_features(self, x, padding_mask=None, tgt_layer=None):
121
-
122
- if padding_mask is not None:
123
- x[padding_mask] = 0
124
-
125
- x_conv = self.pos_conv(x.transpose(1, 2))
126
- x_conv = x_conv.transpose(1, 2)
127
- x = x + x_conv
128
-
129
- if not self.layer_norm_first:
130
- x = self.layer_norm(x)
131
-
132
- x = F.dropout(x, p=self.dropout, training=self.training)
133
-
134
- # B x T x C -> T x B x C
135
- x = x.transpose(0, 1)
136
-
137
- layer_results = []
138
- z = None
139
- if tgt_layer is not None:
140
- layer_results.append((x, z))
141
- r = None
142
- pos_bias = None
143
- for i, layer in enumerate(self.layers):
144
- if self.layer_wise_gradient_decay_ratio != 1.0:
145
- x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
146
- dropout_probability = np.random.random()
147
- if not self.training or (dropout_probability > self.layerdrop):
148
- x, z, pos_bias = layer(
149
- x,
150
- self_attn_padding_mask=padding_mask,
151
- need_weights=False,
152
- pos_bias=pos_bias)
153
- if tgt_layer is not None:
154
- layer_results.append((x, z))
155
- if i == tgt_layer:
156
- r = x
157
- break
158
-
159
- if r is not None:
160
- x = r
161
-
162
- # T x B x C -> B x T x C
163
- x = x.transpose(0, 1)
164
-
165
- return x, layer_results
166
-
167
-
168
- class TransformerSentenceEncoderLayer(nn.Module):
169
- def __init__(
170
- self,
171
- embedding_dim: float=768,
172
- ffn_embedding_dim: float=3072,
173
- num_attention_heads: float=8,
174
- dropout: float=0.1,
175
- attention_dropout: float=0.1,
176
- activation_dropout: float=0.1,
177
- activation_fn: str="relu",
178
- layer_norm_first: bool=False,
179
- deep_norm: bool=False,
180
- has_relative_attention_bias: bool=False,
181
- num_buckets: int=0,
182
- max_distance: int=0,
183
- rescale_init: bool=False,
184
- gru_rel_pos: bool=False,
185
- encoder_layers: int=0, ) -> None:
186
-
187
- super().__init__()
188
- self.embedding_dim = embedding_dim
189
- self.dropout = dropout
190
- self.activation_dropout = activation_dropout
191
-
192
- self.activation_name = activation_fn
193
- self.activation_fn = get_activation_fn(activation_fn)
194
- self.self_attn = MultiheadAttention(
195
- self.embedding_dim,
196
- num_attention_heads,
197
- dropout=attention_dropout,
198
- self_attention=True,
199
- has_relative_attention_bias=has_relative_attention_bias,
200
- num_buckets=num_buckets,
201
- max_distance=max_distance,
202
- rescale_init=rescale_init,
203
- gru_rel_pos=gru_rel_pos, )
204
-
205
- self.dropout1 = nn.Dropout(dropout)
206
- self.dropout2 = nn.Dropout(self.activation_dropout)
207
- self.dropout3 = nn.Dropout(dropout)
208
-
209
- self.layer_norm_first = layer_norm_first
210
-
211
- self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
212
-
213
- if self.activation_name == "glu":
214
- self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim,
215
- "swish")
216
- else:
217
- self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
218
- self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
219
-
220
- self.final_layer_norm = LayerNorm(self.embedding_dim)
221
-
222
- self.deep_norm = deep_norm
223
- if self.deep_norm:
224
- self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
225
- else:
226
- self.deep_norm_alpha = 1
227
-
228
- def forward(self,
229
- x: torch.Tensor,
230
- self_attn_mask: torch.Tensor=None,
231
- self_attn_padding_mask: torch.Tensor=None,
232
- need_weights: bool=False,
233
- pos_bias=None):
234
- residual = x
235
-
236
- if self.layer_norm_first:
237
- x = self.self_attn_layer_norm(x)
238
- x, attn, pos_bias = self.self_attn(
239
- query=x,
240
- key=x,
241
- value=x,
242
- key_padding_mask=self_attn_padding_mask,
243
- need_weights=False,
244
- attn_mask=self_attn_mask,
245
- position_bias=pos_bias)
246
- x = self.dropout1(x)
247
- x = residual + x
248
-
249
- residual = x
250
- x = self.final_layer_norm(x)
251
- if self.activation_name == "glu":
252
- x = self.fc1(x)
253
- else:
254
- x = self.activation_fn(self.fc1(x))
255
- x = self.dropout2(x)
256
- x = self.fc2(x)
257
- x = self.dropout3(x)
258
- x = residual + x
259
- else:
260
- x, attn, pos_bias = self.self_attn(
261
- query=x,
262
- key=x,
263
- value=x,
264
- key_padding_mask=self_attn_padding_mask,
265
- need_weights=need_weights,
266
- attn_mask=self_attn_mask,
267
- position_bias=pos_bias)
268
-
269
- x = self.dropout1(x)
270
- x = residual * self.deep_norm_alpha + x
271
-
272
- x = self.self_attn_layer_norm(x)
273
-
274
- residual = x
275
- if self.activation_name == "glu":
276
- x = self.fc1(x)
277
- else:
278
- x = self.activation_fn(self.fc1(x))
279
- x = self.dropout2(x)
280
- x = self.fc2(x)
281
- x = self.dropout3(x)
282
- x = residual * self.deep_norm_alpha + x
283
- x = self.final_layer_norm(x)
284
-
285
- return x, attn, pos_bias
286
-
287
-
288
- class MultiheadAttention(nn.Module):
289
- """Multi-headed attention.
290
-
291
- See "Attention Is All You Need" for more details.
292
- """
293
-
294
- def __init__(
295
- self,
296
- embed_dim,
297
- num_heads,
298
- kdim=None,
299
- vdim=None,
300
- dropout=0.0,
301
- bias=True,
302
- add_bias_kv=False,
303
- add_zero_attn=False,
304
- self_attention=False,
305
- encoder_decoder_attention=False,
306
- q_noise=0.0,
307
- qn_block_size=8,
308
- has_relative_attention_bias=False,
309
- num_buckets=32,
310
- max_distance=128,
311
- gru_rel_pos=False,
312
- rescale_init=False, ):
313
- super().__init__()
314
- self.embed_dim = embed_dim
315
- self.kdim = kdim if kdim is not None else embed_dim
316
- self.vdim = vdim if vdim is not None else embed_dim
317
- self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
318
-
319
- self.num_heads = num_heads
320
- self.dropout_module = nn.Dropout(dropout)
321
-
322
- self.has_relative_attention_bias = has_relative_attention_bias
323
- self.num_buckets = num_buckets
324
- self.max_distance = max_distance
325
- if self.has_relative_attention_bias:
326
- self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
327
-
328
- self.head_dim = embed_dim // num_heads
329
- self.q_head_dim = self.head_dim
330
- self.k_head_dim = self.head_dim
331
- assert (self.head_dim * num_heads == self.embed_dim
332
- ), "embed_dim must be divisible by num_heads"
333
- self.scaling = self.head_dim**-0.5
334
-
335
- self.self_attention = self_attention
336
- self.encoder_decoder_attention = encoder_decoder_attention
337
-
338
- assert not self.self_attention or self.qkv_same_dim, (
339
- "Self-attention requires query, key and "
340
- "value to be of the same size")
341
-
342
- k_bias = True
343
- if rescale_init:
344
- k_bias = False
345
-
346
- k_embed_dim = embed_dim
347
- q_embed_dim = embed_dim
348
-
349
- self.k_proj = quant_noise(
350
- nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise,
351
- qn_block_size)
352
- self.v_proj = quant_noise(
353
- nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
354
- self.q_proj = quant_noise(
355
- nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise,
356
- qn_block_size)
357
-
358
- self.out_proj = quant_noise(
359
- nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
360
-
361
- if add_bias_kv:
362
- self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
363
- self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
364
- else:
365
- self.bias_k = self.bias_v = None
366
-
367
- self.add_zero_attn = add_zero_attn
368
-
369
- self.gru_rel_pos = gru_rel_pos
370
- if self.gru_rel_pos:
371
- self.grep_linear = nn.Linear(self.q_head_dim, 8)
372
- self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
373
-
374
- self.reset_parameters()
375
-
376
- def reset_parameters(self):
377
- if self.qkv_same_dim:
378
- # Empirically observed the convergence to be much better with
379
- # the scaled initialization
380
- nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
381
- nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
382
- nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
383
- else:
384
- nn.init.xavier_uniform_(self.k_proj.weight)
385
- nn.init.xavier_uniform_(self.v_proj.weight)
386
- nn.init.xavier_uniform_(self.q_proj.weight)
387
-
388
- nn.init.xavier_uniform_(self.out_proj.weight)
389
- if self.out_proj.bias is not None:
390
- nn.init.constant_(self.out_proj.bias, 0.0)
391
- if self.bias_k is not None:
392
- nn.init.xavier_normal_(self.bias_k)
393
- if self.bias_v is not None:
394
- nn.init.xavier_normal_(self.bias_v)
395
- if self.has_relative_attention_bias:
396
- nn.init.xavier_normal_(self.relative_attention_bias.weight)
397
-
398
- def _relative_positions_bucket(self, relative_positions,
399
- bidirectional=True):
400
- num_buckets = self.num_buckets
401
- max_distance = self.max_distance
402
- relative_buckets = 0
403
-
404
- if bidirectional:
405
- num_buckets = num_buckets // 2
406
- relative_buckets += (
407
- relative_positions > 0).to(torch.long) * num_buckets
408
- relative_positions = torch.abs(relative_positions)
409
- else:
410
- relative_positions = -torch.min(
411
- relative_positions, torch.zeros_like(relative_positions))
412
-
413
- max_exact = num_buckets // 2
414
- is_small = relative_positions < max_exact
415
-
416
- relative_postion_if_large = max_exact + (
417
- torch.log(relative_positions.float() / max_exact) / math.log(
418
- max_distance / max_exact) *
419
- (num_buckets - max_exact)).to(torch.long)
420
- relative_postion_if_large = torch.min(
421
- relative_postion_if_large,
422
- torch.full_like(relative_postion_if_large, num_buckets - 1))
423
-
424
- relative_buckets += torch.where(is_small, relative_positions,
425
- relative_postion_if_large)
426
- return relative_buckets
427
-
428
- def compute_bias(self, query_length, key_length):
429
- context_position = torch.arange(query_length, dtype=torch.long)[:, None]
430
- memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
431
- relative_position = memory_position - context_position
432
- relative_position_bucket = self._relative_positions_bucket(
433
- relative_position, bidirectional=True)
434
- relative_position_bucket = relative_position_bucket.to(
435
- self.relative_attention_bias.weight.device)
436
- values = self.relative_attention_bias(relative_position_bucket)
437
- values = values.permute([2, 0, 1])
438
- return values
439
-
440
- def forward(self,
441
- query,
442
- key: Optional[Tensor],
443
- value: Optional[Tensor],
444
- key_padding_mask: Optional[Tensor]=None,
445
- incremental_state: Optional[Dict[str, Dict[str, Optional[
446
- Tensor]]]]=None,
447
- need_weights: bool=True,
448
- static_kv: bool=False,
449
- attn_mask: Optional[Tensor]=None,
450
- before_softmax: bool=False,
451
- need_head_weights: bool=False,
452
- position_bias: Optional[Tensor]=None
453
- ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
454
- """Input shape: Time x Batch x Channel
455
-
456
- Args:
457
- key_padding_mask (ByteTensor, optional): mask to exclude
458
- keys that are pads, of shape `(batch, src_len)`, where
459
- padding elements are indicated by 1s.
460
- need_weights (bool, optional): return the attention weights,
461
- averaged over heads (default: False).
462
- attn_mask (ByteTensor, optional): typically used to
463
- implement causal attention, where the mask prevents the
464
- attention from looking forward in time (default: None).
465
- before_softmax (bool, optional): return the raw attention
466
- weights and values before the attention softmax.
467
- need_head_weights (bool, optional): return the attention
468
- weights for each head. Implies *need_weights*. Default:
469
- return the average attention weights over all heads.
470
- """
471
- if need_head_weights:
472
- need_weights = True
473
-
474
- is_tpu = query.device.type == "xla"
475
-
476
- tgt_len, bsz, embed_dim = query.size()
477
- src_len = tgt_len
478
- assert embed_dim == self.embed_dim
479
- assert list(query.size()) == [tgt_len, bsz, embed_dim]
480
- if key is not None:
481
- src_len, key_bsz, _ = key.size()
482
- if not torch.jit.is_scripting():
483
- assert key_bsz == bsz
484
- assert value is not None
485
- assert src_len, bsz == value.shape[:2]
486
-
487
- if self.has_relative_attention_bias and position_bias is None:
488
- position_bias = self.compute_bias(tgt_len, src_len)
489
- position_bias = position_bias.unsqueeze(0).repeat(
490
- bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
491
-
492
- if incremental_state is not None:
493
- saved_state = self._get_input_buffer(incremental_state)
494
- if saved_state is not None and "prev_key" in saved_state:
495
- # previous time steps are cached - no need to recompute
496
- # key and value if they are static
497
- if static_kv:
498
- assert self.encoder_decoder_attention and not self.self_attention
499
- key = value = None
500
- else:
501
- saved_state = None
502
-
503
- if self.self_attention:
504
- q = self.q_proj(query)
505
- k = self.k_proj(query)
506
- v = self.v_proj(query)
507
- elif self.encoder_decoder_attention:
508
- # encoder-decoder attention
509
- q = self.q_proj(query)
510
- if key is None:
511
- assert value is None
512
- k = v = None
513
- else:
514
- k = self.k_proj(key)
515
- v = self.v_proj(key)
516
-
517
- else:
518
- assert key is not None and value is not None
519
- q = self.q_proj(query)
520
- k = self.k_proj(key)
521
- v = self.v_proj(value)
522
- q *= self.scaling
523
- alpha = 32
524
- q *= 1 / alpha
525
-
526
- if self.bias_k is not None:
527
- assert self.bias_v is not None
528
- k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
529
- v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
530
- if attn_mask is not None:
531
- attn_mask = torch.cat(
532
- [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)],
533
- dim=1)
534
- if key_padding_mask is not None:
535
- key_padding_mask = torch.cat(
536
- [
537
- key_padding_mask,
538
- key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
539
- ],
540
- dim=1, )
541
-
542
- q = (q.contiguous().view(tgt_len, bsz * self.num_heads, self.q_head_dim)
543
- .transpose(0, 1))
544
- if k is not None:
545
- k = (k.contiguous().view(-1, bsz * self.num_heads, self.k_head_dim)
546
- .transpose(0, 1))
547
- if v is not None:
548
- v = (v.contiguous().view(-1, bsz * self.num_heads, self.head_dim)
549
- .transpose(0, 1))
550
-
551
- if saved_state is not None:
552
- # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
553
- if "prev_key" in saved_state:
554
- _prev_key = saved_state["prev_key"]
555
- assert _prev_key is not None
556
- prev_key = _prev_key.view(bsz * self.num_heads, -1,
557
- self.head_dim)
558
- if static_kv:
559
- k = prev_key
560
- else:
561
- assert k is not None
562
- k = torch.cat([prev_key, k], dim=1)
563
- src_len = k.size(1)
564
- if "prev_value" in saved_state:
565
- _prev_value = saved_state["prev_value"]
566
- assert _prev_value is not None
567
- prev_value = _prev_value.view(bsz * self.num_heads, -1,
568
- self.head_dim)
569
- if static_kv:
570
- v = prev_value
571
- else:
572
- assert v is not None
573
- v = torch.cat([prev_value, v], dim=1)
574
- prev_key_padding_mask: Optional[Tensor] = None
575
- if "prev_key_padding_mask" in saved_state:
576
- prev_key_padding_mask = saved_state["prev_key_padding_mask"]
577
- assert k is not None and v is not None
578
- key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
579
- key_padding_mask=key_padding_mask,
580
- prev_key_padding_mask=prev_key_padding_mask,
581
- batch_size=bsz,
582
- src_len=k.size(1),
583
- static_kv=static_kv, )
584
-
585
- saved_state["prev_key"] = k.view(bsz, self.num_heads, -1,
586
- self.head_dim)
587
- saved_state["prev_value"] = v.view(bsz, self.num_heads, -1,
588
- self.head_dim)
589
- saved_state["prev_key_padding_mask"] = key_padding_mask
590
- # In this branch incremental_state is never None
591
- assert incremental_state is not None
592
- incremental_state = self._set_input_buffer(incremental_state,
593
- saved_state)
594
- assert k is not None
595
- assert k.size(1) == src_len
596
-
597
- # This is part of a workaround to get around fork/join parallelism
598
- # not supporting Optional types.
599
- if key_padding_mask is not None and key_padding_mask.dim() == 0:
600
- key_padding_mask = None
601
-
602
- if key_padding_mask is not None:
603
- assert key_padding_mask.size(0) == bsz
604
- assert key_padding_mask.size(1) == src_len
605
-
606
- if self.add_zero_attn:
607
- assert v is not None
608
- src_len += 1
609
- k = torch.cat(
610
- [k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
611
- v = torch.cat(
612
- [v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
613
- if attn_mask is not None:
614
- attn_mask = torch.cat(
615
- [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)],
616
- dim=1)
617
- if key_padding_mask is not None:
618
- key_padding_mask = torch.cat(
619
- [
620
- key_padding_mask,
621
- torch.zeros(key_padding_mask.size(0),
622
- 1).type_as(key_padding_mask),
623
- ],
624
- dim=1, )
625
-
626
- attn_weights = torch.bmm(q, k.transpose(1, 2))
627
- attn_weights = (
628
- attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
629
- attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len,
630
- bsz)
631
-
632
- assert list(
633
- attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
634
-
635
- if attn_mask is not None:
636
- attn_mask = attn_mask.unsqueeze(0)
637
- attn_weights += attn_mask
638
-
639
- if key_padding_mask is not None:
640
- # don't attend to padding symbols
641
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
642
- src_len)
643
- if not is_tpu:
644
- attn_weights = attn_weights.masked_fill(
645
- key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
646
- float("-inf"), )
647
- else:
648
- attn_weights = attn_weights.transpose(0, 2)
649
- attn_weights = attn_weights.masked_fill(key_padding_mask,
650
- float("-inf"))
651
- attn_weights = attn_weights.transpose(0, 2)
652
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
653
- src_len)
654
-
655
- if before_softmax:
656
- return attn_weights, v, position_bias
657
-
658
- if position_bias is not None:
659
- attn_mask_rel_pos = position_bias
660
- if self.gru_rel_pos == 1:
661
- query_layer = q.view(bsz, self.num_heads, tgt_len,
662
- self.q_head_dim) * alpha / self.scaling
663
- _B, _H, _L, __ = query_layer.size()
664
- gate_a, gate_b = torch.sigmoid(
665
- self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(
666
- -1, keepdim=False)).chunk(
667
- 2, dim=-1)
668
- gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
669
- attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len,
670
- 1) * position_bias
671
-
672
- attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
673
-
674
- attn_weights = attn_weights + attn_mask_rel_pos
675
-
676
- attn_weights_float = F.softmax(attn_weights, dim=-1)
677
- attn_weights = attn_weights_float.type_as(attn_weights)
678
- attn_probs = self.dropout_module(attn_weights)
679
-
680
- assert v is not None
681
- attn = torch.bmm(attn_probs, v)
682
- assert list(
683
- attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
684
- attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
685
- attn = self.out_proj(attn)
686
- attn_weights: Optional[Tensor] = None
687
- if need_weights:
688
- attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len,
689
- src_len).transpose(1, 0)
690
- if not need_head_weights:
691
- # average attention weights over heads
692
- attn_weights = attn_weights.mean(dim=0)
693
-
694
- return attn, attn_weights, position_bias
695
-
696
- @staticmethod
697
- def _append_prev_key_padding_mask(
698
- key_padding_mask: Optional[Tensor],
699
- prev_key_padding_mask: Optional[Tensor],
700
- batch_size: int,
701
- src_len: int,
702
- static_kv: bool, ) -> Optional[Tensor]:
703
- # saved key padding masks have shape (bsz, seq_len)
704
- if prev_key_padding_mask is not None and static_kv:
705
- new_key_padding_mask = prev_key_padding_mask
706
- elif prev_key_padding_mask is not None and key_padding_mask is not None:
707
- new_key_padding_mask = torch.cat(
708
- [prev_key_padding_mask.float(), key_padding_mask.float()],
709
- dim=1)
710
- # During incremental decoding, as the padding token enters and
711
- # leaves the frame, there will be a time when prev or current
712
- # is None
713
- elif prev_key_padding_mask is not None:
714
- if src_len > prev_key_padding_mask.size(1):
715
- filler = torch.zeros(
716
- (batch_size, src_len - prev_key_padding_mask.size(1)),
717
- device=prev_key_padding_mask.device, )
718
- new_key_padding_mask = torch.cat(
719
- [prev_key_padding_mask.float(), filler.float()], dim=1)
720
- else:
721
- new_key_padding_mask = prev_key_padding_mask.float()
722
- elif key_padding_mask is not None:
723
- if src_len > key_padding_mask.size(1):
724
- filler = torch.zeros(
725
- (batch_size, src_len - key_padding_mask.size(1)),
726
- device=key_padding_mask.device, )
727
- new_key_padding_mask = torch.cat(
728
- [filler.float(), key_padding_mask.float()], dim=1)
729
- else:
730
- new_key_padding_mask = key_padding_mask.float()
731
- else:
732
- new_key_padding_mask = prev_key_padding_mask
733
- return new_key_padding_mask
734
-
735
- def _get_input_buffer(
736
- self,
737
- incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
738
- ) -> Dict[str, Optional[Tensor]]:
739
- result = self.get_incremental_state(incremental_state, "attn_state")
740
- if result is not None:
741
- return result
742
- else:
743
- empty_result: Dict[str, Optional[Tensor]] = {}
744
- return empty_result
745
-
746
- def _set_input_buffer(
747
- self,
748
- incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
749
- buffer: Dict[str, Optional[Tensor]], ):
750
- return self.set_incremental_state(incremental_state, "attn_state",
751
- buffer)
752
-
753
- def apply_sparse_mask(self,
754
- attn_weights,
755
- tgt_len: int,
756
- src_len: int,
757
- bsz: int):
758
- return attn_weights
759
-
760
-
761
- def init_bert_params(module):
762
- """
763
- Initialize the weights specific to the BERT Model.
764
- This overrides the default initializations depending on the specified arguments.
765
- 1. If normal_init_linear_weights is set then weights of linear
766
- layer will be initialized using the normal distribution and
767
- bais will be set to the specified value.
768
- 2. If normal_init_embed_weights is set then weights of embedding
769
- layer will be initialized using the normal distribution.
770
- 3. If normal_init_proj_weights is set then weights of
771
- in_project_weight for MultiHeadAttention initialized using
772
- the normal distribution (to be validated).
773
- """
774
-
775
- def normal_(data):
776
- # with FSDP, module params will be on CUDA, so we cast them back to CPU
777
- # so that the RNG is consistent with and without FSDP
778
- data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
779
-
780
- if isinstance(module, nn.Linear):
781
- normal_(module.weight.data)
782
- if module.bias is not None:
783
- module.bias.data.zero_()
784
- if isinstance(module, nn.Embedding):
785
- normal_(module.weight.data)
786
- if module.padding_idx is not None:
787
- module.weight.data[module.padding_idx].zero_()
788
- if isinstance(module, MultiheadAttention):
789
- normal_(module.q_proj.weight.data)
790
- normal_(module.k_proj.weight.data)
791
- normal_(module.v_proj.weight.data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/config.py DELETED
@@ -1,19 +0,0 @@
1
- import json
2
- import os
3
-
4
- # 获取当前脚本的所在目录
5
- script_dir = os.path.dirname(os.path.abspath(__file__))
6
-
7
- # JSON 文件的文件名
8
- json_filename = "ontology.json"
9
-
10
- # 构建 JSON 文件的完整路径
11
- json_path = os.path.join(script_dir, json_filename)
12
-
13
- id_name_dict = {}
14
-
15
- with open(json_path, 'r') as f:
16
- json_items = json.load(f)
17
- # '/m/0dgw9r' -> 'Human sounds' and etc.
18
- for item in json_items:
19
- id_name_dict[item['id']] = item['name']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/modules.py DELETED
@@ -1,220 +0,0 @@
1
- # --------------------------------------------------------
2
- # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
- # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
- # Copyright (c) 2022 Microsoft
5
- # Licensed under The MIT License [see LICENSE for details]
6
- # Based on fairseq code bases
7
- # https://github.com/pytorch/fairseq
8
- # --------------------------------------------------------
9
- import math
10
- import warnings
11
-
12
- import torch
13
- import torch.nn.functional as F
14
- from torch import nn
15
-
16
-
17
- class GradMultiply(torch.autograd.Function):
18
- @staticmethod
19
- def forward(ctx, x, scale):
20
- ctx.scale = scale
21
- res = x.new(x)
22
- return res
23
-
24
- @staticmethod
25
- def backward(ctx, grad):
26
- return grad * ctx.scale, None
27
-
28
-
29
- class SamePad(nn.Module):
30
- def __init__(self, kernel_size, causal=False):
31
- super().__init__()
32
- if causal:
33
- self.remove = kernel_size - 1
34
- else:
35
- self.remove = 1 if kernel_size % 2 == 0 else 0
36
-
37
- def forward(self, x):
38
- if self.remove > 0:
39
- x = x[:, :, :-self.remove]
40
- return x
41
-
42
-
43
- class Swish(nn.Module):
44
- def __init__(self):
45
- super(Swish, self).__init__()
46
- self.act = torch.nn.Sigmoid()
47
-
48
- def forward(self, x):
49
- return x * self.act(x)
50
-
51
-
52
- class GLU_Linear(nn.Module):
53
- def __init__(self,
54
- input_dim,
55
- output_dim,
56
- glu_type="sigmoid",
57
- bias_in_glu=True):
58
- super(GLU_Linear, self).__init__()
59
-
60
- self.glu_type = glu_type
61
- self.output_dim = output_dim
62
-
63
- if glu_type == "sigmoid":
64
- self.glu_act = torch.nn.Sigmoid()
65
- elif glu_type == "swish":
66
- self.glu_act = Swish()
67
- elif glu_type == "relu":
68
- self.glu_act = torch.nn.ReLU()
69
- elif glu_type == "gelu":
70
- self.glu_act = torch.nn.GELU()
71
-
72
- if bias_in_glu:
73
- self.linear = nn.Linear(input_dim, output_dim * 2, True)
74
- else:
75
- self.linear = nn.Linear(input_dim, output_dim * 2, False)
76
-
77
- def forward(self, x):
78
- # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
79
- x = self.linear(x)
80
-
81
- if self.glu_type == "bilinear":
82
- x = (x[:, :, 0:self.output_dim] *
83
- x[:, :, self.output_dim:self.output_dim * 2])
84
- else:
85
- x = (x[:, :, 0:self.output_dim] *
86
- self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
87
-
88
- return x
89
-
90
-
91
- def gelu_accurate(x):
92
- if not hasattr(gelu_accurate, "_a"):
93
- gelu_accurate._a = math.sqrt(2 / math.pi)
94
- return (0.5 * x * (1 + torch.tanh(gelu_accurate._a *
95
- (x + 0.044715 * torch.pow(x, 3)))))
96
-
97
-
98
- def gelu(x: torch.Tensor) -> torch.Tensor:
99
- return torch.nn.functional.gelu(x.float()).type_as(x)
100
-
101
-
102
- def get_activation_fn(activation: str):
103
- """Returns the activation function corresponding to `activation`"""
104
-
105
- if activation == "relu":
106
- return F.relu
107
- elif activation == "gelu":
108
- return gelu
109
- elif activation == "gelu_fast":
110
- warnings.warn(
111
- "--activation-fn=gelu_fast has been renamed to gelu_accurate")
112
- return gelu_accurate
113
- elif activation == "gelu_accurate":
114
- return gelu_accurate
115
- elif activation == "tanh":
116
- return torch.tanh
117
- elif activation == "linear":
118
- return lambda x: x
119
- elif activation == "glu":
120
- return lambda x: x
121
- else:
122
- raise RuntimeError(
123
- "--activation-fn {} not supported".format(activation))
124
-
125
-
126
- def quant_noise(module, p, block_size):
127
- """
128
- Wraps modules and applies quantization noise to the weights for
129
- subsequent quantization with Iterative Product Quantization as
130
- described in "Training with Quantization Noise for Extreme Model Compression"
131
-
132
- Args:
133
- - module: nn.Module
134
- - p: amount of Quantization Noise
135
- - block_size: size of the blocks for subsequent quantization with iPQ
136
-
137
- Remarks:
138
- - Module weights must have the right sizes wrt the block size
139
- - Only Linear, Embedding and Conv2d modules are supported for the moment
140
- - For more detail on how to quantize by blocks with convolutional weights,
141
- see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
142
- - We implement the simplest form of noise here as stated in the paper
143
- which consists in randomly dropping blocks
144
- """
145
-
146
- # if no quantization noise, don't register hook
147
- if p <= 0:
148
- return module
149
-
150
- # supported modules
151
- assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
152
-
153
- # test whether module.weight has the right sizes wrt block_size
154
- is_conv = module.weight.ndim == 4
155
-
156
- # 2D matrix
157
- if not is_conv:
158
- assert (
159
- module.weight.size(1) %
160
- block_size == 0), "Input features must be a multiple of block sizes"
161
-
162
- # 4D matrix
163
- else:
164
- # 1x1 convolutions
165
- if module.kernel_size == (1, 1):
166
- assert (module.in_channels % block_size == 0
167
- ), "Input channels must be a multiple of block sizes"
168
- # regular convolutions
169
- else:
170
- k = module.kernel_size[0] * module.kernel_size[1]
171
- assert k % block_size == 0, "Kernel size must be a multiple of block size"
172
-
173
- def _forward_pre_hook(mod, input):
174
- # no noise for evaluation
175
- if mod.training:
176
- if not is_conv:
177
- # gather weight and sizes
178
- weight = mod.weight
179
- in_features = weight.size(1)
180
- out_features = weight.size(0)
181
-
182
- # split weight matrix into blocks and randomly drop selected blocks
183
- mask = torch.zeros(
184
- in_features // block_size * out_features,
185
- device=weight.device)
186
- mask.bernoulli_(p)
187
- mask = mask.repeat_interleave(block_size, -1).view(-1,
188
- in_features)
189
-
190
- else:
191
- # gather weight and sizes
192
- weight = mod.weight
193
- in_channels = mod.in_channels
194
- out_channels = mod.out_channels
195
-
196
- # split weight matrix into blocks and randomly drop selected blocks
197
- if mod.kernel_size == (1, 1):
198
- mask = torch.zeros(
199
- int(in_channels // block_size * out_channels),
200
- device=weight.device, )
201
- mask.bernoulli_(p)
202
- mask = mask.repeat_interleave(block_size, -1).view(
203
- -1, in_channels)
204
- else:
205
- mask = torch.zeros(
206
- weight.size(0), weight.size(1), device=weight.device)
207
- mask.bernoulli_(p)
208
- mask = (
209
- mask.unsqueeze(2).unsqueeze(3)
210
- .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]))
211
-
212
- # scale weights and apply mask
213
- mask = mask.to(
214
- torch.
215
- bool) # x.bool() is not currently supported in TorchScript
216
- s = 1 / (1 - p)
217
- mod.weight.data = s * weight.masked_fill(mask, 0)
218
-
219
- module.register_forward_pre_hook(_forward_pre_hook)
220
- return module
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/ontology.json DELETED
The diff for this file is too large to render. See raw diff
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/beats/quantizer.py DELETED
@@ -1,235 +0,0 @@
1
- # --------------------------------------------------------
2
- # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
- # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
- # Copyright (c) 2022 Microsoft
5
- # Licensed under The MIT License [see LICENSE for details]
6
- # Based on VQGAN code bases
7
- # https://github.com/CompVis/taming-transformers
8
- # --------------------------------------------------------'
9
- import torch
10
- import torch.distributed as distributed
11
- import torch.nn as nn
12
- import torch.nn.functional as F
13
-
14
- try:
15
- from einops import rearrange, repeat
16
- except ImportError:
17
- pass
18
-
19
-
20
- def l2norm(t):
21
- return F.normalize(t, p=2, dim=-1)
22
-
23
-
24
- def ema_inplace(moving_avg, new, decay):
25
- moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
26
-
27
-
28
- def sample_vectors(samples, num):
29
- num_samples, device = samples.shape[0], samples.device
30
-
31
- if num_samples >= num:
32
- indices = torch.randperm(num_samples, device=device)[:num]
33
- else:
34
- indices = torch.randint(0, num_samples, (num, ), device=device)
35
-
36
- return samples[indices]
37
-
38
-
39
- def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
40
- dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
41
-
42
- means = sample_vectors(samples, num_clusters)
43
-
44
- for _ in range(num_iters):
45
- if use_cosine_sim:
46
- dists = samples @ means.t()
47
- else:
48
- diffs = rearrange(samples, 'n d -> n () d') \
49
- - rearrange(means, 'c d -> () c d')
50
- dists = -(diffs**2).sum(dim=-1)
51
-
52
- buckets = dists.max(dim=-1).indices
53
- bins = torch.bincount(buckets, minlength=num_clusters)
54
- zero_mask = bins == 0
55
- bins_min_clamped = bins.masked_fill(zero_mask, 1)
56
-
57
- new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
58
- new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
59
- new_means = new_means / bins_min_clamped[..., None]
60
-
61
- if use_cosine_sim:
62
- new_means = l2norm(new_means)
63
-
64
- means = torch.where(zero_mask[..., None], means, new_means)
65
-
66
- return means, bins
67
-
68
-
69
- class EmbeddingEMA(nn.Module):
70
- def __init__(self,
71
- num_tokens,
72
- codebook_dim,
73
- decay=0.99,
74
- eps=1e-5,
75
- kmeans_init=True,
76
- codebook_init_path=''):
77
- super().__init__()
78
- self.num_tokens = num_tokens
79
- self.codebook_dim = codebook_dim
80
- self.decay = decay
81
- self.eps = eps
82
- if codebook_init_path == '':
83
- if not kmeans_init:
84
- weight = torch.randn(num_tokens, codebook_dim)
85
- weight = l2norm(weight)
86
- else:
87
- weight = torch.zeros(num_tokens, codebook_dim)
88
- self.register_buffer('initted', torch.Tensor([not kmeans_init]))
89
- else:
90
- print(f"load init codebook weight from {codebook_init_path}")
91
- codebook_ckpt_weight = torch.load(
92
- codebook_init_path, map_location='cpu')
93
- weight = codebook_ckpt_weight.clone()
94
- self.register_buffer('initted', torch.Tensor([True]))
95
-
96
- self.weight = nn.Parameter(weight, requires_grad=False)
97
- self.cluster_size = nn.Parameter(
98
- torch.zeros(num_tokens), requires_grad=False)
99
- self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
100
- # self.register_buffer('initted', torch.Tensor([not kmeans_init]))
101
- self.update = True
102
-
103
- @torch.jit.ignore
104
- def init_embed_(self, data):
105
- if self.initted:
106
- return
107
- print("Performing Kemans init for codebook")
108
- embed, cluster_size = kmeans(
109
- data, self.num_tokens, 10, use_cosine_sim=True)
110
- self.weight.data.copy_(embed)
111
- self.cluster_size.data.copy_(cluster_size)
112
- self.initted.data.copy_(torch.Tensor([True]))
113
-
114
- def forward(self, embed_id):
115
- return F.embedding(embed_id, self.weight)
116
-
117
- def cluster_size_ema_update(self, new_cluster_size):
118
- self.cluster_size.data.mul_(self.decay).add_(
119
- new_cluster_size, alpha=1 - self.decay)
120
-
121
- def embed_avg_ema_update(self, new_embed_avg):
122
- self.embed_avg.data.mul_(self.decay).add_(
123
- new_embed_avg, alpha=1 - self.decay)
124
-
125
- def weight_update(self, num_tokens):
126
- n = self.cluster_size.sum()
127
- smoothed_cluster_size = (
128
- (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n)
129
- # normalize embedding average with smoothed cluster size
130
- embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
131
- # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
132
- self.weight.data.copy_(embed_normalized)
133
-
134
-
135
- def norm_ema_inplace(moving_avg, new, decay):
136
- moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
137
- moving_avg.data.copy_(l2norm(moving_avg.data))
138
-
139
-
140
- class NormEMAVectorQuantizer(nn.Module):
141
- def __init__(self,
142
- n_embed,
143
- embedding_dim,
144
- beta,
145
- decay=0.99,
146
- eps=1e-5,
147
- statistic_code_usage=True,
148
- kmeans_init=False,
149
- codebook_init_path=''):
150
- super().__init__()
151
- self.codebook_dim = embedding_dim
152
- self.num_tokens = n_embed
153
- self.beta = beta
154
- self.decay = decay
155
-
156
- # learnable = True if orthogonal_reg_weight > 0 else False
157
- self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay,
158
- eps, kmeans_init, codebook_init_path)
159
-
160
- self.statistic_code_usage = statistic_code_usage
161
- if statistic_code_usage:
162
- self.register_buffer('cluster_size', torch.zeros(n_embed))
163
- if distributed.is_available() and distributed.is_initialized():
164
- print(
165
- "ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!"
166
- )
167
- self.all_reduce_fn = distributed.all_reduce
168
- else:
169
- self.all_reduce_fn = nn.Identity()
170
-
171
- def reset_cluster_size(self, device):
172
- if self.statistic_code_usage:
173
- self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
174
- self.cluster_size = self.cluster_size.to(device)
175
-
176
- def forward(self, z):
177
- # reshape z -> (batch, height, width, channel) and flatten
178
- # z, 'b c h w -> b h w c'
179
- # z = rearrange(z, 'b c h w -> b h w c')
180
- # z = z.transpose(1, 2)
181
- z = l2norm(z)
182
- z_flattened = z.reshape(-1, self.codebook_dim)
183
-
184
- self.embedding.init_embed_(z_flattened)
185
-
186
- d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
187
- self.embedding.weight.pow(2).sum(dim=1) - 2 * \
188
- torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
189
-
190
- encoding_indices = torch.argmin(d, dim=1)
191
-
192
- z_q = self.embedding(encoding_indices).view(z.shape)
193
-
194
- encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
195
-
196
- if not self.training:
197
- with torch.no_grad():
198
- cluster_size = encodings.sum(0)
199
- self.all_reduce_fn(cluster_size)
200
- ema_inplace(self.cluster_size, cluster_size, self.decay)
201
-
202
- if self.training and self.embedding.update:
203
- # EMA cluster size
204
-
205
- bins = encodings.sum(0)
206
- self.all_reduce_fn(bins)
207
-
208
- # self.embedding.cluster_size_ema_update(bins)
209
- ema_inplace(self.cluster_size, bins, self.decay)
210
-
211
- zero_mask = (bins == 0)
212
- bins = bins.masked_fill(zero_mask, 1.)
213
-
214
- embed_sum = z_flattened.t() @ encodings
215
- self.all_reduce_fn(embed_sum)
216
-
217
- embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
218
- embed_normalized = l2norm(embed_normalized)
219
-
220
- embed_normalized = torch.where(
221
- zero_mask[..., None], self.embedding.weight, embed_normalized)
222
- norm_ema_inplace(self.embedding.weight, embed_normalized,
223
- self.decay)
224
-
225
- # compute loss for embedding
226
- loss = self.beta * F.mse_loss(z_q.detach(), z)
227
-
228
- # preserve gradients
229
- z_q = z + (z_q - z).detach()
230
-
231
- # reshape back to match original input shape
232
- # z_q, 'b h w c -> b c h w'
233
- # z_q = rearrange(z_q, 'b h w c -> b c h w')
234
- # z_q = z_q.transpose(1, 2)
235
- return z_q, loss, encoding_indices
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_beats_librilight.py DELETED
@@ -1,321 +0,0 @@
1
- # Use AudioTag tool BEATs to filter out audios who's top1 tag is not 'speech'
2
- # non_speech.npy, 存储一个 python dict 表示非 speech 类型的音频的 tag, 更小,加载和搜索速度更快
3
- # audio_tag 目录存储 {utt_id}.txt, 第一行是小写的 top1 tag
4
- import argparse
5
- import os
6
- import time
7
- import traceback
8
- from concurrent.futures import ThreadPoolExecutor
9
- from pathlib import Path
10
-
11
- import librosa
12
- import numpy as np
13
- import torch
14
- import tqdm
15
- from AR.exps.beats.BEATs import BEATs
16
- from AR.exps.beats.BEATs import BEATsConfig
17
- from AR.exps.beats.config import id_name_dict
18
- from soundstorm.s2.exps.hubert.feature_utils import get_shard_range
19
- from soundstorm.utils import check_txt_file
20
-
21
-
22
- def get_BEATs_top1(wav,
23
- BEATs_model,
24
- BEATs_label_dict,
25
- device: str='cpu',
26
- topk: int=1):
27
- wav = torch.tensor(wav).unsqueeze(0).to(device)
28
- padding_mask = torch.zeros(wav.shape).bool().to(device)
29
- probs = BEATs_model.extract_features(wav, padding_mask=padding_mask)[0]
30
- # 单条推理
31
- probs = probs[0]
32
- topk_label_prob, topk_label_idx = probs.topk(k=topk)
33
- topk_label = [
34
- BEATs_label_dict[label_idx.item()] for label_idx in topk_label_idx
35
- ]
36
- topk_label_name = [id_name_dict[label] for label in topk_label]
37
- top1_label = topk_label_name[0]
38
- return top1_label
39
-
40
-
41
- def process_sentence(args,
42
- fp: Path,
43
- train_dump_dir: Path,
44
- dev_dump_dir: Path,
45
- test_dump_dir: Path,
46
- VAD_dict,
47
- BEATs_model,
48
- BEATs_label_dict,
49
- device: str='cpu'):
50
- utt_id = fp.stem
51
- sr = args.sr
52
- record = []
53
- train_audio_tag_dir = train_dump_dir / "audio_tag"
54
- train_audio_tag_dir.mkdir(parents=True, exist_ok=True)
55
-
56
- dev_audio_tag_dir = dev_dump_dir / "audio_tag"
57
- dev_audio_tag_dir.mkdir(parents=True, exist_ok=True)
58
-
59
- test_audio_tag_dir = test_dump_dir / "audio_tag"
60
- test_audio_tag_dir.mkdir(parents=True, exist_ok=True)
61
-
62
- try:
63
- # get info for path
64
- wav_path_list = str(fp).strip().split('/')
65
- sub_dataset, spk_id, book_name = wav_path_list[-4], wav_path_list[
66
- -3], wav_path_list[-2]
67
- wav_name = wav_path_list[-1][:-5]
68
- assert wav_name == utt_id
69
- # key_name for big wav
70
- key_name = f'{wav_name}#{sub_dataset}#{spk_id}#{book_name}'
71
- # 判断 VAD 字典中不存在该条音频信息的情况
72
- if key_name not in VAD_dict.keys():
73
- print(key_name, 'not in VAD_dict !')
74
- return record
75
- wav = None
76
- sorted_split_VAD_dict = sorted(VAD_dict[key_name].items())
77
- len_dict = len(sorted_split_VAD_dict)
78
- for index, item in enumerate(sorted_split_VAD_dict):
79
- split_name, value = item
80
- start, end = value
81
- # train | dev | test
82
- if index == len_dict - 1:
83
- subset = 'test'
84
- audio_tag_path = test_audio_tag_dir / (split_name + ".txt")
85
- elif index == len_dict - 2:
86
- subset = 'dev'
87
- audio_tag_path = dev_audio_tag_dir / (split_name + ".txt")
88
- else:
89
- subset = 'train'
90
- audio_tag_path = train_audio_tag_dir / (split_name + ".txt")
91
-
92
- if os.path.exists(audio_tag_path) and check_txt_file(
93
- audio_tag_path):
94
- # print(audio_tag_path, 'exits!')
95
- pass
96
- else:
97
- # 这里加判断保证在 sub wav 的循环中只 load 一次
98
- if wav is None:
99
- # load big wav
100
- # 在最外层 load 如果 sub wav 的特征都存在了就会白白消耗 load 的时间
101
- wav, _ = librosa.load(str(fp), sr=sr)
102
- sub_wav = wav[int(start * sr):int(end * sr)]
103
- audio_tag_top1 = get_BEATs_top1(
104
- wav=sub_wav,
105
- BEATs_model=BEATs_model,
106
- BEATs_label_dict=BEATs_label_dict,
107
- device=device)
108
-
109
- with open(audio_tag_path, 'w') as f:
110
- f.write(audio_tag_top1)
111
-
112
- sub_record = {
113
- "utt_id": split_name,
114
- "audio_tag_path": audio_tag_path,
115
- "subset": subset
116
- }
117
- # recodrd 变成 List of Dict
118
- record.append(sub_record)
119
- except Exception:
120
- print("occur Exception")
121
- traceback.print_exc()
122
- # record 有可能是一个不完整的 List
123
- return record
124
- return record
125
-
126
-
127
- def process_sentences(args,
128
- fps: Path,
129
- train_dump_dir: Path,
130
- dev_dump_dir: Path,
131
- test_dump_dir: Path,
132
- VAD_dict,
133
- BEATs_model,
134
- BEATs_label_dict,
135
- device: str='cpu',
136
- nprocs: int=1):
137
- print("nprocs:", nprocs)
138
- if nprocs == 1:
139
- results = []
140
- for fp in tqdm.tqdm(fps, total=len(fps)):
141
- record = process_sentence(
142
- args=args,
143
- fp=fp,
144
- train_dump_dir=train_dump_dir,
145
- dev_dump_dir=dev_dump_dir,
146
- test_dump_dir=test_dump_dir,
147
- VAD_dict=VAD_dict,
148
- BEATs_model=BEATs_model,
149
- BEATs_label_dict=BEATs_label_dict,
150
- device=device)
151
- if record:
152
- results.append(record)
153
- else:
154
- with ThreadPoolExecutor(nprocs) as pool:
155
- futures = []
156
- with tqdm.tqdm(total=len(fps)) as progress:
157
- for fp in fps:
158
- future = pool.submit(process_sentence, args, fp,
159
- train_dump_dir, dev_dump_dir,
160
- test_dump_dir, VAD_dict, BEATs_model,
161
- BEATs_label_dict, device)
162
- future.add_done_callback(lambda p: progress.update())
163
- futures.append(future)
164
-
165
- results = []
166
- for ft in futures:
167
- record = ft.result()
168
- if record:
169
- results.append(record)
170
-
171
- # torch.save() to a large `.pth` file
172
- non_speech_dict = dict()
173
- non_speech_dict['train'] = {}
174
- non_speech_dict['dev'] = {}
175
- non_speech_dict['test'] = {}
176
- # record 是 List of Dict, 一条大 wav 一个 record,一条小 wav 一个 sub_recored
177
- print(f"start to save {args.rank}_{args.nshard}.npy ...")
178
- save_start_time = time.time()
179
- for record in tqdm.tqdm(results, total=len(results), colour='green'):
180
- for sub_record in record:
181
- # 这里加 try, 因为 txt 文件可能损坏
182
- try:
183
- utt_id = sub_record["utt_id"]
184
- subset = sub_record["subset"]
185
- audio_tag_top1 = check_txt_file(sub_record["audio_tag_path"])
186
- if audio_tag_top1 is not False:
187
- if 'speech' not in audio_tag_top1.lower():
188
- non_speech_dict[subset][utt_id] = audio_tag_top1
189
- else:
190
- # print(f'audio tag result of {utt_id} is speech')
191
- pass
192
- else:
193
- print(f'audio tag result of {utt_id} is False')
194
- except Exception:
195
- print(f"{utt_id} occur Exception")
196
- traceback.print_exc()
197
- continue
198
-
199
- train_filename = train_dump_dir / f'non_speech_{args.rank}_{args.nshard}.npy'
200
- dev_filename = dev_dump_dir / f'non_speech_{args.rank}_{args.nshard}.npy'
201
- test_filename = test_dump_dir / f'non_speech_{args.rank}_{args.nshard}.npy'
202
- np.save(train_filename, non_speech_dict['train'])
203
- print(f"npy file '{train_filename}' write down")
204
-
205
- np.save(dev_filename, non_speech_dict['dev'])
206
- print(f"npy file '{dev_filename}' write down")
207
-
208
- np.save(test_filename, non_speech_dict['test'])
209
- print(f"npy file '{test_filename}' write down")
210
- print('time of save stage:', time.time() - save_start_time)
211
-
212
-
213
- def main():
214
- # parse config and args
215
- parser = argparse.ArgumentParser(
216
- description="Use AudioTag tool BEATs to filter out audios who's top1 tag is not 'speech'."
217
- )
218
-
219
- parser.add_argument(
220
- "--data_dir", default=None, type=str, help="directory to dataset.")
221
-
222
- parser.add_argument(
223
- "--dump_dir",
224
- type=str,
225
- required=True,
226
- help="directory to dump feature files.")
227
-
228
- parser.add_argument(
229
- "--num-cpu", type=int, default=1, help="number of process.")
230
-
231
- parser.add_argument(
232
- '--sr', type=int, default=16000, help='sample rate of model')
233
-
234
- # For LibriLight dataset
235
- parser.add_argument(
236
- "--sub_dataset",
237
- default="small",
238
- type=str,
239
- help="name of sub dataset of LibriLight",
240
- choices=['small', 'medium', 'large', 'duplicate'], )
241
- parser.add_argument(
242
- "--VAD_path", type=str, default='./VAD/librilight_segment_dict.npy')
243
- parser.add_argument("--nshard", type=int, default=3)
244
- parser.add_argument("--rank", type=int, default=0)
245
-
246
- # for BEATs
247
- parser.add_argument(
248
- "--BEATs_ckpt_path",
249
- type=str,
250
- default='./pretrained_model/BEATs_iter1_finetuned_on_AS2M_cpt1.pt')
251
-
252
- args = parser.parse_args()
253
-
254
- data_dir = Path(args.data_dir).expanduser()
255
- dump_dir = Path(args.dump_dir).expanduser()
256
- # use absolute path
257
- dump_dir = dump_dir.resolve()
258
- dump_dir.mkdir(parents=True, exist_ok=True)
259
-
260
- assert data_dir.is_dir()
261
-
262
- # sub_dataset here
263
- sub_dataset_dir = data_dir / args.sub_dataset
264
- # olny spk_id in list, sort by lexicographical order
265
- speaker_list = sorted(os.listdir(sub_dataset_dir))
266
- start, end = get_shard_range(len(speaker_list), args.nshard, args.rank)
267
- # speaker_list for this rank
268
- speaker_list = speaker_list[start:end]
269
-
270
- all_wav_files = []
271
-
272
- for speaker in speaker_list:
273
- wav_files = sorted(list((sub_dataset_dir / speaker).rglob("*/*.flac")))
274
- # filter out ._*.flac
275
- wav_files = [
276
- file for file in wav_files if not file.name.startswith('._')
277
- ]
278
- all_wav_files += wav_files
279
-
280
- print(f"num of wav files in rank {args.rank}:", len(all_wav_files))
281
- # get VAD info
282
- VAD_dict = np.load(args.VAD_path, allow_pickle=True).item()
283
-
284
- sub_dataset_dump_dir = dump_dir / args.sub_dataset
285
- sub_dataset_dump_dir.mkdir(parents=True, exist_ok=True)
286
- train_dump_dir = sub_dataset_dump_dir / "train"
287
- train_dump_dir.mkdir(parents=True, exist_ok=True)
288
- dev_dump_dir = sub_dataset_dump_dir / "dev"
289
- dev_dump_dir.mkdir(parents=True, exist_ok=True)
290
- test_dump_dir = sub_dataset_dump_dir / "test"
291
- test_dump_dir.mkdir(parents=True, exist_ok=True)
292
-
293
- BEATs_ckpt = torch.load(args.BEATs_ckpt_path)
294
-
295
- BEATs_cfg = BEATsConfig(BEATs_ckpt['cfg'])
296
- BEATs_model = BEATs(BEATs_cfg)
297
- BEATs_model.load_state_dict(BEATs_ckpt['model'])
298
- BEATs_model.eval()
299
- # cpu or cuda
300
- device = 'cpu'
301
- BEATs_model.to(device)
302
-
303
- BEATs_label_dict = BEATs_ckpt['label_dict']
304
-
305
- # 每条大 wav 分出一个 dev 一个 test,比例大概是 96:2:2
306
- if all_wav_files:
307
- process_sentences(
308
- args=args,
309
- fps=all_wav_files,
310
- train_dump_dir=train_dump_dir,
311
- dev_dump_dir=dev_dump_dir,
312
- test_dump_dir=test_dump_dir,
313
- VAD_dict=VAD_dict,
314
- BEATs_model=BEATs_model,
315
- BEATs_label_dict=BEATs_label_dict,
316
- device=device,
317
- nprocs=args.num_cpu)
318
-
319
-
320
- if __name__ == "__main__":
321
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_phones.py DELETED
@@ -1,232 +0,0 @@
1
- """
2
- 1. read text of dataset
3
- 2. text -> IPA by GruutPhonemizer
4
- 3. save out a *.npy dict for all text
5
- my_dict = {"utt_id1": text1, "utt_id2": text2}
6
- np.save(output_filename, my_dict)
7
- my_dict = np.load(output_filename, allow_pickle=True).item()
8
- """
9
- import argparse
10
- import os
11
- from concurrent.futures import ThreadPoolExecutor
12
- from operator import itemgetter
13
- from pathlib import Path
14
- from typing import List
15
-
16
- import numpy as np
17
- import tqdm
18
- from AR.text_processing.phonemizer import GruutPhonemizer
19
-
20
-
21
- def read_txt(txt_file):
22
- utt_name = txt_file.stem
23
- utt_id = utt_name.split('.')[0]
24
- try:
25
- with open(txt_file, 'r') as file:
26
- txt = file.readline()
27
- record = {"utt_id": utt_id, "txt": txt}
28
- except Exception:
29
- print("occur Exception")
30
- traceback.print_exc()
31
- return None
32
- return record
33
-
34
-
35
- def read_txts(txt_files: List[Path], nprocs: int=1):
36
- if nprocs == 1:
37
- results = []
38
- for txt_file in tqdm.tqdm(txt_files, total=len(txt_files)):
39
- record = read_txt(txt_file=txt_file)
40
- if record:
41
- results.append(record)
42
- else:
43
- with ThreadPoolExecutor(nprocs) as pool:
44
- futures = []
45
- with tqdm.tqdm(total=len(txt_files)) as progress:
46
- for txt_file in txt_files:
47
- future = pool.submit(read_txt, txt_file)
48
- future.add_done_callback(lambda p: progress.update())
49
- futures.append(future)
50
-
51
- results = []
52
- for ft in futures:
53
- record = ft.result()
54
- if record:
55
- results.append(record)
56
-
57
- results.sort(key=itemgetter("utt_id"))
58
- return_list = []
59
- for item in results:
60
- return_list.append((item["utt_id"], item["txt"]))
61
- return return_list
62
-
63
-
64
- def process_sentence(item, phonemizer):
65
- utt_id, text = item
66
- try:
67
- phonemes = phonemizer.phonemize(text, espeak=False)
68
- record = {"utt_id": utt_id, "phonemes": phonemes}
69
- except Exception:
70
- print("occur Exception")
71
- traceback.print_exc()
72
- return None
73
- return record
74
-
75
-
76
- def process_sentences(items, phonemizer, output_dir, nprocs: int=1):
77
- if nprocs == 1:
78
- results = []
79
- for item in tqdm.tqdm(items, total=len(items)):
80
- record = process_sentence(item=item, phonemizer=phonemizer)
81
- if record:
82
- results.append(record)
83
- else:
84
- with ThreadPoolExecutor(nprocs) as pool:
85
- futures = []
86
- with tqdm.tqdm(total=len(items)) as progress:
87
- for item in items:
88
- future = pool.submit(process_sentence, item, phonemizer)
89
- future.add_done_callback(lambda p: progress.update())
90
- futures.append(future)
91
-
92
- results = []
93
- for ft in futures:
94
- record = ft.result()
95
- if record:
96
- results.append(record)
97
- results.sort(key=itemgetter("utt_id"))
98
- npy_dict = {}
99
- for item in results:
100
- utt_id = item["utt_id"]
101
- phonemes = item["phonemes"]
102
- npy_dict[utt_id] = phonemes
103
- filename = output_dir / 'phonemes.npy'
104
- np.save(filename, npy_dict)
105
- print(f"npy file '{filename}' write down")
106
-
107
-
108
- def main():
109
- # parse config and args
110
- parser = argparse.ArgumentParser(description="Get phones for datasets")
111
-
112
- parser.add_argument(
113
- "--dataset",
114
- default="ljspeech",
115
- type=str,
116
- help="name of dataset, should in {ljspeech, libritts} now")
117
-
118
- parser.add_argument(
119
- "--data_dir", default=None, type=str, help="directory to dataset.")
120
-
121
- parser.add_argument(
122
- "--dump_dir",
123
- type=str,
124
- required=True,
125
- help="directory to dump feature files.")
126
- parser.add_argument(
127
- "--num-cpu", type=int, default=1, help="number of process.")
128
-
129
- args = parser.parse_args()
130
-
131
- data_dir = Path(args.data_dir).expanduser()
132
- dump_dir = Path(args.dump_dir).expanduser()
133
- # use absolute path
134
- dump_dir = dump_dir.resolve()
135
- dump_dir.mkdir(parents=True, exist_ok=True)
136
-
137
- assert data_dir.is_dir()
138
-
139
- if args.dataset == "ljspeech":
140
- data_dict = {}
141
- text_path = data_dir / 'metadata.csv'
142
- with open(text_path, 'r') as rf:
143
- for line in rf:
144
- line_list = line.strip().split('|')
145
- utt_id = line_list[0]
146
- raw_text = line_list[-1]
147
- data_dict[utt_id] = raw_text
148
-
149
- sorted_dict = sorted(data_dict.items())
150
-
151
- num_train = 12900
152
- num_dev = 100
153
- # (utt_id, txt)
154
- train_txts = sorted_dict[:num_train]
155
- dev_txts = sorted_dict[num_train:num_train + num_dev]
156
- test_txts = sorted_dict[num_train + num_dev:]
157
-
158
- elif args.dataset == "libritts":
159
- '''
160
- we use train-clean-100、train-clean-360、train-other-500 here
161
- and split dev and test from them, don't use test-* and dev-* cause the speakers are disjoint
162
- the file structure is LibriTTS_R/train-clean-100/spkid/*/*.wav
163
- there are about 2311 in these subsets, we split 1 dev and 1 test wav out from each speaker
164
- '''
165
- txt_files = []
166
- train_txt_files = []
167
- dev_txt_files = []
168
- test_txt_files = []
169
- sub_num_dev = 1
170
- for sub_dataset_name in {
171
- "train-clean-100", "train-clean-360", "train-other-500"
172
- }:
173
- sub_dataset_dir = data_dir / sub_dataset_name
174
- # filter out hidden files
175
- speaker_list = [
176
- file for file in os.listdir(sub_dataset_dir)
177
- if not file.startswith('.')
178
- ]
179
- for speaker in speaker_list:
180
- txt_files = sorted(
181
- list((sub_dataset_dir / speaker).rglob(
182
- "*/*.normalized.txt")))
183
- # filter out ._*.wav
184
- txt_files = [
185
- file for file in txt_files if not file.name.startswith('._')
186
- ]
187
- train_txt_files += txt_files[:-sub_num_dev * 2]
188
- dev_txt_files += txt_files[-sub_num_dev * 2:-sub_num_dev]
189
- test_txt_files += txt_files[-sub_num_dev:]
190
- print("len(train_txt_files):", len(train_txt_files))
191
- print("len(dev_txt_files):", len(dev_txt_files))
192
- print("len(test_txt_files):", len(test_txt_files))
193
-
194
- train_txts = read_txts(train_txt_files)
195
- dev_txts = read_txts(dev_txt_files)
196
- test_txts = read_txts(test_txt_files)
197
-
198
- else:
199
- print("dataset should in {ljspeech, libritts} now!")
200
-
201
- train_dump_dir = dump_dir / "train"
202
- train_dump_dir.mkdir(parents=True, exist_ok=True)
203
- dev_dump_dir = dump_dir / "dev"
204
- dev_dump_dir.mkdir(parents=True, exist_ok=True)
205
- test_dump_dir = dump_dir / "test"
206
- test_dump_dir.mkdir(parents=True, exist_ok=True)
207
-
208
- phonemizer = GruutPhonemizer(language='en-us')
209
-
210
- # process for the 3 sections
211
- if train_txts:
212
- process_sentences(
213
- items=train_txts,
214
- output_dir=train_dump_dir,
215
- phonemizer=phonemizer,
216
- nprocs=args.num_cpu)
217
- if dev_txts:
218
- process_sentences(
219
- items=dev_txts,
220
- output_dir=dev_dump_dir,
221
- phonemizer=phonemizer,
222
- nprocs=args.num_cpu)
223
- if test_txts:
224
- process_sentences(
225
- items=test_txts,
226
- output_dir=test_dump_dir,
227
- phonemizer=phonemizer,
228
- nprocs=args.num_cpu)
229
-
230
-
231
- if __name__ == "__main__":
232
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_phones_librilight.py DELETED
@@ -1,198 +0,0 @@
1
- """
2
- 1. read text of dataset, for LibriLight read txt_*.npy -> 需要整理成 list(utt_id, txt) 的形式
3
- 2. text -> IPA by GruutPhonemizer
4
- 3. save out a *.npy dict for all text
5
- 4. LibriLight 每个 split 分开处理
6
- my_dict = {"utt_id1": text1, "utt_id2": text2}
7
- np.save(output_filename, my_dict)
8
- my_dict = np.load(output_filename, allow_pickle=True).item()
9
- """
10
- import argparse
11
- import os
12
- import time
13
- import traceback
14
- from concurrent.futures import ThreadPoolExecutor
15
- from operator import itemgetter
16
- from pathlib import Path
17
-
18
- import numpy as np
19
- import tqdm
20
- from AR.text_processing.phonemizer import GruutPhonemizer
21
- from soundstorm.utils import check_txt_file
22
-
23
-
24
- def read_txts(txt_file: Path, nprocs: int=1):
25
- '''
26
- txt_file: path of npy dict, {"utt_id1": text1, "utt_id2": text2}
27
- '''
28
- txt_dict = np.load(txt_file, allow_pickle=True).item()
29
- #[(utt_id, txt), ...]
30
- return_list = list(txt_dict.items())
31
- return return_list
32
-
33
-
34
- def process_sentence(item, phonemizer, output_dir):
35
- utt_id, text = item
36
- phonemes_dir = output_dir / "phonemes"
37
- phonemes_dir.mkdir(parents=True, exist_ok=True)
38
- phonemes_path = phonemes_dir / (utt_id + ".txt")
39
- try:
40
- if os.path.exists(phonemes_path) and check_txt_file(phonemes_path):
41
- # print(phonemes_path, 'exits!')
42
- pass
43
- else:
44
- phonemes = phonemizer.phonemize(text, espeak=False)
45
- with open(phonemes_path, 'w') as f:
46
- f.write(phonemes)
47
- record = {"utt_id": utt_id, "phonemes_path": phonemes_path}
48
- except Exception:
49
- print("occur Exception")
50
- traceback.print_exc()
51
- return None
52
- return record
53
-
54
-
55
- def process_sentences(args, items, phonemizer, output_dir, nprocs: int=1):
56
- print("nprocs:", nprocs)
57
- if nprocs == 1:
58
- results = []
59
- for item in tqdm.tqdm(items, total=len(items)):
60
- record = process_sentence(
61
- item=item, phonemizer=phonemizer, output_dir=output_dir)
62
- if record:
63
- results.append(record)
64
- else:
65
- with ThreadPoolExecutor(nprocs) as pool:
66
- futures = []
67
- with tqdm.tqdm(total=len(items)) as progress:
68
- for item in items:
69
- future = pool.submit(process_sentence, item, phonemizer,
70
- output_dir)
71
- future.add_done_callback(lambda p: progress.update())
72
- futures.append(future)
73
-
74
- results = []
75
- for ft in futures:
76
- record = ft.result()
77
- if record:
78
- results.append(record)
79
-
80
- results.sort(key=itemgetter("utt_id"))
81
-
82
- npy_dict = {}
83
- print(f"start to save {args.rank}_{args.nshard}.npy ...")
84
- save_start_time = time.time()
85
- for item in tqdm.tqdm(results, total=len(results), colour='green'):
86
- # 这里加 try, 因为 txt 文件可能损坏
87
- try:
88
- utt_id = item["utt_id"]
89
- phonemes = check_txt_file(item["phonemes_path"])
90
- if phonemes is not False:
91
- npy_dict[utt_id] = phonemes
92
- else:
93
- print(f'phonemes of {utt_id} is False')
94
- except Exception:
95
- print(f"{utt_id} occur Exception")
96
- traceback.print_exc()
97
- continue
98
-
99
- filename = output_dir / f'phonemes_{args.rank}_{args.nshard}.npy'
100
- np.save(filename, npy_dict)
101
- print(f"npy file '{filename}' write down")
102
- print('time of save stage:', time.time() - save_start_time)
103
-
104
-
105
- def main():
106
- # parse config and args
107
- parser = argparse.ArgumentParser(
108
- description="Get phones for LibriLight dataset from txt_*.npy")
109
-
110
- parser.add_argument(
111
- "--dump_dir",
112
- type=str,
113
- required=True,
114
- help="directory to dump feature files.")
115
- parser.add_argument(
116
- "--num-cpu", type=int, default=1, help="number of process.")
117
-
118
- parser.add_argument(
119
- '--train_txt_dir',
120
- type=str,
121
- default='dump/small/train/',
122
- help='dir of train txt files')
123
- parser.add_argument(
124
- '--dev_txt_dir',
125
- type=str,
126
- default='dump/small/dev/',
127
- help='dir of dev txt files')
128
- parser.add_argument(
129
- '--test_txt_dir',
130
- type=str,
131
- default='dump/small/test/',
132
- help='dir of test txt files')
133
-
134
- parser.add_argument(
135
- "--sub_dataset",
136
- default="small",
137
- type=str,
138
- help="name of sub dataset of LibriLight",
139
- choices=['small', 'medium', 'large', 'duplicate'], )
140
- parser.add_argument("--nshard", type=int, default=3)
141
- parser.add_argument("--rank", type=int, default=0)
142
-
143
- args = parser.parse_args()
144
- print(f"nshard: {args.nshard}, rank: {args.rank}")
145
-
146
- train_txt_dir = Path(args.train_txt_dir)
147
- dev_txt_dir = Path(args.dev_txt_dir)
148
- test_txt_dir = Path(args.test_txt_dir)
149
-
150
- dump_dir = Path(args.dump_dir).expanduser()
151
- # use absolute path
152
- dump_dir = dump_dir.resolve()
153
- dump_dir.mkdir(parents=True, exist_ok=True)
154
-
155
- train_txt_file = train_txt_dir / f'txt_{args.rank}_{args.nshard}.npy'
156
- dev_txt_file = dev_txt_dir / f'txt_{args.rank}_{args.nshard}.npy'
157
- test_txt_file = test_txt_dir / f'txt_{args.rank}_{args.nshard}.npy'
158
-
159
- train_txts = read_txts(train_txt_file)
160
- dev_txts = read_txts(dev_txt_file)
161
- test_txts = read_txts(test_txt_file)
162
-
163
- sub_dataset_dump_dir = dump_dir / args.sub_dataset
164
- sub_dataset_dump_dir.mkdir(parents=True, exist_ok=True)
165
- train_dump_dir = sub_dataset_dump_dir / "train"
166
- train_dump_dir.mkdir(parents=True, exist_ok=True)
167
- dev_dump_dir = sub_dataset_dump_dir / "dev"
168
- dev_dump_dir.mkdir(parents=True, exist_ok=True)
169
- test_dump_dir = sub_dataset_dump_dir / "test"
170
- test_dump_dir.mkdir(parents=True, exist_ok=True)
171
- phonemizer = GruutPhonemizer(language='en-us')
172
-
173
- # process for the 3 sections
174
- if train_txts:
175
- process_sentences(
176
- args=args,
177
- items=train_txts,
178
- output_dir=train_dump_dir,
179
- phonemizer=phonemizer,
180
- nprocs=args.num_cpu)
181
- if dev_txts:
182
- process_sentences(
183
- args=args,
184
- items=dev_txts,
185
- output_dir=dev_dump_dir,
186
- phonemizer=phonemizer,
187
- nprocs=args.num_cpu)
188
- if test_txts:
189
- process_sentences(
190
- args=args,
191
- items=test_txts,
192
- output_dir=test_dump_dir,
193
- phonemizer=phonemizer,
194
- nprocs=args.num_cpu)
195
-
196
-
197
- if __name__ == "__main__":
198
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/get_txt_librilight.py DELETED
@@ -1,255 +0,0 @@
1
- import argparse
2
- import os
3
- import time
4
- import traceback
5
- from concurrent.futures import ThreadPoolExecutor
6
- from pathlib import Path
7
-
8
- import librosa
9
- import numpy as np
10
- import tqdm
11
- import whisper
12
- from soundstorm.s2.exps.hubert.feature_utils import get_shard_range
13
- from soundstorm.utils import check_txt_file
14
-
15
-
16
- def process_sentence(args,
17
- fp: Path,
18
- train_dump_dir: Path,
19
- dev_dump_dir: Path,
20
- test_dump_dir: Path,
21
- VAD_dict):
22
- asr_model = whisper.load_model("tiny.en")
23
- utt_id = fp.stem
24
- sr = args.sr
25
- record = []
26
- train_txt_dir = train_dump_dir / "txt"
27
- train_txt_dir.mkdir(parents=True, exist_ok=True)
28
-
29
- dev_txt_dir = dev_dump_dir / "txt"
30
- dev_txt_dir.mkdir(parents=True, exist_ok=True)
31
-
32
- test_txt_dir = test_dump_dir / "txt"
33
- test_txt_dir.mkdir(parents=True, exist_ok=True)
34
-
35
- try:
36
- # get info for path
37
- wav_path_list = str(fp).strip().split('/')
38
- sub_dataset, spk_id, book_name = wav_path_list[-4], wav_path_list[
39
- -3], wav_path_list[-2]
40
- wav_name = wav_path_list[-1][:-5]
41
- assert wav_name == utt_id
42
- # key_name for big wav
43
- key_name = f'{wav_name}#{sub_dataset}#{spk_id}#{book_name}'
44
- # 判断 VAD 字典中不存在该条音频信息的情况
45
- if key_name not in VAD_dict.keys():
46
- print(key_name, 'not in VAD_dict !')
47
- return record
48
- wav = None
49
- sorted_split_VAD_dict = sorted(VAD_dict[key_name].items())
50
- len_dict = len(sorted_split_VAD_dict)
51
- for index, item in enumerate(sorted_split_VAD_dict):
52
- split_name, value = item
53
- start, end = value
54
- # train | dev | test
55
- if index == len_dict - 1:
56
- subset = 'test'
57
- txt_path = test_txt_dir / (split_name + ".txt")
58
- elif index == len_dict - 2:
59
- subset = 'dev'
60
- txt_path = dev_txt_dir / (split_name + ".txt")
61
- else:
62
- subset = 'train'
63
- txt_path = train_txt_dir / (split_name + ".txt")
64
-
65
- if os.path.exists(txt_path) and check_txt_file(txt_path):
66
- # print(txt_path, 'exits!')
67
- pass
68
- else:
69
- # 这里加判断保证在 sub wav 的循环中只 load 一次
70
- if wav is None:
71
- # load big wav
72
- # 在最外层 load 如果 sub wav 的特征都存在了就会白白消耗 load 的时间
73
- wav, _ = librosa.load(str(fp), sr=sr)
74
- sub_wav = wav[int(start * sr):int(end * sr)]
75
- asr_result = asr_model.transcribe(sub_wav)["text"]
76
- with open(txt_path, 'w') as f:
77
- f.write(asr_result)
78
-
79
- sub_record = {
80
- "utt_id": split_name,
81
- "txt_path": txt_path,
82
- "subset": subset
83
- }
84
- # recodrd 变成 List of Dict
85
- record.append(sub_record)
86
- except Exception:
87
- print("occur Exception")
88
- traceback.print_exc()
89
- # record 有可能是一个不完整的 List
90
- return record
91
- return record
92
-
93
-
94
- def process_sentences(args,
95
- fps: Path,
96
- train_dump_dir: Path,
97
- dev_dump_dir: Path,
98
- test_dump_dir: Path,
99
- VAD_dict,
100
- nprocs: int=1):
101
- print("nprocs:", nprocs)
102
- if nprocs == 1:
103
- results = []
104
- for fp in tqdm.tqdm(fps, total=len(fps)):
105
- record = process_sentence(
106
- args=args,
107
- fp=fp,
108
- train_dump_dir=train_dump_dir,
109
- dev_dump_dir=dev_dump_dir,
110
- test_dump_dir=test_dump_dir,
111
- VAD_dict=VAD_dict)
112
- if record:
113
- results.append(record)
114
- else:
115
- with ThreadPoolExecutor(nprocs) as pool:
116
- futures = []
117
- with tqdm.tqdm(total=len(fps)) as progress:
118
- for fp in fps:
119
- future = pool.submit(process_sentence, args, fp,
120
- train_dump_dir, dev_dump_dir,
121
- test_dump_dir, VAD_dict)
122
- future.add_done_callback(lambda p: progress.update())
123
- futures.append(future)
124
-
125
- results = []
126
- for ft in futures:
127
- record = ft.result()
128
- if record:
129
- results.append(record)
130
-
131
- # torch.save() to a large `.pth` file
132
- txt_dict = dict()
133
- txt_dict['train'] = {}
134
- txt_dict['dev'] = {}
135
- txt_dict['test'] = {}
136
- # record 是 List of Dict, 一条大 wav 一个 record,一条小 wav 一个 sub_recored
137
- print(f"start to save {args.rank}_{args.nshard}.npy ...")
138
- save_start_time = time.time()
139
- for record in tqdm.tqdm(results, total=len(results), colour='green'):
140
- for sub_record in record:
141
- # 这里加 try, 因为 txt 文件可能损坏
142
- try:
143
- utt_id = sub_record["utt_id"]
144
- subset = sub_record["subset"]
145
- asr_result = check_txt_file(sub_record["txt_path"])
146
- if asr_result is not False:
147
- txt_dict[subset][utt_id] = asr_result
148
- else:
149
- print(f'asr result of {utt_id} is False')
150
- except Exception:
151
- print(f"{utt_id} occur Exception")
152
- traceback.print_exc()
153
- continue
154
-
155
- train_filename = train_dump_dir / f'txt_{args.rank}_{args.nshard}.npy'
156
- dev_filename = dev_dump_dir / f'txt_{args.rank}_{args.nshard}.npy'
157
- test_filename = test_dump_dir / f'txt_{args.rank}_{args.nshard}.npy'
158
- np.save(train_filename, txt_dict['train'])
159
- print(f"npy file '{train_filename}' write down")
160
-
161
- np.save(dev_filename, txt_dict['dev'])
162
- print(f"npy file '{dev_filename}' write down")
163
-
164
- np.save(test_filename, txt_dict['test'])
165
- print(f"npy file '{test_filename}' write down")
166
- print('time of save stage:', time.time() - save_start_time)
167
-
168
-
169
- def main():
170
- # parse config and args
171
- parser = argparse.ArgumentParser(
172
- description="Preprocess audio and then extract features for LibriLight.")
173
-
174
- parser.add_argument(
175
- "--data_dir", default=None, type=str, help="directory to dataset.")
176
-
177
- parser.add_argument(
178
- "--dump_dir",
179
- type=str,
180
- required=True,
181
- help="directory to dump feature files.")
182
-
183
- parser.add_argument(
184
- "--num-cpu", type=int, default=1, help="number of process.")
185
-
186
- parser.add_argument(
187
- '--sr', type=int, default=16000, help='sample rate of model')
188
-
189
- # For LibriLight dataset
190
- parser.add_argument(
191
- "--sub_dataset",
192
- default="small",
193
- type=str,
194
- help="name of sub dataset of LibriLight",
195
- choices=['small', 'medium', 'large', 'duplicate'], )
196
- parser.add_argument(
197
- "--VAD_path", type=str, default='./VAD/librilight_segment_dict.npy')
198
- parser.add_argument("--nshard", type=int, default=3)
199
- parser.add_argument("--rank", type=int, default=0)
200
-
201
- args = parser.parse_args()
202
-
203
- data_dir = Path(args.data_dir).expanduser()
204
- dump_dir = Path(args.dump_dir).expanduser()
205
- # use absolute path
206
- dump_dir = dump_dir.resolve()
207
- dump_dir.mkdir(parents=True, exist_ok=True)
208
-
209
- assert data_dir.is_dir()
210
-
211
- # sub_dataset here
212
- sub_dataset_dir = data_dir / args.sub_dataset
213
- # olny spk_id in list, sort by lexicographical order
214
- speaker_list = sorted(os.listdir(sub_dataset_dir))
215
- start, end = get_shard_range(len(speaker_list), args.nshard, args.rank)
216
- # speaker_list for this rank
217
- speaker_list = speaker_list[start:end]
218
-
219
- all_wav_files = []
220
-
221
- for speaker in speaker_list:
222
- wav_files = sorted(list((sub_dataset_dir / speaker).rglob("*/*.flac")))
223
- # filter out ._*.flac
224
- wav_files = [
225
- file for file in wav_files if not file.name.startswith('._')
226
- ]
227
- all_wav_files += wav_files
228
-
229
- print(f"num of wav files in rank {args.rank}:", len(all_wav_files))
230
- # get VAD info
231
- VAD_dict = np.load(args.VAD_path, allow_pickle=True).item()
232
-
233
- sub_dataset_dump_dir = dump_dir / args.sub_dataset
234
- sub_dataset_dump_dir.mkdir(parents=True, exist_ok=True)
235
- train_dump_dir = sub_dataset_dump_dir / "train"
236
- train_dump_dir.mkdir(parents=True, exist_ok=True)
237
- dev_dump_dir = sub_dataset_dump_dir / "dev"
238
- dev_dump_dir.mkdir(parents=True, exist_ok=True)
239
- test_dump_dir = sub_dataset_dump_dir / "test"
240
- test_dump_dir.mkdir(parents=True, exist_ok=True)
241
-
242
- # 每条大 wav 分出一个 dev 一个 test,比例大概是 96:2:2
243
- if all_wav_files:
244
- process_sentences(
245
- args=args,
246
- fps=all_wav_files,
247
- train_dump_dir=train_dump_dir,
248
- dev_dump_dir=dev_dump_dir,
249
- test_dump_dir=test_dump_dir,
250
- VAD_dict=VAD_dict,
251
- nprocs=args.num_cpu)
252
-
253
-
254
- if __name__ == "__main__":
255
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/split_train_val.py DELETED
@@ -1,35 +0,0 @@
1
- import numpy
2
- import pandas
3
-
4
- semantic_path = 'dump/semantic.tsv'
5
- phoneme_path = 'dump/phoneme.npy'
6
- train_semantic_path = 'dump/semantic_train.tsv'
7
- train_phoneme_path = 'dump/phoneme_train.npy'
8
- dev_semantic_path = 'dump/semantic_dev.tsv'
9
- dev_phoneme_path = 'dump/phoneme_dev.npy'
10
-
11
- # 读取dump/semantic.tsv
12
- semantic_df = pandas.read_csv(semantic_path, sep='\t')
13
- # pd.DataFrame(columns=["item_name", "semantic_audio"])
14
- # # 读取dump/phoneme.npy
15
- phoneme_dict = numpy.load(phoneme_path, allow_pickle=True).item()
16
-
17
- dev_num = 20
18
- # 随机从semantic_df中选取dev_num个
19
- dev_df = semantic_df.sample(n=dev_num)
20
- # 剩下的是train
21
- train_df = semantic_df.drop(dev_df.index)
22
- # 保存
23
- dev_df.to_csv(dev_semantic_path, sep='\t', index=False)
24
- train_df.to_csv(train_semantic_path, sep='\t', index=False)
25
-
26
- # 将dev_df中的item_name取出来 作为dev_phoneme_dict的key
27
- dev_item_names = dev_df['item_name'].tolist()
28
- dev_phoneme_dict = {k: phoneme_dict[k] for k in dev_item_names if k in phoneme_dict}
29
- train_phoneme_dict = {k: phoneme_dict[k] for k in phoneme_dict.keys() if k not in dev_item_names}
30
-
31
- numpy.save(dev_phoneme_path, dev_phoneme_dict)
32
- numpy.save(train_phoneme_path, train_phoneme_dict)
33
-
34
-
35
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/t2s.py DELETED
@@ -1,197 +0,0 @@
1
- # text to semantic
2
- import argparse
3
- import os
4
- import re
5
- import time
6
- from pathlib import Path
7
-
8
- import librosa
9
- import numpy as np
10
- import torch
11
- import whisper
12
- from AR.models.t2s_lightning_module import Text2SemanticLightningModule
13
- from AR.text_processing.phonemizer import GruutPhonemizer
14
- from AR.utils.io import load_yaml_config
15
-
16
-
17
- def get_batch(text, phonemizer):
18
- # phoneme_ids 和 phoneme_ids_len 是需要的
19
- phoneme = phonemizer.phonemize(text, espeak=False)
20
- phoneme_ids = phonemizer.transform(phoneme)
21
- phoneme_ids_len = len(phoneme_ids)
22
- phoneme_ids = np.array(phoneme_ids)
23
- # add batch axis here
24
- phoneme_ids = torch.tensor(phoneme_ids).unsqueeze(0)
25
- phoneme_ids_len = torch.tensor([phoneme_ids_len])
26
- print("phoneme:", phoneme)
27
- batch = {
28
- # torch.Tensor (B, max_phoneme_length)
29
- "phoneme_ids": phoneme_ids,
30
- # torch.Tensor (B)
31
- "phoneme_ids_len": phoneme_ids_len
32
- }
33
- return batch
34
-
35
-
36
- def get_prompt(prompt_wav_path, asr_model, phonemizer, semantic_tokenizer):
37
- sample_rate = 16000
38
- # to get prompt
39
- prompt_name = os.path.basename(prompt_wav_path).split('.')[0]
40
- wav, _ = librosa.load(prompt_wav_path, sr=sample_rate)
41
- # 取末尾 3s, 但是不包含最后 0.1s 防止 AR S1 infer 提前停止
42
- wav = wav[-sample_rate * 3:-int(sample_rate * 0.1)]
43
- # wav 需要挪出末尾的静音否则也可能提前停住
44
- prompt_text = asr_model.transcribe(wav)["text"]
45
- # 移除最后的句点, 防止 AR S1 infer 提前停止, 加了句点可能会有停顿
46
- prompt_text = prompt_text.replace(".", "")
47
- prompt_phoneme = phonemizer.phonemize(prompt_text, espeak=False)
48
- prompt_phoneme_ids = phonemizer.transform(prompt_phoneme)
49
- prompt_phoneme_ids_len = len(prompt_phoneme_ids)
50
- # get prompt_semantic
51
- # (T) -> (1, T)
52
- wav = torch.tensor(wav).unsqueeze(0)
53
- wav = wav.cuda()
54
- # (1, T)
55
- prompt_semantic_tokens = semantic_tokenizer.tokenize(wav).to(torch.int32)
56
- prompt_phoneme_ids = torch.tensor(prompt_phoneme_ids).unsqueeze(0)
57
- prompt_phoneme_ids_len = torch.tensor([prompt_phoneme_ids_len])
58
-
59
- result = {
60
- 'prompt_name': prompt_name,
61
- 'prompt_phoneme_ids': prompt_phoneme_ids,
62
- 'prompt_semantic_tokens': prompt_semantic_tokens,
63
- 'prompt_phoneme_ids_len': prompt_phoneme_ids_len
64
- }
65
-
66
- return result
67
-
68
-
69
- def parse_args():
70
- # parse args and config
71
- parser = argparse.ArgumentParser(
72
- description="Run SoundStorm AR S1 model for input text file")
73
-
74
- parser.add_argument(
75
- '--config_file',
76
- type=str,
77
- default='conf/default.yaml',
78
- help='path of config file')
79
-
80
- parser.add_argument(
81
- "--text_file",
82
- type=str,
83
- help="text file to be convert to semantic tokens, a 'utt_id sentence' pair per line."
84
- )
85
-
86
- parser.add_argument(
87
- '--ckpt_path',
88
- type=str,
89
- default='exp/default/ckpt/epoch=99-step=49000.ckpt',
90
- help='Checkpoint file of SoundStorm AR S1 model.')
91
-
92
- parser.add_argument(
93
- '--prompt_wav_path',
94
- type=str,
95
- default=None,
96
- help='extract prompt semantic and prompt phonemes from prompt wav')
97
-
98
- # to get semantic tokens from prompt_wav
99
- parser.add_argument("--hubert_path", type=str, default=None)
100
- parser.add_argument("--quantizer_path", type=str, default=None)
101
-
102
- parser.add_argument("--output_dir", type=str, help="output dir.")
103
-
104
- args = parser.parse_args()
105
- return args
106
-
107
-
108
- def main():
109
- args = parse_args()
110
- config = load_yaml_config(args.config_file)
111
-
112
- output_dir = Path(args.output_dir)
113
- output_dir.mkdir(parents=True, exist_ok=True)
114
-
115
- hz = 50
116
- max_sec = config['data']['max_sec']
117
-
118
- # get models
119
- t2s_model = Text2SemanticLightningModule.load_from_checkpoint(
120
- checkpoint_path=args.ckpt_path, config=config)
121
- t2s_model.cuda()
122
- t2s_model.eval()
123
-
124
- phonemizer: GruutPhonemizer = GruutPhonemizer(language='en-us')
125
-
126
- # models for prompt
127
- asr_model = whisper.load_model("tiny.en")
128
-
129
- semantic_tokenizer = SemanticTokenizer(
130
- hubert_path=args.hubert_path,
131
- quantizer_path=args.quantizer_path,
132
- duplicate=True)
133
-
134
- prompt_result = get_prompt(
135
- prompt_wav_path=args.prompt_wav_path,
136
- asr_model=asr_model,
137
- phonemizer=phonemizer,
138
- semantic_tokenizer=semantic_tokenizer)
139
-
140
- # zero prompt => 输出的 semantic 包含的内容是对的但是音色是乱的
141
- # (B, 1)
142
- # prompt = torch.ones(
143
- # batch['phoneme_ids'].size(0), 1, dtype=torch.int32) * 0
144
-
145
- prompt = prompt_result['prompt_semantic_tokens']
146
- prompt_phoneme_ids_len = prompt_result['prompt_phoneme_ids_len']
147
- prompt_phoneme_ids = prompt_result['prompt_phoneme_ids']
148
-
149
- sentences = []
150
- with open(args.text_file, 'rt', encoding='utf-8') as f:
151
- for line in f:
152
- if line.strip() != "":
153
- items = re.split(r"\s+", line.strip(), 1)
154
- utt_id = items[0]
155
- sentence = " ".join(items[1:])
156
- sentences.append((utt_id, sentence))
157
- semantic_data = [['item_name', 'semantic_audio']]
158
- for utt_id, sentence in sentences[1:]:
159
- # 需要自己构造伪 batch 输入给模型
160
- batch = get_batch(sentence, phonemizer)
161
- # prompt 和真正的输入拼接
162
- all_phoneme_ids = torch.cat(
163
- [prompt_phoneme_ids, batch['phoneme_ids']], dim=1)
164
- # 或者可以直接求 all_phoneme_ids 的 shape[-1]
165
- all_phoneme_len = prompt_phoneme_ids_len + batch['phoneme_ids_len']
166
- st = time.time()
167
- with torch.no_grad():
168
- pred_semantic = t2s_model.model.infer(
169
- all_phoneme_ids.cuda(),
170
- all_phoneme_len.cuda(),
171
- prompt.cuda(),
172
- top_k=config['inference']['top_k'],
173
- early_stop_num=hz * max_sec)
174
- print(f'{time.time() - st} sec used in T2S')
175
-
176
- # 删除 prompt 对应的部分
177
- prompt_len = prompt.shape[-1]
178
- pred_semantic = pred_semantic[:, prompt_len:]
179
-
180
- # bs = 1
181
- pred_semantic = pred_semantic[0]
182
- semantic_token = pred_semantic.detach().cpu().numpy().tolist()
183
- semantic_token_str = ' '.join(str(x) for x in semantic_token)
184
- semantic_data.append([utt_id, semantic_token_str])
185
-
186
- delimiter = '\t'
187
- filename = output_dir / f'{utt_id}_p_{prompt_result["prompt_name"]}_semantic_token.tsv'
188
- with open(filename, 'w', encoding='utf-8') as writer:
189
- for row in semantic_data:
190
- line = delimiter.join(row)
191
- writer.write(line + '\n')
192
- # clean semantic token for next setence
193
- semantic_data = [['item_name', 'semantic_audio']]
194
-
195
-
196
- if __name__ == "__main__":
197
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/test.py DELETED
@@ -1,139 +0,0 @@
1
- # test from dump file
2
- import argparse
3
- import time
4
- from pathlib import Path
5
-
6
- import numpy as np
7
- import torch
8
- from AR.data.dataset import Text2SemanticDataset
9
- from AR.models.t2s_lightning_module import Text2SemanticLightningModule
10
- from AR.utils.io import load_yaml_config
11
- from torch.utils.data import DataLoader
12
-
13
-
14
- def parse_args():
15
- # parse args and config
16
- parser = argparse.ArgumentParser(
17
- description="Run SoundStorm AR S1 model for test set.")
18
-
19
- parser.add_argument(
20
- '--config_file',
21
- type=str,
22
- default='conf/default.yaml',
23
- help='path of config file')
24
-
25
- # args for dataset
26
- parser.add_argument(
27
- '--test_semantic_path',
28
- type=str,
29
- default='dump/test/semantic_token.tsv')
30
- parser.add_argument(
31
- '--test_phoneme_path', type=str, default='dump/test/phonemes.npy')
32
-
33
- parser.add_argument(
34
- '--ckpt_path',
35
- type=str,
36
- default='exp/default/ckpt/epoch=99-step=49000.ckpt',
37
- help='Checkpoint file of SoundStorm AR S1 model.')
38
-
39
- parser.add_argument("--output_dir", type=str, help="output dir.")
40
-
41
- args = parser.parse_args()
42
- return args
43
-
44
-
45
- def main():
46
- args = parse_args()
47
-
48
- config = load_yaml_config(args.config_file)
49
-
50
- output_dir = Path(args.output_dir)
51
- output_dir.mkdir(parents=True, exist_ok=True)
52
-
53
- batch_size = 1
54
- hz = 50
55
- max_sec = config['data']['max_sec']
56
-
57
- # get dataset
58
- test_dataset = Text2SemanticDataset(
59
- phoneme_path=args.test_phoneme_path,
60
- semantic_path=args.test_semantic_path,
61
- # max_sec 需要与训练时保持一致,不然可能会效果不好,重复漏字等
62
- # 但是这里设置太短又会直接过滤掉太长的样本,为了防止被过滤掉,可以在 infer 的时候截断
63
- max_sec=100,
64
- max_sample=8,
65
- pad_val=config['data']['pad_val'])
66
- # get model
67
- t2s_model = Text2SemanticLightningModule.load_from_checkpoint(
68
- checkpoint_path=args.ckpt_path, config=config)
69
- t2s_model.cuda()
70
- t2s_model.eval()
71
-
72
- # 获取 batch_size 条
73
- # 创建 DataLoader,并指定 collate_fn 函数
74
- dataloader = DataLoader(
75
- test_dataset,
76
- batch_size=batch_size,
77
- shuffle=False,
78
- collate_fn=test_dataset.collate)
79
-
80
- item_names = test_dataset.__get_item_names__()
81
-
82
- # 逐批次读取数据, bs=1、shuffle=False 时可以用 __get_item_names__ 对应
83
- semantic_data = [['item_name', 'semantic_audio']]
84
- for i, batch in enumerate(dataloader):
85
- # 要保证 bs = 1
86
- utt_id = item_names[i]
87
- if i == 0:
88
- print("utt_id:", utt_id)
89
- # bs > 1 时会补零
90
- # 与 validation_step() 保持一致
91
- semantic_len = batch['semantic_ids'].size(1)
92
- # 以 batch['semantic_ids'] 的前 150 个为 prompt
93
- # 多次合成,前 prompt_len 个是一样的,而且和 prompt 一样
94
- prompt_len = min(int(semantic_len * 0.5), 150)
95
- # 输入纯文本时 prompt 该输入什么?=> see t2s.py
96
- prompt = batch['semantic_ids'][:, :prompt_len]
97
- # # zero prompt => 也可以输出文本内容正确的 semantic token, 但是音色是乱的
98
- # 证明 semantic token 中还是包含了音色信息
99
- # prompt = torch.ones(
100
- # batch['semantic_ids'].size(0), 1, dtype=torch.int32) * 0
101
- # print("prompt:", prompt)
102
- # print("prompt.shape:", prompt.shape)
103
- np.save(output_dir / 'prompt.npy', prompt.detach().cpu().numpy())
104
-
105
- st = time.time()
106
- with torch.no_grad():
107
- # calculate acc for test
108
- loss, acc = t2s_model.model.forward(
109
- batch['phoneme_ids'].cuda(),
110
- batch['phoneme_ids_len'].cuda(),
111
- batch['semantic_ids'].cuda(),
112
- batch['semantic_ids_len'].cuda())
113
- print("top_3_acc of this batch:", acc)
114
- pred_semantic = t2s_model.model.infer(
115
- batch['phoneme_ids'].cuda(),
116
- batch['phoneme_ids_len'].cuda(),
117
- prompt.cuda(),
118
- top_k=config['inference']['top_k'],
119
- # hz * max_sec in train dataloader
120
- # 生成的长度是 1002 应该是有一些 pad
121
- early_stop_num=hz * max_sec)
122
- # bs = 1
123
- pred_semantic = pred_semantic[0]
124
- print(f'{time.time() - st} sec used in T2S')
125
- semantic_token = pred_semantic.detach().cpu().numpy().tolist()
126
- semantic_token_str = ' '.join(str(x) for x in semantic_token)
127
- semantic_data.append([utt_id, semantic_token_str])
128
- else:
129
- break
130
- delimiter = '\t'
131
- filename = output_dir / "semantic_token.tsv"
132
- with open(filename, 'w', encoding='utf-8') as writer:
133
- for row in semantic_data:
134
- line = delimiter.join(row)
135
- writer.write(line + '\n')
136
-
137
-
138
- if __name__ == "__main__":
139
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/text.txt DELETED
@@ -1,10 +0,0 @@
1
- 001 Life was like a box of chocolates, you never know what you're gonna get.
2
- 002 With great power there must come great responsibility.
3
- 003 To be or not to be, that’s a question.
4
- 004 A man can be destroyed but not defeated
5
- 005 Do not, for one repulse, give up the purpose that you resolved to effort.
6
- 006 Death is just a part of life, something we're all destined to do.
7
- 007 I think it's hard winning a war with words.
8
- 008 Don’t argue with the people of strong determination, because they may change the fact!
9
- 009 Love you three thousand times.
10
- 010 tidy tiger tied a tie tighter to tidy her tiny tall.
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/train.py DELETED
@@ -1,103 +0,0 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
2
- import argparse
3
- import logging
4
- import os
5
- from pathlib import Path
6
-
7
- import torch
8
- from pytorch_lightning import seed_everything
9
- from pytorch_lightning import Trainer
10
- from pytorch_lightning.callbacks import ModelCheckpoint
11
- from pytorch_lightning.loggers import WandbLogger
12
- from pytorch_lightning.strategies import DDPStrategy
13
- from AR.data.data_module import Text2SemanticDataModule
14
- from AR.models.t2s_lightning_module import Text2SemanticLightningModule
15
- from soundstorm.utils.io import load_yaml_config
16
- logging.getLogger('numba').setLevel(logging.WARNING)
17
- logging.getLogger('matplotlib').setLevel(logging.WARNING)
18
- torch.set_float32_matmul_precision('high')
19
- from soundstorm.utils import get_newest_ckpt
20
-
21
-
22
- def main(args):
23
- output_dir = Path(args.output_dir)
24
- output_dir.mkdir(parents=True, exist_ok=True)
25
-
26
- ckpt_dir = output_dir / 'ckpt'
27
- ckpt_dir.mkdir(parents=True, exist_ok=True)
28
-
29
- config = load_yaml_config(args.config_file)
30
-
31
- seed_everything(config["train"]["seed"], workers=True)
32
- ckpt_callback: ModelCheckpoint = ModelCheckpoint(
33
- save_top_k=-1,
34
- save_on_train_epoch_end=False,
35
- every_n_epochs=config["train"]["save_every_n_epoch"],
36
- dirpath=ckpt_dir)
37
- logger = WandbLogger(
38
- project="AR_S1",
39
- name=output_dir.stem,
40
- save_dir=output_dir,
41
- # resume the loss curve
42
- resume=True,
43
- # id='k19kvsq8'
44
- )
45
- trainer: Trainer = Trainer(
46
- max_epochs=config["train"]["epochs"],
47
- accelerator='gpu',
48
- devices=-1,
49
- benchmark=False,
50
- fast_dev_run=False,
51
- strategy=DDPStrategy(find_unused_parameters=True),
52
- precision=config["train"]["precision"],
53
- logger=logger,
54
- callbacks=[ckpt_callback])
55
-
56
- model: Text2SemanticLightningModule = Text2SemanticLightningModule(
57
- config, output_dir)
58
-
59
- data_module: Text2SemanticDataModule = Text2SemanticDataModule(
60
- config,
61
- train_semantic_path=args.train_semantic_path,
62
- train_phoneme_path=args.train_phoneme_path,
63
- dev_semantic_path=args.dev_semantic_path,
64
- dev_phoneme_path=args.dev_phoneme_path)
65
-
66
- try:
67
- # 使用正则表达式匹配文件名中的数字部分,并按数字大小进行排序
68
- newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
69
- ckpt_path = ckpt_dir / newest_ckpt_name
70
- except Exception:
71
- ckpt_path = None
72
- print("ckpt_path:", ckpt_path)
73
- trainer.fit(model, data_module, ckpt_path=ckpt_path)
74
-
75
-
76
- # srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
77
- if __name__ == '__main__':
78
- parser = argparse.ArgumentParser()
79
- parser.add_argument(
80
- '--config_file',
81
- type=str,
82
- default='conf/default.yaml',
83
- help='path of config file')
84
- # args for dataset
85
- parser.add_argument(
86
- '--train_semantic_path',
87
- type=str,
88
- default='dump/train/semantic_token.tsv')
89
- parser.add_argument(
90
- '--train_phoneme_path', type=str, default='dump/train/phonemes.npy')
91
- parser.add_argument(
92
- '--dev_semantic_path', type=str, default='dump/dev/semantic_token.tsv')
93
- parser.add_argument(
94
- '--dev_phoneme_path', type=str, default='dump/dev/phonemes.npy')
95
- parser.add_argument(
96
- '--output_dir',
97
- type=str,
98
- default='exp/default',
99
- help='directory to save the results')
100
-
101
- args = parser.parse_args()
102
- logging.info(str(args))
103
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/exps/train_librilight_6k.py DELETED
@@ -1,170 +0,0 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
2
- import argparse
3
- import logging
4
- import os
5
- from pathlib import Path
6
-
7
- import torch
8
- from pytorch_lightning import seed_everything
9
- from pytorch_lightning import Trainer
10
- from pytorch_lightning.callbacks import ModelCheckpoint
11
- from pytorch_lightning.loggers import WandbLogger
12
- from pytorch_lightning.strategies import DDPStrategy
13
- from AR.data.data_module_librilight_6k import Text2SemanticDataModule
14
- from AR.models.t2s_lightning_module import Text2SemanticLightningModule
15
- from soundstorm.utils import get_newest_ckpt
16
- from soundstorm.utils.io import load_yaml_config
17
-
18
- logging.getLogger('numba').setLevel(logging.WARNING)
19
- logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
- torch.set_float32_matmul_precision('high')
21
-
22
-
23
- def main(args):
24
- output_dir = Path(args.output_dir)
25
- output_dir.mkdir(parents=True, exist_ok=True)
26
-
27
- ckpt_dir = output_dir / 'ckpt'
28
- ckpt_dir.mkdir(parents=True, exist_ok=True)
29
-
30
- config = load_yaml_config(args.config_file)
31
-
32
- seed_everything(config["train"]["seed"], workers=True)
33
-
34
- ckpt_callback: ModelCheckpoint = ModelCheckpoint(
35
- save_top_k=-1,
36
- save_on_train_epoch_end=False,
37
- every_n_train_steps=config["train"]["every_n_train_steps"],
38
- dirpath=ckpt_dir)
39
- logger = WandbLogger(
40
- project="AR_S1_LibriLight",
41
- name=output_dir.stem,
42
- save_dir=output_dir,
43
- # resume the loss curve
44
- resume=True,
45
- # id='k19kvsq8'
46
- )
47
- trainer: Trainer = Trainer(
48
- max_epochs=config["train"]["epochs"],
49
- accelerator='gpu',
50
- devices=-1,
51
- benchmark=False,
52
- fast_dev_run=False,
53
- strategy=DDPStrategy(find_unused_parameters=True),
54
- precision=config["train"]["precision"],
55
- logger=logger,
56
- callbacks=[ckpt_callback])
57
-
58
- model: Text2SemanticLightningModule = Text2SemanticLightningModule(
59
- config, output_dir)
60
-
61
- data_module: Text2SemanticDataModule = Text2SemanticDataModule(
62
- config,
63
- train_semantic_dirs=args.train_semantic_dirs,
64
- train_phoneme_dirs=args.train_phoneme_dirs,
65
- dev_semantic_dirs=args.dev_semantic_dirs,
66
- dev_phoneme_dirs=args.dev_phoneme_dirs,
67
- train_non_speech_dirs=args.train_non_speech_dirs,
68
- dev_non_speech_dirs=args.dev_non_speech_dirs)
69
- try:
70
- newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
71
- ckpt_path = ckpt_dir / newest_ckpt_name
72
- except Exception:
73
- ckpt_path = None
74
-
75
- print("ckpt_path:", ckpt_path)
76
- trainer.fit(model, data_module, ckpt_path=ckpt_path)
77
-
78
-
79
- # srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
80
- if __name__ == '__main__':
81
- parser = argparse.ArgumentParser()
82
- parser.add_argument(
83
- '--config_file',
84
- type=str,
85
- default='conf/default.yaml',
86
- help='path of config file')
87
- # args for dataset
88
- parser.add_argument(
89
- '--train_semantic_dirs',
90
- type=list,
91
- nargs='+',
92
- default=["dump/small/train/"],
93
- help='dirs of train semantic')
94
- parser.add_argument(
95
- '--train_phoneme_dirs',
96
- type=list,
97
- nargs='+',
98
- default=["dump/small/train/"],
99
- help='dirs of train phoneme')
100
- parser.add_argument(
101
- '--dev_semantic_dirs',
102
- type=list,
103
- nargs='+',
104
- default=["dump/small/dev/"],
105
- help='dirs of dev semantic')
106
- parser.add_argument(
107
- '--dev_phoneme_dirs',
108
- type=list,
109
- nargs='+',
110
- default=["dump/small/dev/"],
111
- help='dirs of dev phoneme')
112
- parser.add_argument(
113
- '--output_dir',
114
- type=str,
115
- default='exp/default',
116
- help='directory to save the results')
117
-
118
- parser.add_argument(
119
- '--train_non_speech_dirs',
120
- type=list,
121
- nargs='+',
122
- default=None,
123
- help='dirs of train non_speech data')
124
-
125
- parser.add_argument(
126
- '--dev_non_speech_dirs',
127
- type=list,
128
- nargs='+',
129
- default=None,
130
- help='dirs of dev non_speech data')
131
-
132
- args = parser.parse_args()
133
-
134
- new_train_semantic_dirs = []
135
- new_train_phoneme_dirs = []
136
- new_dev_semantic_dirs = []
137
- new_dev_phoneme_dirs = []
138
-
139
- new_train_non_speech_dirs = []
140
- new_dev_non_speech_dirs = []
141
-
142
- # format dataset dirs
143
- for item in args.train_semantic_dirs:
144
- new_train_semantic_dirs.append(''.join(item))
145
- args.train_semantic_dirs = new_train_semantic_dirs
146
-
147
- for item in args.train_phoneme_dirs:
148
- new_train_phoneme_dirs.append(''.join(item))
149
- args.train_phoneme_dirs = new_train_phoneme_dirs
150
-
151
- for item in args.dev_semantic_dirs:
152
- new_dev_semantic_dirs.append(''.join(item))
153
- args.dev_semantic_dirs = new_dev_semantic_dirs
154
-
155
- for item in args.dev_phoneme_dirs:
156
- new_dev_phoneme_dirs.append(''.join(item))
157
- args.dev_phoneme_dirs = new_dev_phoneme_dirs
158
-
159
- if args.train_non_speech_dirs is not None:
160
- for item in args.train_non_speech_dirs:
161
- new_train_non_speech_dirs.append(''.join(item))
162
- args.train_non_speech_dirs = new_train_non_speech_dirs
163
-
164
- if args.dev_non_speech_dirs is not None:
165
- for item in args.dev_non_speech_dirs:
166
- new_dev_non_speech_dirs.append(''.join(item))
167
- args.dev_non_speech_dirs = new_dev_non_speech_dirs
168
-
169
- logging.info(str(args))
170
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/__init__.py DELETED
File without changes
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/t2s_lightning_module.py DELETED
@@ -1,128 +0,0 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
2
- import os,sys
3
- now_dir = os.getcwd()
4
- sys.path.append(now_dir)
5
- from typing import Dict
6
-
7
- import torch
8
- from pytorch_lightning import LightningModule
9
- from AR.models.t2s_model import Text2SemanticDecoder
10
- from AR.modules.lr_schedulers import WarmupCosineLRSchedule
11
- from AR.modules.optim import ScaledAdam
12
-
13
-
14
- class Text2SemanticLightningModule(LightningModule):
15
- def __init__(self, config, output_dir,is_train=True):
16
- super().__init__()
17
- self.config = config
18
- self.top_k = 3
19
- self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
20
- pretrained_s1=config.get("pretrained_s1")
21
- if(pretrained_s1 and is_train):
22
- # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
23
- print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["weight"]))
24
- if is_train:
25
- self.automatic_optimization = False
26
- self.save_hyperparameters()
27
- self.eval_dir = output_dir / 'eval'
28
- self.eval_dir.mkdir(parents=True, exist_ok=True)
29
-
30
- def training_step(self, batch: Dict, batch_idx: int):
31
-
32
- opt = self.optimizers()
33
- scheduler = self.lr_schedulers()
34
- loss, acc = self.model.forward(
35
- batch['phoneme_ids'], batch['phoneme_ids_len'],
36
- batch['semantic_ids'], batch['semantic_ids_len'],
37
- batch['bert_feature'])
38
- self.manual_backward(loss)
39
- if batch_idx > 0 and batch_idx % 4 == 0:
40
- opt.step()
41
- opt.zero_grad()
42
- scheduler.step()
43
-
44
- self.log(
45
- "total_loss",
46
- loss,
47
- on_step=True,
48
- on_epoch=True,
49
- prog_bar=True,
50
- sync_dist=True)
51
- self.log(
52
- "lr",
53
- scheduler.get_last_lr()[0],
54
- on_epoch=True,
55
- prog_bar=True,
56
- sync_dist=True)
57
- self.log(
58
- f"top_{self.top_k}_acc",
59
- acc,
60
- on_step=True,
61
- on_epoch=True,
62
- prog_bar=True,
63
- sync_dist=True)
64
-
65
- def validation_step(self, batch: Dict, batch_idx: int):return
66
- # # get loss
67
- # loss, acc = self.model.forward(
68
- # batch['phoneme_ids'], batch['phoneme_ids_len'],
69
- # batch['semantic_ids'], batch['semantic_ids_len'],
70
- # batch['bert_feature']
71
- # )
72
- #
73
- # self.log(
74
- # "val_total_loss",
75
- # loss,
76
- # on_step=True,
77
- # on_epoch=True,
78
- # prog_bar=True,
79
- # sync_dist=True)
80
- # self.log(
81
- # f"val_top_{self.top_k}_acc",
82
- # acc,
83
- # on_step=True,
84
- # on_epoch=True,
85
- # prog_bar=True,
86
- # sync_dist=True)
87
- #
88
- # # get infer output
89
- # semantic_len = batch['semantic_ids'].size(1)
90
- # prompt_len = min(int(semantic_len * 0.5), 150)
91
- # prompt = batch['semantic_ids'][:, :prompt_len]
92
- # pred_semantic = self.model.infer(batch['phoneme_ids'],
93
- # batch['phoneme_ids_len'], prompt,
94
- # batch['bert_feature']
95
- # )
96
- # save_name = f'semantic_toks_{batch_idx}.pt'
97
- # save_path = os.path.join(self.eval_dir, save_name)
98
- # torch.save(pred_semantic.detach().cpu(), save_path)
99
-
100
- def configure_optimizers(self):
101
- model_parameters = self.model.parameters()
102
- parameters_names = []
103
- parameters_names.append([
104
- name_param_pair[0]
105
- for name_param_pair in self.model.named_parameters()
106
- ])
107
- lm_opt = ScaledAdam(
108
- model_parameters,
109
- lr=0.01,
110
- betas=(0.9, 0.95),
111
- clipping_scale=2.0,
112
- parameters_names=parameters_names,
113
- show_dominant_parameters=False,
114
- clipping_update_period=1000, )
115
-
116
- return {
117
- "optimizer": lm_opt,
118
- "lr_scheduler": {
119
- "scheduler":
120
- WarmupCosineLRSchedule(
121
- lm_opt,
122
- init_lr=self.config['optimizer']['lr_init'],
123
- peak_lr=self.config['optimizer']['lr'],
124
- end_lr=self.config['optimizer']['lr_end'],
125
- warmup_steps=self.config['optimizer']['warmup_steps'],
126
- total_steps=self.config['optimizer']['decay_steps'])
127
- }
128
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/t2s_model.py DELETED
@@ -1,298 +0,0 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
2
- import torch
3
- from tqdm import tqdm
4
-
5
- from AR.models.utils import make_pad_mask
6
- from AR.models.utils import topk_sampling,sample,logits_to_probs,multinomial_sample_one_no_sync
7
- from AR.modules.embedding import SinePositionalEmbedding
8
- from AR.modules.embedding import TokenEmbedding
9
- from AR.modules.transformer import LayerNorm
10
- from AR.modules.transformer import TransformerEncoder
11
- from AR.modules.transformer import TransformerEncoderLayer
12
- from torch import nn
13
- from torch.nn import functional as F
14
- from torchmetrics.classification import MulticlassAccuracy
15
-
16
- default_config = {
17
- "embedding_dim": 512,
18
- "hidden_dim": 512,
19
- "num_head": 8,
20
- "num_layers": 12,
21
- "num_codebook": 8,
22
- "p_dropout": 0.0,
23
- "vocab_size": 1024 + 1,
24
- "phoneme_vocab_size": 512,
25
- "EOS": 1024
26
- }
27
-
28
-
29
- class Text2SemanticDecoder(nn.Module):
30
- def __init__(self, config, norm_first=False, top_k=3):
31
- super(Text2SemanticDecoder, self).__init__()
32
- self.model_dim = config['model']["hidden_dim"]
33
- self.embedding_dim = config['model']["embedding_dim"]
34
- self.num_head = config['model']["head"]
35
- self.num_layers = config['model']["n_layer"]
36
- self.norm_first = norm_first
37
- self.vocab_size = config['model']["vocab_size"]
38
- self.phoneme_vocab_size = config['model']["phoneme_vocab_size"]
39
- self.p_dropout = config['model']["dropout"]
40
- self.EOS = config['model']["EOS"]
41
- self.norm_first = norm_first
42
- assert self.EOS == self.vocab_size - 1
43
- # should be same as num of kmeans bin
44
- # assert self.EOS == 1024
45
- self.bert_proj = nn.Linear(1024, self.embedding_dim)
46
- self.ar_text_embedding = TokenEmbedding(
47
- self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
48
- self.ar_text_position = SinePositionalEmbedding(
49
- self.embedding_dim, dropout=0.1, scale=False, alpha=True)
50
- self.ar_audio_embedding = TokenEmbedding(
51
- self.embedding_dim, self.vocab_size, self.p_dropout)
52
- self.ar_audio_position = SinePositionalEmbedding(
53
- self.embedding_dim, dropout=0.1, scale=False, alpha=True)
54
-
55
- self.h = TransformerEncoder(
56
- TransformerEncoderLayer(
57
- d_model=self.model_dim,
58
- nhead=self.num_head,
59
- dim_feedforward=self.model_dim * 4,
60
- dropout=0.1,
61
- batch_first=True,
62
- norm_first=norm_first, ),
63
- num_layers=self.num_layers,
64
- norm=LayerNorm(self.model_dim) if norm_first else None, )
65
-
66
- self.ar_predict_layer = nn.Linear(
67
- self.model_dim, self.vocab_size, bias=False)
68
- self.loss_fct = nn.CrossEntropyLoss(reduction='sum')
69
-
70
- self.ar_accuracy_metric = MulticlassAccuracy(
71
- self.vocab_size,
72
- top_k=top_k,
73
- average="micro",
74
- multidim_average="global",
75
- ignore_index=self.EOS, )
76
-
77
- def forward(self, x, x_lens, y, y_lens, bert_feature):
78
- '''
79
- x: phoneme_ids
80
- y: semantic_ids
81
- '''
82
- x = self.ar_text_embedding(x)
83
- x = x + self.bert_proj(bert_feature.transpose(1,2))
84
- x = self.ar_text_position(x)
85
- x_mask = make_pad_mask(x_lens)
86
-
87
- y_mask = make_pad_mask(y_lens)
88
- y_mask_int = y_mask.type(torch.int64)
89
- codes = y.type(torch.int64) * (1 - y_mask_int)
90
-
91
- # Training
92
- # AR Decoder
93
- y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
94
- x_len = x_lens.max()
95
- y_len = y_lens.max()
96
- y_emb = self.ar_audio_embedding(y)
97
- y_pos = self.ar_audio_position(y_emb)
98
-
99
- xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
100
- ar_xy_padding_mask = xy_padding_mask
101
-
102
- x_attn_mask = F.pad(
103
- torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
104
- (0, y_len),
105
- value=True, )
106
- y_attn_mask = F.pad(
107
- torch.triu(
108
- torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
109
- diagonal=1, ),
110
- (x_len, 0),
111
- value=False, )
112
- xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
113
- bsz, src_len = x.shape[0], x_len + y_len
114
- _xy_padding_mask = (ar_xy_padding_mask.view(bsz, 1, 1, src_len)
115
- .expand(-1, self.num_head, -1, -1)
116
- .reshape(bsz * self.num_head, 1, src_len))
117
- xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
118
- new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
119
- new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
120
- xy_attn_mask = new_attn_mask
121
- # x 和完整的 y 一次性输入模型
122
- xy_pos = torch.concat([x, y_pos], dim=1)
123
- xy_dec, _ = self.h(
124
- (xy_pos, None),
125
- mask=xy_attn_mask, )
126
- logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
127
- # loss
128
- # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
129
- loss = F.cross_entropy(logits, targets, reduction='sum')
130
- acc = self.ar_accuracy_metric(logits.detach(), targets).item()
131
- return loss, acc
132
-
133
- # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
134
- def infer(self,
135
- x,
136
- x_lens,
137
- prompts,
138
- bert_feature,
139
- top_k: int=-100,
140
- early_stop_num: int=-1,
141
- temperature: float=1.0):
142
-
143
- x = self.ar_text_embedding(x)
144
- x = x + self.bert_proj(bert_feature.transpose(1,2))
145
- x = self.ar_text_position(x)
146
-
147
- # AR Decoder
148
- y = prompts
149
- prefix_len = y.shape[1]
150
- x_len = x.shape[1]
151
- x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
152
- stop = False
153
- for _ in tqdm(range(1500)):
154
- y_emb = self.ar_audio_embedding(y)
155
- y_pos = self.ar_audio_position(y_emb)
156
- # x 和逐渐增长的 y 一起输入给模型
157
- xy_pos = torch.concat([x, y_pos], dim=1)
158
- y_len = y.shape[1]
159
- x_attn_mask_pad = F.pad(
160
- x_attn_mask,
161
- (0, y_len),
162
- value=True, )
163
- y_attn_mask = F.pad(
164
- torch.triu(
165
- torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
166
- (x_len, 0),
167
- value=False, )
168
- xy_attn_mask = torch.concat(
169
- [x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
170
-
171
- xy_dec, _ = self.h(
172
- (xy_pos, None),
173
- mask=xy_attn_mask, )
174
- logits = self.ar_predict_layer(xy_dec[:, -1])
175
- samples = topk_sampling(
176
- logits, top_k=top_k, top_p=1.0, temperature=temperature)
177
-
178
- if early_stop_num != -1 and (y.shape[1] - prefix_len
179
- ) > early_stop_num:
180
- print("use early stop num:", early_stop_num)
181
- stop = True
182
-
183
- if torch.argmax(
184
- logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
185
- # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
186
- stop = True
187
- if stop:
188
- if prompts.shape[1] == y.shape[1]:
189
- y = torch.concat([y, torch.zeros_like(samples)], dim=1)
190
- print('bad zero prediction')
191
- print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
192
- break
193
- # 本次生成的 semantic_ids 和之前的 y 构成新的 y
194
- # print(samples.shape)#[1,1]#第一个1是bs
195
- # import os
196
- # os._exit(2333)
197
- y = torch.concat([y, samples], dim=1)
198
- return y
199
-
200
- def pad_y_eos(self, y, y_mask_int, eos_id):
201
- targets = F.pad(
202
- y, (0, 1), value=0) + eos_id * F.pad(
203
- y_mask_int, (0, 1), value=1)
204
- # 错位
205
- return targets[:, :-1], targets[:, 1:]
206
-
207
- def infer_panel(self,
208
- x,#####全部文本token
209
- x_lens,
210
- prompts,####参考音频token
211
- bert_feature,
212
- top_k: int=-100,
213
- early_stop_num: int=-1,
214
- temperature: float=1.0):
215
-
216
- x = self.ar_text_embedding(x)
217
- x = x + self.bert_proj(bert_feature.transpose(1,2))
218
- x = self.ar_text_position(x)
219
-
220
- # AR Decoder
221
- y = prompts
222
- prefix_len = y.shape[1]
223
- x_len = x.shape[1]
224
- x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
225
- stop = False
226
- # print(1111111,self.num_layers)
227
- cache={
228
- "all_stage":self.num_layers,
229
- "k":[None]*self.num_layers,###根据配置自己手写
230
- "v":[None]*self.num_layers,
231
- # "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
232
- "y_emb":None,##只需要对最新的samples求emb,再拼历史的就行
233
- # "logits":None,###原版就已经只对结尾求再拼接了,不用管
234
- # "xy_dec":None,###不需要,本来只需要最后一个做logits
235
- "first_infer":1,
236
- "stage":0
237
- }
238
- for idx in tqdm(range(1500)):
239
- if(cache["first_infer"]==1):
240
- y_emb = self.ar_audio_embedding(y)
241
- else:
242
- y_emb = torch.cat([cache["y_emb"],self.ar_audio_embedding(y[:,-1:])],1)
243
- cache["y_emb"]=y_emb
244
- y_pos = self.ar_audio_position(y_emb)
245
- # x 和逐渐增长的 y 一起输入给模型
246
- if(cache["first_infer"]==1):
247
- xy_pos = torch.concat([x, y_pos], dim=1)
248
- else:
249
- xy_pos=y_pos[:,-1:]
250
- y_len = y_pos.shape[1]
251
- ###以下3个不做缓存
252
- if (cache["first_infer"] == 1):
253
- x_attn_mask_pad = F.pad(
254
- x_attn_mask,
255
- (0, y_len),###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
256
- value=True, )
257
- y_attn_mask = F.pad(###yy的右上1扩展到左边xy的0,(y,x+y)
258
- torch.triu(
259
- torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
260
- (x_len, 0),
261
- value=False, )
262
- xy_attn_mask = torch.concat(
263
- [x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
264
- else:
265
- ###最右边一列(是错的)
266
- # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
267
- # xy_attn_mask[:,-1]=False
268
- ###最下面一行(是对的)
269
- xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool, device=xy_pos.device)
270
- # pdb.set_trace()
271
- ###缓存重头戏
272
- # print(1111,xy_pos.shape,xy_attn_mask.shape,x_len,y_len)
273
- xy_dec, _ = self.h(
274
- (xy_pos, None),
275
- mask=xy_attn_mask,cache=cache )
276
- logits = self.ar_predict_layer(xy_dec[:, -1])##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的
277
- # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
278
- samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
279
- if early_stop_num != -1 and (y.shape[1] - prefix_len
280
- ) > early_stop_num:
281
- print("use early stop num:", early_stop_num)
282
- stop = True
283
-
284
- if torch.argmax(
285
- logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
286
- # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
287
- stop = True
288
- if stop:
289
- if prompts.shape[1] == y.shape[1]:
290
- y = torch.concat([y, torch.zeros_like(samples)], dim=1)
291
- print('bad zero prediction')
292
- print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
293
- break
294
- # 本次生成的 semantic_ids 和之前的 y 构成新的 y
295
- # print(samples.shape)#[1,1]#第一个1是bs
296
- y = torch.concat([y, samples], dim=1)
297
- cache["first_infer"]=0
298
- return y,idx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/models/utils.py DELETED
@@ -1,164 +0,0 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/utils.py\
2
- import torch
3
- import torch.nn.functional as F
4
- import torchaudio
5
-
6
-
7
- def sequence_mask(length, max_length=None):
8
- if max_length is None:
9
- max_length = length.max()
10
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
11
- return x.unsqueeze(0) < length.unsqueeze(1)
12
-
13
-
14
- def make_pad_mask(lengths: torch.Tensor, max_len: int=0) -> torch.Tensor:
15
- """
16
- Args:
17
- lengths:
18
- A 1-D tensor containing sentence lengths.
19
- max_len:
20
- The length of masks.
21
- Returns:
22
- Return a 2-D bool tensor, where masked positions
23
- are filled with `True` and non-masked positions are
24
- filled with `False`.
25
-
26
- #>>> lengths = torch.tensor([1, 3, 2, 5])
27
- #>>> make_pad_mask(lengths)
28
- tensor([[False, True, True, True, True],
29
- [False, False, False, True, True],
30
- [False, False, True, True, True],
31
- [False, False, False, False, False]])
32
- """
33
- assert lengths.ndim == 1, lengths.ndim
34
- max_len = max(max_len, lengths.max())
35
- n = lengths.size(0)
36
- seq_range = torch.arange(0, max_len, device=lengths.device)
37
- expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
38
-
39
- return expaned_lengths >= lengths.unsqueeze(-1)
40
-
41
-
42
- # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
43
- def top_k_top_p_filtering(logits,
44
- top_k=0,
45
- top_p=1.0,
46
- filter_value=-float("Inf"),
47
- min_tokens_to_keep=1):
48
- """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
49
- Args:
50
- logits: logits distribution shape (batch size, vocabulary size)
51
- if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
52
- if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
53
- Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
54
- Make sure we keep at least min_tokens_to_keep per batch example in the output
55
- From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
56
- """
57
- if top_k > 0:
58
- top_k = min(max(top_k, min_tokens_to_keep),
59
- logits.size(-1)) # Safety check
60
- # Remove all tokens with a probability less than the last token of the top-k
61
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
62
- logits[indices_to_remove] = filter_value
63
-
64
- if top_p < 1.0:
65
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
66
- cumulative_probs = torch.cumsum(
67
- F.softmax(sorted_logits, dim=-1), dim=-1)
68
-
69
- # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
70
- sorted_indices_to_remove = cumulative_probs > top_p
71
- if min_tokens_to_keep > 1:
72
- # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
73
- sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
74
- # Shift the indices to the right to keep also the first token above the threshold
75
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
76
- ..., :-1].clone()
77
- sorted_indices_to_remove[..., 0] = 0
78
-
79
- # scatter sorted tensors to original indexing
80
- indices_to_remove = sorted_indices_to_remove.scatter(
81
- 1, sorted_indices, sorted_indices_to_remove)
82
- logits[indices_to_remove] = filter_value
83
- return logits
84
-
85
-
86
- def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
87
- # temperature: (`optional`) float
88
- # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
89
- # top_k: (`optional`) int
90
- # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
91
- # top_p: (`optional`) float
92
- # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
93
-
94
- # Temperature (higher temperature => more likely to sample low probability tokens)
95
- if temperature != 1.0:
96
- logits = logits / temperature
97
- # Top-p/top-k filtering
98
- logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
99
- # Sample
100
- token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
101
- return token
102
-
103
-
104
- from typing import Optional, Tuple
105
- def multinomial_sample_one_no_sync(
106
- probs_sort,
107
- ): # Does multinomial sampling without a cuda synchronization
108
- q = torch.empty_like(probs_sort).exponential_(1)
109
- return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
110
-
111
-
112
- def logits_to_probs(
113
- logits,
114
- previous_tokens: Optional[torch.Tensor] = None,
115
- temperature: float = 1.0,
116
- top_k: Optional[int] = None,
117
- top_p: Optional[int] = None,
118
- repetition_penalty: float = 1.0,
119
- ):
120
- previous_tokens=previous_tokens.squeeze()
121
- # print(logits.shape,previous_tokens.shape)
122
- # pdb.set_trace()
123
- if previous_tokens is not None and repetition_penalty != 1.0:
124
- previous_tokens = previous_tokens.long()
125
- score = torch.gather(logits, dim=0, index=previous_tokens)
126
- score = torch.where(
127
- score < 0, score * repetition_penalty, score / repetition_penalty
128
- )
129
- logits.scatter_(dim=0, index=previous_tokens, src=score)
130
-
131
- if top_p is not None and top_p < 1.0:
132
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
133
- cum_probs = torch.cumsum(
134
- torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
135
- )
136
- sorted_indices_to_remove = cum_probs > top_p
137
- sorted_indices_to_remove[0] = False # keep at least one option
138
- indices_to_remove = sorted_indices_to_remove.scatter(
139
- dim=0, index=sorted_indices, src=sorted_indices_to_remove
140
- )
141
- logits = logits.masked_fill(indices_to_remove, -float("Inf"))
142
-
143
- logits = logits / max(temperature, 1e-5)
144
-
145
- if top_k is not None:
146
- v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
147
- pivot = v.select(-1, -1).unsqueeze(-1)
148
- logits = torch.where(logits < pivot, -float("Inf"), logits)
149
-
150
- probs = torch.nn.functional.softmax(logits, dim=-1)
151
- return probs
152
-
153
-
154
- def sample(
155
- logits,
156
- previous_tokens: Optional[torch.Tensor] = None,
157
- **sampling_kwargs,
158
- ) -> Tuple[torch.Tensor, torch.Tensor]:
159
- probs = logits_to_probs(
160
- logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
161
- )
162
- idx_next = multinomial_sample_one_no_sync(probs)
163
- return idx_next, probs
164
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/__init__.py DELETED
File without changes
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/activation.py DELETED
@@ -1,397 +0,0 @@
1
- # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
2
- from typing import Optional
3
- from typing import Tuple
4
- import torch
5
- from torch import Tensor
6
- from torch.nn import Linear
7
- from torch.nn import Module
8
- from torch.nn.init import constant_
9
- from torch.nn.init import xavier_normal_
10
- from torch.nn.init import xavier_uniform_
11
- from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
12
- from torch.nn.parameter import Parameter
13
-
14
- from torch.nn import functional as F
15
- from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
16
- F.multi_head_attention_forward=multi_head_attention_forward_patched
17
-
18
- class MultiheadAttention(Module):
19
- r"""Allows the model to jointly attend to information
20
- from different representation subspaces as described in the paper:
21
- `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
22
-
23
- Multi-Head Attention is defined as:
24
-
25
- .. math::
26
- \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
27
-
28
- where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
29
-
30
- ``forward()`` will use a special optimized implementation if all of the following
31
- conditions are met:
32
-
33
- - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
34
- restriction will be loosened in the future.)
35
- - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
36
- - training is disabled (using ``.eval()``)
37
- - dropout is 0
38
- - ``add_bias_kv`` is ``False``
39
- - ``add_zero_attn`` is ``False``
40
- - ``batch_first`` is ``True`` and the input is batched
41
- - ``kdim`` and ``vdim`` are equal to ``embed_dim``
42
- - at most one of ``key_padding_mask`` or ``attn_mask`` is passed
43
- - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
44
- nor ``attn_mask`` is passed
45
-
46
- If the optimized implementation is in use, a
47
- `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
48
- ``query``/``key``/``value`` to represent padding more efficiently than using a
49
- padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
50
- will be returned, and an additional speedup proportional to the fraction of the input
51
- that is padding can be expected.
52
-
53
- Args:
54
- embed_dim: Total dimension of the model.
55
- num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
56
- across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
57
- dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
58
- bias: If specified, adds bias to input / output projection layers. Default: ``True``.
59
- add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
60
- add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
61
- Default: ``False``.
62
- kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
63
- vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
64
- batch_first: If ``True``, then the input and output tensors are provided
65
- as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
66
-
67
- Examples::
68
-
69
- >>> # xdoctest: +SKIP
70
- >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
71
- >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
72
-
73
- """
74
- __constants__ = ["batch_first"]
75
- bias_k: Optional[torch.Tensor]
76
- bias_v: Optional[torch.Tensor]
77
-
78
- def __init__(
79
- self,
80
- embed_dim,
81
- num_heads,
82
- dropout=0.0,
83
- bias=True,
84
- add_bias_kv=False,
85
- add_zero_attn=False,
86
- kdim=None,
87
- vdim=None,
88
- batch_first=False,
89
- linear1_cls=Linear,
90
- linear2_cls=Linear,
91
- device=None,
92
- dtype=None, ) -> None:
93
- factory_kwargs = {"device": device, "dtype": dtype}
94
- super(MultiheadAttention, self).__init__()
95
- self.embed_dim = embed_dim
96
- self.kdim = kdim if kdim is not None else embed_dim
97
- self.vdim = vdim if vdim is not None else embed_dim
98
- self._qkv_same_embed_dim = (self.kdim == embed_dim and
99
- self.vdim == embed_dim)
100
-
101
- self.num_heads = num_heads
102
- self.dropout = dropout
103
- self.batch_first = batch_first
104
- self.head_dim = embed_dim // num_heads
105
- assert (self.head_dim * num_heads == self.embed_dim
106
- ), "embed_dim must be divisible by num_heads"
107
-
108
- if add_bias_kv:
109
- self.bias_k = Parameter(
110
- torch.empty((1, 1, embed_dim), **factory_kwargs))
111
- self.bias_v = Parameter(
112
- torch.empty((1, 1, embed_dim), **factory_kwargs))
113
- else:
114
- self.bias_k = self.bias_v = None
115
-
116
- if linear1_cls == Linear:
117
- if not self._qkv_same_embed_dim:
118
- self.q_proj_weight = Parameter(
119
- torch.empty((embed_dim, embed_dim), **factory_kwargs))
120
- self.k_proj_weight = Parameter(
121
- torch.empty((embed_dim, self.kdim), **factory_kwargs))
122
- self.v_proj_weight = Parameter(
123
- torch.empty((embed_dim, self.vdim), **factory_kwargs))
124
- self.register_parameter("in_proj_weight", None)
125
- else:
126
- self.in_proj_weight = Parameter(
127
- torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
128
- self.register_parameter("q_proj_weight", None)
129
- self.register_parameter("k_proj_weight", None)
130
- self.register_parameter("v_proj_weight", None)
131
-
132
- if bias:
133
- self.in_proj_bias = Parameter(
134
- torch.empty(3 * embed_dim, **factory_kwargs))
135
- else:
136
- self.register_parameter("in_proj_bias", None)
137
- self.out_proj = NonDynamicallyQuantizableLinear(
138
- embed_dim, embed_dim, bias=bias, **factory_kwargs)
139
-
140
- self._reset_parameters()
141
- else:
142
- if not self._qkv_same_embed_dim:
143
- raise NotImplementedError
144
- else:
145
- self.in_proj_linear = linear1_cls(
146
- embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
147
- self.in_proj_weight = self.in_proj_linear.weight
148
-
149
- self.register_parameter("q_proj_weight", None)
150
- self.register_parameter("k_proj_weight", None)
151
- self.register_parameter("v_proj_weight", None)
152
-
153
- if bias:
154
- self.in_proj_bias = self.in_proj_linear.bias
155
- else:
156
- self.register_parameter("in_proj_bias", None)
157
-
158
- self.out_proj = linear2_cls(
159
- embed_dim, embed_dim, bias=bias, **factory_kwargs)
160
-
161
- if self.bias_k is not None:
162
- xavier_normal_(self.bias_k)
163
- if self.bias_v is not None:
164
- xavier_normal_(self.bias_v)
165
-
166
- self.add_zero_attn = add_zero_attn
167
-
168
- def _reset_parameters(self):
169
- if self._qkv_same_embed_dim:
170
- xavier_uniform_(self.in_proj_weight)
171
- else:
172
- xavier_uniform_(self.q_proj_weight)
173
- xavier_uniform_(self.k_proj_weight)
174
- xavier_uniform_(self.v_proj_weight)
175
-
176
- if self.in_proj_bias is not None:
177
- constant_(self.in_proj_bias, 0.0)
178
- constant_(self.out_proj.bias, 0.0)
179
-
180
- if self.bias_k is not None:
181
- xavier_normal_(self.bias_k)
182
- if self.bias_v is not None:
183
- xavier_normal_(self.bias_v)
184
-
185
- def __setstate__(self, state):
186
- # Support loading old MultiheadAttention checkpoints generated by v1.1.0
187
- if "_qkv_same_embed_dim" not in state:
188
- state["_qkv_same_embed_dim"] = True
189
-
190
- super(MultiheadAttention, self).__setstate__(state)
191
-
192
- def forward(
193
- self,
194
- query: Tensor,
195
- key: Tensor,
196
- value: Tensor,
197
- key_padding_mask: Optional[Tensor]=None,
198
- need_weights: bool=True,
199
- attn_mask: Optional[Tensor]=None,
200
- average_attn_weights: bool=True,cache=None
201
- ) -> Tuple[Tensor, Optional[Tensor]]:
202
- r"""
203
- Args:
204
- query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
205
- or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
206
- :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
207
- Queries are compared against key-value pairs to produce the output.
208
- See "Attention Is All You Need" for more details.
209
- key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
210
- or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
211
- :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
212
- See "Attention Is All You Need" for more details.
213
- value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
214
- ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
215
- sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
216
- See "Attention Is All You Need" for more details.
217
- key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
218
- to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
219
- Binary and byte masks are supported.
220
- For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
221
- the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
222
- need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
223
- Default: ``True``.
224
- attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
225
- :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
226
- :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
227
- broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
228
- Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
229
- corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
230
- corresponding position is not allowed to attend. For a float mask, the mask values will be added to
231
- the attention weight.
232
- average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
233
- heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
234
- effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
235
-
236
- Outputs:
237
- - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
238
- :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
239
- where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
240
- embedding dimension ``embed_dim``.
241
- - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
242
- returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
243
- :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
244
- :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
245
- head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
246
-
247
- .. note::
248
- `batch_first` argument is ignored for unbatched inputs.
249
- """
250
- is_batched = query.dim() == 3
251
- if key_padding_mask is not None:
252
- _kpm_dtype = key_padding_mask.dtype
253
- if _kpm_dtype != torch.bool and not torch.is_floating_point(
254
- key_padding_mask):
255
- raise AssertionError(
256
- "only bool and floating types of key_padding_mask are supported"
257
- )
258
- why_not_fast_path = ""
259
- if not is_batched:
260
- why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
261
- elif query is not key or key is not value:
262
- # When lifting this restriction, don't forget to either
263
- # enforce that the dtypes all match or test cases where
264
- # they don't!
265
- why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
266
- elif (self.in_proj_bias is not None and
267
- query.dtype != self.in_proj_bias.dtype):
268
- why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
269
- elif (self.in_proj_weight is not None and
270
- query.dtype != self.in_proj_weight.dtype):
271
- # this case will fail anyway, but at least they'll get a useful error message.
272
- why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
273
- elif self.training:
274
- why_not_fast_path = "training is enabled"
275
- elif not self.batch_first:
276
- why_not_fast_path = "batch_first was not True"
277
- elif self.bias_k is not None:
278
- why_not_fast_path = "self.bias_k was not None"
279
- elif self.bias_v is not None:
280
- why_not_fast_path = "self.bias_v was not None"
281
- elif self.dropout:
282
- why_not_fast_path = f"dropout was {self.dropout}, required zero"
283
- elif self.add_zero_attn:
284
- why_not_fast_path = "add_zero_attn was enabled"
285
- elif not self._qkv_same_embed_dim:
286
- why_not_fast_path = "_qkv_same_embed_dim was not True"
287
- elif attn_mask is not None:
288
- why_not_fast_path = "attn_mask was not None"
289
- elif query.is_nested and key_padding_mask is not None:
290
- why_not_fast_path = (
291
- "key_padding_mask is not supported with NestedTensor input")
292
- elif self.num_heads % 2 == 1:
293
- why_not_fast_path = "num_heads is odd"
294
- elif torch.is_autocast_enabled():
295
- why_not_fast_path = "autocast is enabled"
296
-
297
- if not why_not_fast_path:
298
- tensor_args = (query, key, value, self.in_proj_weight,
299
- self.in_proj_bias, self.out_proj.weight,
300
- self.out_proj.bias, )
301
- # We have to use list comprehensions below because TorchScript does not support
302
- # generator expressions.
303
- if torch.overrides.has_torch_function(tensor_args):
304
- why_not_fast_path = "some Tensor argument has_torch_function"
305
- elif not all([(x is None or x.is_cuda or "cpu" in str(x.device))
306
- for x in tensor_args]):
307
- why_not_fast_path = (
308
- "some Tensor argument is neither CUDA nor CPU")
309
- elif torch.is_grad_enabled() and any(
310
- [x is not None and x.requires_grad for x in tensor_args]):
311
- why_not_fast_path = (
312
- "grad is enabled and at least one of query or the "
313
- "input/output projection weights or biases requires_grad")
314
- if not why_not_fast_path:
315
- return torch._native_multi_head_attention(
316
- query,
317
- key,
318
- value,
319
- self.embed_dim,
320
- self.num_heads,
321
- self.in_proj_weight,
322
- self.in_proj_bias,
323
- self.out_proj.weight,
324
- self.out_proj.bias,
325
- key_padding_mask
326
- if key_padding_mask is not None else attn_mask,
327
- need_weights,
328
- average_attn_weights,
329
- 1 if key_padding_mask is not None else 0
330
- if attn_mask is not None else None, )
331
-
332
- any_nested = query.is_nested or key.is_nested or value.is_nested
333
- assert not any_nested, (
334
- "MultiheadAttention does not support NestedTensor outside of its fast path. "
335
- + f"The fast path was not hit because {why_not_fast_path}")
336
-
337
- if self.batch_first and is_batched:
338
- # make sure that the transpose op does not affect the "is" property
339
- if key is value:
340
- if query is key:
341
- query = key = value = query.transpose(1, 0)
342
- else:
343
- query, key = [x.transpose(1, 0) for x in (query, key)]
344
- value = key
345
- else:
346
- query, key, value = [
347
- x.transpose(1, 0) for x in (query, key, value)
348
- ]
349
-
350
- if not self._qkv_same_embed_dim:
351
- attn_output, attn_output_weights = F.multi_head_attention_forward(
352
- query,
353
- key,
354
- value,
355
- self.embed_dim,
356
- self.num_heads,
357
- self.in_proj_weight,
358
- self.in_proj_bias,
359
- self.bias_k,
360
- self.bias_v,
361
- self.add_zero_attn,
362
- self.dropout,
363
- self.out_proj.weight,
364
- self.out_proj.bias,
365
- training=self.training,
366
- key_padding_mask=key_padding_mask,
367
- need_weights=need_weights,
368
- attn_mask=attn_mask,
369
- use_separate_proj_weight=True,
370
- q_proj_weight=self.q_proj_weight,
371
- k_proj_weight=self.k_proj_weight,
372
- v_proj_weight=self.v_proj_weight,
373
- average_attn_weights=average_attn_weights,cache=cache )
374
- else:
375
- attn_output, attn_output_weights = F.multi_head_attention_forward(
376
- query,
377
- key,
378
- value,
379
- self.embed_dim,
380
- self.num_heads,
381
- self.in_proj_weight,
382
- self.in_proj_bias,
383
- self.bias_k,
384
- self.bias_v,
385
- self.add_zero_attn,
386
- self.dropout,
387
- self.out_proj.weight,
388
- self.out_proj.bias,
389
- training=self.training,
390
- key_padding_mask=key_padding_mask,
391
- need_weights=need_weights,
392
- attn_mask=attn_mask,
393
- average_attn_weights=average_attn_weights,cache=cache )
394
- if self.batch_first and is_batched:
395
- return attn_output.transpose(1, 0), attn_output_weights
396
- else:
397
- return attn_output, attn_output_weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/embedding.py DELETED
@@ -1,78 +0,0 @@
1
- # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
2
- import math
3
-
4
- import torch
5
- from torch import nn
6
-
7
-
8
- class TokenEmbedding(nn.Module):
9
- def __init__(
10
- self,
11
- embedding_dim: int,
12
- vocab_size: int,
13
- dropout: float=0.0, ):
14
- super().__init__()
15
-
16
- self.vocab_size = vocab_size
17
- self.embedding_dim = embedding_dim
18
-
19
- self.dropout = torch.nn.Dropout(p=dropout)
20
- self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
21
-
22
- @property
23
- def weight(self) -> torch.Tensor:
24
- return self.word_embeddings.weight
25
-
26
- def embedding(self, index: int) -> torch.Tensor:
27
- return self.word_embeddings.weight[index:index + 1]
28
-
29
- def forward(self, x: torch.Tensor):
30
- x = self.word_embeddings(x)
31
- x = self.dropout(x)
32
- return x
33
-
34
-
35
- class SinePositionalEmbedding(nn.Module):
36
- def __init__(
37
- self,
38
- embedding_dim: int,
39
- dropout: float=0.0,
40
- scale: bool=False,
41
- alpha: bool=False, ):
42
- super().__init__()
43
- self.embedding_dim = embedding_dim
44
- self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
45
- self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
46
- self.dropout = torch.nn.Dropout(p=dropout)
47
-
48
- self.reverse = False
49
- self.pe = None
50
- self.extend_pe(torch.tensor(0.0).expand(1, 4000))
51
-
52
- def extend_pe(self, x):
53
- """Reset the positional encodings."""
54
- if self.pe is not None:
55
- if self.pe.size(1) >= x.size(1):
56
- if self.pe.dtype != x.dtype or self.pe.device != x.device:
57
- self.pe = self.pe.to(dtype=x.dtype, device=x.device)
58
- return
59
- pe = torch.zeros(x.size(1), self.embedding_dim)
60
- if self.reverse:
61
- position = torch.arange(
62
- x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
63
- else:
64
- position = torch.arange(
65
- 0, x.size(1), dtype=torch.float32).unsqueeze(1)
66
- div_term = torch.exp(
67
- torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) *
68
- -(math.log(10000.0) / self.embedding_dim))
69
- pe[:, 0::2] = torch.sin(position * div_term)
70
- pe[:, 1::2] = torch.cos(position * div_term)
71
- pe = pe.unsqueeze(0)
72
- self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
73
-
74
- def forward(self, x: torch.Tensor) -> torch.Tensor:
75
- self.extend_pe(x)
76
- output = x.unsqueeze(-1) if x.ndim == 2 else x
77
- output = output * self.x_scale + self.alpha * self.pe[:, :x.size(1)]
78
- return self.dropout(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/lr_schedulers.py DELETED
@@ -1,85 +0,0 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/lr_schedulers.py
2
- import math
3
-
4
- import torch
5
- from matplotlib import pyplot as plt
6
- from torch import nn
7
- from torch.optim import Adam
8
-
9
-
10
- class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
11
- """
12
- Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers.
13
- """
14
-
15
- def __init__(self,
16
- optimizer,
17
- init_lr,
18
- peak_lr,
19
- end_lr,
20
- warmup_steps=10000,
21
- total_steps=400000,
22
- current_step=0):
23
- self.init_lr = init_lr
24
- self.peak_lr = peak_lr
25
- self.end_lr = end_lr
26
- self.optimizer = optimizer
27
- self._warmup_rate = (peak_lr - init_lr) / warmup_steps
28
- self._decay_rate = (end_lr - peak_lr) / (total_steps - warmup_steps)
29
- self._current_step = current_step
30
- self.lr = init_lr
31
- self.warmup_steps = warmup_steps
32
- self.total_steps = total_steps
33
- self._last_lr = [self.lr]
34
-
35
- def set_lr(self, lr):
36
- self._last_lr = [g['lr'] for g in self.optimizer.param_groups]
37
- for g in self.optimizer.param_groups:
38
- # g['lr'] = lr
39
- g['lr'] = self.end_lr###锁定用线性
40
-
41
- def step(self):
42
- if self._current_step < self.warmup_steps:
43
- lr = self.init_lr + self._warmup_rate * self._current_step
44
-
45
- elif self._current_step > self.total_steps:
46
- lr = self.end_lr
47
-
48
- else:
49
- decay_ratio = (self._current_step - self.warmup_steps) / (
50
- self.total_steps - self.warmup_steps)
51
- if decay_ratio < 0.0 or decay_ratio > 1.0:
52
- raise RuntimeError(
53
- "Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
54
- )
55
- coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
56
- lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
57
-
58
- self.lr=lr=self.end_lr=0.002###锁定用线性###不听话,直接锁定!
59
- self.set_lr(lr)
60
- self.lr = lr
61
- self._current_step += 1
62
- return self.lr
63
-
64
-
65
-
66
- if __name__ == '__main__':
67
- m = nn.Linear(10, 10)
68
- opt = Adam(m.parameters(), lr=1e-4)
69
- s = WarmupCosineLRSchedule(
70
- opt,
71
- 1e-6,
72
- 2e-4,
73
- 1e-6,
74
- warmup_steps=2000,
75
- total_steps=20000,
76
- current_step=0)
77
- lrs = []
78
- for i in range(25000):
79
- s.step()
80
- lrs.append(s.lr)
81
- print(s.lr)
82
-
83
- plt.plot(lrs)
84
- plt.plot(range(0, 25000), lrs)
85
- plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/optim.py DELETED
@@ -1,622 +0,0 @@
1
- # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
- #
3
- # See ../LICENSE for clarification regarding multiple authors
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- import contextlib
17
- import logging
18
- from collections import defaultdict
19
- from typing import List
20
- from typing import Tuple
21
-
22
- import torch
23
- from torch import Tensor
24
- from torch.optim import Optimizer
25
-
26
-
27
- class BatchedOptimizer(Optimizer):
28
- """
29
- This class adds to class Optimizer the capability to optimize parameters in batches:
30
- it will stack the parameters and their grads for you so the optimizer can work
31
- on tensors with an extra leading dimension. This is intended for speed with GPUs,
32
- as it reduces the number of kernels launched in the optimizer.
33
-
34
- Args:
35
- params:
36
- """
37
-
38
- def __init__(self, params, defaults):
39
- super(BatchedOptimizer, self).__init__(params, defaults)
40
-
41
- @contextlib.contextmanager
42
- def batched_params(self, param_group, group_params_names):
43
- """
44
- This function returns (technically, yields) a list of
45
- of tuples (p, state), where
46
- p is a `fake` parameter that is stacked (over axis 0) from real parameters
47
- that share the same shape, and its gradient is also stacked;
48
- `state` is the state corresponding to this batch of parameters
49
- (it will be physically located in the "state" for one of the real
50
- parameters, the last one that has any particular shape and dtype).
51
-
52
- This function is decorated as a context manager so that it can
53
- write parameters back to their "real" locations.
54
-
55
- The idea is, instead of doing:
56
- <code>
57
- for p in group["params"]:
58
- state = self.state[p]
59
- ...
60
- </code>
61
- you can do:
62
- <code>
63
- with self.batched_params(group["params"]) as batches:
64
- for p, state, p_names in batches:
65
- ...
66
- </code>
67
-
68
- Args:
69
- group: a parameter group, which is a list of parameters; should be
70
- one of self.param_groups.
71
- group_params_names: name for each parameter in group,
72
- which is List[str].
73
- """
74
- batches = defaultdict(
75
- list
76
- ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
77
- batches_names = defaultdict(
78
- list
79
- ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
80
-
81
- assert len(param_group) == len(group_params_names)
82
- for p, named_p in zip(param_group, group_params_names):
83
- key = (str(p.dtype), *p.shape)
84
- batches[key].append(p)
85
- batches_names[key].append(named_p)
86
-
87
- batches_names_keys = list(batches_names.keys())
88
- sorted_idx = sorted(
89
- range(len(batches_names)), key=lambda i: batches_names_keys[i])
90
- batches_names = [
91
- batches_names[batches_names_keys[idx]] for idx in sorted_idx
92
- ]
93
- batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
94
-
95
- stacked_params_dict = dict()
96
-
97
- # turn batches into a list, in deterministic order.
98
- # tuples will contain tuples of (stacked_param, state, stacked_params_names),
99
- # one for each batch in `batches`.
100
- tuples = []
101
-
102
- for batch, batch_names in zip(batches, batches_names):
103
- p = batch[0]
104
- # we arbitrarily store the state in the
105
- # state corresponding to the 1st parameter in the
106
- # group. class Optimizer will take care of saving/loading state.
107
- state = self.state[p]
108
- p_stacked = torch.stack(batch)
109
- grad = torch.stack([
110
- torch.zeros_like(p) if p.grad is None else p.grad for p in batch
111
- ])
112
- p_stacked.grad = grad
113
- stacked_params_dict[key] = p_stacked
114
- tuples.append((p_stacked, state, batch_names))
115
-
116
- yield tuples # <-- calling code will do the actual optimization here!
117
-
118
- for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
119
- for i, p in enumerate(batch): # batch is list of Parameter
120
- p.copy_(stacked_params[i])
121
-
122
-
123
- class ScaledAdam(BatchedOptimizer):
124
- """
125
- Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
126
- proportional to the norm of that parameter; and also learn the scale of the parameter,
127
- in log space, subject to upper and lower limits (as if we had factored each parameter as
128
- param = underlying_param * log_scale.exp())
129
-
130
-
131
- Args:
132
- params: The parameters or param_groups to optimize (like other Optimizer subclasses)
133
- lr: The learning rate. We will typically use a learning rate schedule that starts
134
- at 0.03 and decreases over time, i.e. much higher than other common
135
- optimizers.
136
- clipping_scale: (e.g. 2.0)
137
- A scale for gradient-clipping: if specified, the normalized gradients
138
- over the whole model will be clipped to have 2-norm equal to
139
- `clipping_scale` times the median 2-norm over the most recent period
140
- of `clipping_update_period` minibatches. By "normalized gradients",
141
- we mean after multiplying by the rms parameter value for this tensor
142
- [for non-scalars]; this is appropriate because our update is scaled
143
- by this quantity.
144
- betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
145
- Must satisfy 0 < beta <= beta2 < 1.
146
- scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
147
- scale of each parameter tensor and scalar parameters of the mode..
148
- If each parameter were decomposed
149
- as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
150
- would be a the scaling factor on the learning rate of p_scale.
151
- eps: A general-purpose epsilon to prevent division by zero
152
- param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
153
- learning the scale on the parameters (we'll constrain the rms of each non-scalar
154
- parameter tensor to be >= this value)
155
- param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
156
- learning the scale on the parameters (we'll constrain the rms of each non-scalar
157
- parameter tensor to be <= this value)
158
- scalar_max: Maximum absolute value for scalar parameters (applicable if your
159
- model has any parameters with numel() == 1).
160
- size_update_period: The periodicity, in steps, with which we update the size (scale)
161
- of the parameter tensor. This is provided to save a little time
162
- in the update.
163
- clipping_update_period: if clipping_scale is specified, this is the period
164
- """
165
-
166
- def __init__(
167
- self,
168
- params,
169
- lr=3e-02,
170
- clipping_scale=None,
171
- betas=(0.9, 0.98),
172
- scalar_lr_scale=0.1,
173
- eps=1.0e-08,
174
- param_min_rms=1.0e-05,
175
- param_max_rms=3.0,
176
- scalar_max=10.0,
177
- size_update_period=4,
178
- clipping_update_period=100,
179
- parameters_names=None,
180
- show_dominant_parameters=True, ):
181
-
182
- assert parameters_names is not None, (
183
- "Please prepare parameters_names,"
184
- "which is a List[List[str]]. Each List[str] is for a group"
185
- "and each str is for a parameter")
186
- defaults = dict(
187
- lr=lr,
188
- clipping_scale=clipping_scale,
189
- betas=betas,
190
- scalar_lr_scale=scalar_lr_scale,
191
- eps=eps,
192
- param_min_rms=param_min_rms,
193
- param_max_rms=param_max_rms,
194
- scalar_max=scalar_max,
195
- size_update_period=size_update_period,
196
- clipping_update_period=clipping_update_period, )
197
-
198
- super(ScaledAdam, self).__init__(params, defaults)
199
- assert len(self.param_groups) == len(parameters_names)
200
- self.parameters_names = parameters_names
201
- self.show_dominant_parameters = show_dominant_parameters
202
-
203
- def __setstate__(self, state):
204
- super(ScaledAdam, self).__setstate__(state)
205
-
206
- @torch.no_grad()
207
- def step(self, closure=None):
208
- """Performs a single optimization step.
209
-
210
- Arguments:
211
- closure (callable, optional): A closure that reevaluates the model
212
- and returns the loss.
213
- """
214
- loss = None
215
- if closure is not None:
216
- with torch.enable_grad():
217
- loss = closure()
218
-
219
- batch = True
220
-
221
- for group, group_params_names in zip(self.param_groups,
222
- self.parameters_names):
223
-
224
- with self.batched_params(group["params"],
225
- group_params_names) as batches:
226
-
227
- # batches is list of pairs (stacked_param, state). stacked_param is like
228
- # a regular parameter, and will have a .grad, but the 1st dim corresponds to
229
- # a stacking dim, it is not a real dim.
230
-
231
- if (len(batches[0][1]) ==
232
- 0): # if len(first state) == 0: not yet initialized
233
- clipping_scale = 1
234
- else:
235
- clipping_scale = self._get_clipping_scale(group, batches)
236
-
237
- for p, state, _ in batches:
238
- # Perform optimization step.
239
- # grad is not going to be None, we handled that when creating the batches.
240
- grad = p.grad
241
- if grad.is_sparse:
242
- raise RuntimeError(
243
- "ScaledAdam optimizer does not support sparse gradients"
244
- )
245
- # State initialization
246
- if len(state) == 0:
247
- self._init_state(group, p, state)
248
-
249
- self._step_one_batch(group, p, state, clipping_scale)
250
-
251
- return loss
252
-
253
- def _init_state(self, group: dict, p: Tensor, state: dict):
254
- """
255
- Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
256
- is actually the batch dimension, corresponding to batched-together
257
- parameters of a given shape.
258
-
259
-
260
- Args:
261
- group: Dict to look up configuration values.
262
- p: The parameter that we are initializing the state for
263
- state: Dict from string to whatever state we are initializing
264
- """
265
- size_update_period = group["size_update_period"]
266
-
267
- state["step"] = 0
268
-
269
- kwargs = {"device": p.device, "dtype": p.dtype}
270
-
271
- # 'delta' implements conventional momentum. There are
272
- # several different kinds of update going on, so rather than
273
- # compute "exp_avg" like in Adam, we store and decay a
274
- # parameter-change "delta", which combines all forms of
275
- # update. this is equivalent to how it's done in Adam,
276
- # except for the first few steps.
277
- state["delta"] = torch.zeros_like(
278
- p, memory_format=torch.preserve_format)
279
-
280
- batch_size = p.shape[0]
281
- numel = p.numel() // batch_size
282
- numel = p.numel()
283
-
284
- if numel > 1:
285
- # "param_rms" just periodically records the scalar root-mean-square value of
286
- # the parameter tensor.
287
- # it has a shape like (batch_size, 1, 1, 1, 1)
288
- param_rms = (
289
- (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
290
- state["param_rms"] = param_rms
291
-
292
- state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
293
- state["scale_grads"] = torch.zeros(size_update_period,
294
- *param_rms.shape, **kwargs)
295
-
296
- # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
297
- state["exp_avg_sq"] = torch.zeros_like(
298
- p, memory_format=torch.preserve_format)
299
-
300
- def _get_clipping_scale(self,
301
- group: dict,
302
- tuples: List[Tuple[Tensor, dict, List[str]]]
303
- ) -> float:
304
- """
305
- Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
306
- by this amount before applying the rest of the update.
307
-
308
- Args:
309
- group: the parameter group, an item in self.param_groups
310
- tuples: a list of tuples of (param, state, param_names)
311
- where param is a batched set of parameters,
312
- with a .grad (1st dim is batch dim)
313
- and state is the state-dict where optimization parameters are kept.
314
- param_names is a List[str] while each str is name for a parameter
315
- in batched set of parameters "param".
316
- """
317
- assert len(tuples) >= 1
318
- clipping_scale = group["clipping_scale"]
319
- (first_p, first_state, _) = tuples[0]
320
- step = first_state["step"]
321
- if clipping_scale is None or step == 0:
322
- # no clipping. return early on step == 0 because the other
323
- # parameters' state won't have been initialized yet.
324
- return 1.0
325
- clipping_update_period = group["clipping_update_period"]
326
-
327
- tot_sumsq = torch.tensor(0.0, device=first_p.device)
328
- for (p, state, param_names) in tuples:
329
- grad = p.grad
330
- if grad.is_sparse:
331
- raise RuntimeError(
332
- "ScaledAdam optimizer does not support sparse gradients")
333
- if p.numel() == p.shape[0]: # a batch of scalars
334
- tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
335
- else:
336
- tot_sumsq += ((grad * state["param_rms"])**2).sum()
337
-
338
- tot_norm = tot_sumsq.sqrt()
339
- if "model_norms" not in first_state:
340
- first_state["model_norms"] = torch.zeros(
341
- clipping_update_period, device=p.device)
342
- first_state["model_norms"][step % clipping_update_period] = tot_norm
343
-
344
- if step % clipping_update_period == 0:
345
- # Print some stats.
346
- # We don't reach here if step == 0 because we would have returned
347
- # above.
348
- sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
349
- quartiles = []
350
- for n in range(0, 5):
351
- index = min(
352
- clipping_update_period - 1,
353
- (clipping_update_period // 4) * n, )
354
- quartiles.append(sorted_norms[index].item())
355
-
356
- median = quartiles[2]
357
- threshold = clipping_scale * median
358
- first_state["model_norm_threshold"] = threshold
359
- percent_clipped = (first_state["num_clipped"] * 100.0 /
360
- clipping_update_period
361
- if "num_clipped" in first_state else 0.0)
362
- first_state["num_clipped"] = 0
363
- quartiles = " ".join(["%.3e" % x for x in quartiles])
364
- logging.info(
365
- f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
366
- f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
367
- )
368
-
369
- if step < clipping_update_period:
370
- return 1.0 # We have not yet estimated a norm to clip to.
371
- else:
372
- try:
373
- model_norm_threshold = first_state["model_norm_threshold"]
374
- except KeyError:
375
- logging.info(
376
- "Warning: model_norm_threshold not in state: possibly "
377
- "you changed config when restarting, adding clipping_scale option?"
378
- )
379
- return 1.0
380
- ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
381
- if ans < 1.0:
382
- first_state["num_clipped"] += 1
383
- if ans < 0.1:
384
- logging.warn(
385
- f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
386
- )
387
- if self.show_dominant_parameters:
388
- assert p.shape[0] == len(param_names)
389
- self._show_gradient_dominating_parameter(tuples, tot_sumsq)
390
- return ans
391
-
392
- def _show_gradient_dominating_parameter(
393
- self, tuples: List[Tuple[Tensor, dict, List[str]]],
394
- tot_sumsq: Tensor):
395
- """
396
- Show information of parameter wihch dominanting tot_sumsq.
397
-
398
- Args:
399
- tuples: a list of tuples of (param, state, param_names)
400
- where param is a batched set of parameters,
401
- with a .grad (1st dim is batch dim)
402
- and state is the state-dict where optimization parameters are kept.
403
- param_names is a List[str] while each str is name for a parameter
404
- in batched set of parameters "param".
405
- tot_sumsq: sumsq of all parameters. Though it's could be calculated
406
- from tuples, we still pass it to save some time.
407
- """
408
- all_sumsq_orig = {}
409
- for (p, state, batch_param_names) in tuples:
410
- # p is a stacked batch parameters.
411
- batch_grad = p.grad
412
- if p.numel() == p.shape[0]: # a batch of scalars
413
- batch_sumsq_orig = batch_grad**2
414
- # Dummpy values used by following `zip` statement.
415
- batch_rms_orig = torch.ones(p.shape[0])
416
- else:
417
- batch_rms_orig = state["param_rms"]
418
- batch_sumsq_orig = ((batch_grad * batch_rms_orig)**2).sum(
419
- dim=list(range(1, batch_grad.ndim)))
420
-
421
- for name, sumsq_orig, rms, grad in zip(batch_param_names,
422
- batch_sumsq_orig,
423
- batch_rms_orig, batch_grad):
424
-
425
- proportion_orig = sumsq_orig / tot_sumsq
426
- all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
427
-
428
- assert torch.isclose(
429
- sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
430
- torch.tensor(1.0), )
431
- sorted_by_proportion = {
432
- k: v
433
- for k, v in sorted(
434
- all_sumsq_orig.items(),
435
- key=lambda item: item[1][0],
436
- reverse=True, )
437
- }
438
- dominant_param_name = next(iter(sorted_by_proportion))
439
- (dominant_proportion, dominant_sumsq, dominant_rms,
440
- dominant_grad, ) = sorted_by_proportion[dominant_param_name]
441
- logging.info(f"Parameter Dominanting tot_sumsq {dominant_param_name}"
442
- f" with proportion {dominant_proportion:.2f},"
443
- f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
444
- f"={dominant_sumsq:.3e},"
445
- f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
446
- f" orig_rms_sq={(dominant_rms**2).item():.3e}")
447
-
448
- def _step_one_batch(self,
449
- group: dict,
450
- p: Tensor,
451
- state: dict,
452
- clipping_scale: float):
453
- """
454
- Do the step for one parameter, which is actually going to be a batch of
455
- `real` parameters, with dim 0 as the batch dim.
456
- Args:
457
- group: dict to look up configuration values
458
- p: parameter to update (actually multiple parameters stacked together
459
- as a batch)
460
- state: state-dict for p, to look up the optimizer state
461
- """
462
- lr = group["lr"]
463
- size_update_period = group["size_update_period"]
464
- beta1 = group["betas"][0]
465
-
466
- grad = p.grad
467
- if clipping_scale != 1.0:
468
- grad = grad * clipping_scale
469
- step = state["step"]
470
- delta = state["delta"]
471
-
472
- delta.mul_(beta1)
473
- batch_size = p.shape[0]
474
- numel = p.numel() // batch_size
475
- if numel > 1:
476
- # Update the size/scale of p, and set param_rms
477
- scale_grads = state["scale_grads"]
478
- scale_grads[step % size_update_period] = (p * grad).sum(
479
- dim=list(range(1, p.ndim)), keepdim=True)
480
- if step % size_update_period == size_update_period - 1:
481
- param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
482
- param_rms.copy_((p**2)
483
- .mean(dim=list(range(1, p.ndim)), keepdim=True)
484
- .sqrt())
485
- if step > 0:
486
- # self._size_update() learns the overall scale on the
487
- # parameter, by shrinking or expanding it.
488
- self._size_update(group, scale_grads, p, state)
489
-
490
- if numel == 1:
491
- # For parameters with 1 element we just use regular Adam.
492
- # Updates delta.
493
- self._step_scalar(group, p, state)
494
- else:
495
- self._step(group, p, state)
496
-
497
- state["step"] = step + 1
498
-
499
- def _size_update(self,
500
- group: dict,
501
- scale_grads: Tensor,
502
- p: Tensor,
503
- state: dict) -> None:
504
- """
505
- Called only where p.numel() > 1, this updates the scale of the parameter.
506
- If we imagine: p = underlying_param * scale.exp(), and we are doing
507
- gradient descent on underlying param and on scale, this function does the update
508
- on `scale`.
509
-
510
- Args:
511
- group: dict to look up configuration values
512
- scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
513
- grads w.r.t. the scales.
514
- p: The parameter to update
515
- state: The state-dict of p
516
- """
517
-
518
- param_rms = state["param_rms"]
519
- beta1, beta2 = group["betas"]
520
- size_lr = group["lr"] * group["scalar_lr_scale"]
521
- param_min_rms = group["param_min_rms"]
522
- param_max_rms = group["param_max_rms"]
523
- eps = group["eps"]
524
- step = state["step"]
525
- batch_size = p.shape[0]
526
-
527
- size_update_period = scale_grads.shape[0]
528
- # correct beta2 for the size update period: we will have
529
- # faster decay at this level.
530
- beta2_corr = beta2**size_update_period
531
-
532
- scale_exp_avg_sq = state[
533
- "scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
534
- scale_exp_avg_sq.mul_(beta2_corr).add_(
535
- (scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
536
- alpha=1 - beta2_corr, ) # shape is (batch_size, 1, 1, ...)
537
-
538
- # The 1st time we reach here is when size_step == 1.
539
- size_step = (step + 1) // size_update_period
540
- bias_correction2 = 1 - beta2_corr**size_step
541
- # we don't bother with bias_correction1; this will help prevent divergence
542
- # at the start of training.
543
-
544
- denom = scale_exp_avg_sq.sqrt() + eps
545
-
546
- scale_step = (-size_lr * (bias_correction2**0.5) *
547
- scale_grads.sum(dim=0) / denom)
548
-
549
- is_too_small = param_rms < param_min_rms
550
- is_too_large = param_rms > param_max_rms
551
-
552
- # when the param gets too small, just don't shrink it any further.
553
- scale_step.masked_fill_(is_too_small, 0.0)
554
- # when it gets too large, stop it from getting any larger.
555
- scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
556
- delta = state["delta"]
557
- # the factor of (1-beta1) relates to momentum.
558
- delta.add_(p * scale_step, alpha=(1 - beta1))
559
-
560
- def _step(self, group: dict, p: Tensor, state: dict):
561
- """
562
- This function does the core update of self.step(), in the case where the members of
563
- the batch have more than 1 element.
564
-
565
- Args:
566
- group: A dict which will be used to look up configuration values
567
- p: The parameter to be updated
568
- grad: The grad of p
569
- state: The state-dict corresponding to parameter p
570
-
571
- This function modifies p.
572
- """
573
- grad = p.grad
574
- lr = group["lr"]
575
- beta1, beta2 = group["betas"]
576
- eps = group["eps"]
577
- param_min_rms = group["param_min_rms"]
578
- step = state["step"]
579
-
580
- exp_avg_sq = state["exp_avg_sq"]
581
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
582
-
583
- this_step = state["step"] - (state["zero_step"]
584
- if "zero_step" in state else 0)
585
- bias_correction2 = 1 - beta2**(this_step + 1)
586
- if bias_correction2 < 0.99:
587
- # note: not in-place.
588
- exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
589
-
590
- denom = exp_avg_sq.sqrt()
591
- denom += eps
592
- grad = grad / denom
593
-
594
- alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
595
-
596
- delta = state["delta"]
597
- delta.add_(grad * alpha)
598
- p.add_(delta)
599
-
600
- def _step_scalar(self, group: dict, p: Tensor, state: dict):
601
- """
602
- A simplified form of the core update for scalar tensors, where we cannot get a good
603
- estimate of the parameter rms.
604
- """
605
- beta1, beta2 = group["betas"]
606
- scalar_max = group["scalar_max"]
607
- eps = group["eps"]
608
- lr = group["lr"] * group["scalar_lr_scale"]
609
- grad = p.grad
610
-
611
- exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
612
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
613
-
614
- # bias_correction2 is like in Adam. Don't bother with bias_correction1;
615
- # slower update at the start will help stability anyway.
616
- bias_correction2 = 1 - beta2**(state["step"] + 1)
617
- denom = (exp_avg_sq / bias_correction2).sqrt() + eps
618
-
619
- delta = state["delta"]
620
- delta.add_(grad / denom, alpha=-lr * (1 - beta1))
621
- p.clamp_(min=-scalar_max, max=scalar_max)
622
- p.add_(delta)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/patched_mha_with_cache.py DELETED
@@ -1,388 +0,0 @@
1
- from torch.nn.functional import *
2
- from torch.nn.functional import _mha_shape_check,_canonical_mask,_none_or_dtype,_in_projection_packed
3
- # import torch
4
- # Tensor = torch.Tensor
5
- # from typing import Callable, List, Optional, Tuple, Union
6
-
7
- def multi_head_attention_forward_patched(
8
- query: Tensor,
9
- key: Tensor,
10
- value: Tensor,
11
- embed_dim_to_check: int,
12
- num_heads: int,
13
- in_proj_weight: Optional[Tensor],
14
- in_proj_bias: Optional[Tensor],
15
- bias_k: Optional[Tensor],
16
- bias_v: Optional[Tensor],
17
- add_zero_attn: bool,
18
- dropout_p: float,
19
- out_proj_weight: Tensor,
20
- out_proj_bias: Optional[Tensor],
21
- training: bool = True,
22
- key_padding_mask: Optional[Tensor] = None,
23
- need_weights: bool = True,
24
- attn_mask: Optional[Tensor] = None,
25
- use_separate_proj_weight: bool = False,
26
- q_proj_weight: Optional[Tensor] = None,
27
- k_proj_weight: Optional[Tensor] = None,
28
- v_proj_weight: Optional[Tensor] = None,
29
- static_k: Optional[Tensor] = None,
30
- static_v: Optional[Tensor] = None,
31
- average_attn_weights: bool = True,
32
- is_causal: bool = False,cache=None
33
- ) -> Tuple[Tensor, Optional[Tensor]]:
34
- r"""
35
- Args:
36
- query, key, value: map a query and a set of key-value pairs to an output.
37
- See "Attention Is All You Need" for more details.
38
- embed_dim_to_check: total dimension of the model.
39
- num_heads: parallel attention heads.
40
- in_proj_weight, in_proj_bias: input projection weight and bias.
41
- bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
42
- add_zero_attn: add a new batch of zeros to the key and
43
- value sequences at dim=1.
44
- dropout_p: probability of an element to be zeroed.
45
- out_proj_weight, out_proj_bias: the output projection weight and bias.
46
- training: apply dropout if is ``True``.
47
- key_padding_mask: if provided, specified padding elements in the key will
48
- be ignored by the attention. This is an binary mask. When the value is True,
49
- the corresponding value on the attention layer will be filled with -inf.
50
- need_weights: output attn_output_weights.
51
- Default: `True`
52
- Note: `needs_weight` defaults to `True`, but should be set to `False`
53
- For best performance when attention weights are not nedeeded.
54
- *Setting needs_weights to `True`
55
- leads to a significant performance degradation.*
56
- attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
57
- the batches while a 3D mask allows to specify a different mask for the entries of each batch.
58
- is_causal: If specified, applies a causal mask as attention mask, and ignores
59
- attn_mask for computing scaled dot product attention.
60
- Default: ``False``.
61
- .. warning::
62
- is_causal is provides a hint that the attn_mask is the
63
- causal mask.Providing incorrect hints can result in
64
- incorrect execution, including forward and backward
65
- compatibility.
66
- use_separate_proj_weight: the function accept the proj. weights for query, key,
67
- and value in different forms. If false, in_proj_weight will be used, which is
68
- a combination of q_proj_weight, k_proj_weight, v_proj_weight.
69
- q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
70
- static_k, static_v: static key and value used for attention operators.
71
- average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
72
- Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
73
- when ``need_weights=True.``. Default: True
74
-
75
-
76
- Shape:
77
- Inputs:
78
- - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
79
- the embedding dimension.
80
- - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
81
- the embedding dimension.
82
- - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
83
- the embedding dimension.
84
- - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
85
- If a FloatTensor is provided, it will be directly added to the value.
86
- If a BoolTensor is provided, the positions with the
87
- value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
88
- - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
89
- 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
90
- S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
91
- positions. If a BoolTensor is provided, positions with ``True``
92
- are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
93
- is provided, it will be added to the attention weight.
94
- - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
95
- N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
96
- - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
97
- N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
98
-
99
- Outputs:
100
- - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
101
- E is the embedding dimension.
102
- - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
103
- attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
104
- :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
105
- :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
106
- head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
107
- """
108
- tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
109
- if has_torch_function(tens_ops):
110
- return handle_torch_function(
111
- multi_head_attention_forward,
112
- tens_ops,
113
- query,
114
- key,
115
- value,
116
- embed_dim_to_check,
117
- num_heads,
118
- in_proj_weight,
119
- in_proj_bias,
120
- bias_k,
121
- bias_v,
122
- add_zero_attn,
123
- dropout_p,
124
- out_proj_weight,
125
- out_proj_bias,
126
- training=training,
127
- key_padding_mask=key_padding_mask,
128
- need_weights=need_weights,
129
- attn_mask=attn_mask,
130
- is_causal=is_causal,
131
- use_separate_proj_weight=use_separate_proj_weight,
132
- q_proj_weight=q_proj_weight,
133
- k_proj_weight=k_proj_weight,
134
- v_proj_weight=v_proj_weight,
135
- static_k=static_k,
136
- static_v=static_v,
137
- average_attn_weights=average_attn_weights,cache=cache
138
- )
139
-
140
- is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
141
-
142
- # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
143
- # is batched, run the computation and before returning squeeze the
144
- # batch dimension so that the output doesn't carry this temporary batch dimension.
145
- if not is_batched:
146
- # unsqueeze if the input is unbatched
147
- query = query.unsqueeze(1)
148
- key = key.unsqueeze(1)
149
- value = value.unsqueeze(1)
150
- if key_padding_mask is not None:
151
- key_padding_mask = key_padding_mask.unsqueeze(0)
152
-
153
- # set up shape vars
154
- tgt_len, bsz, embed_dim = query.shape
155
- src_len, _, _ = key.shape
156
-
157
- key_padding_mask = _canonical_mask(
158
- mask=key_padding_mask,
159
- mask_name="key_padding_mask",
160
- other_type=_none_or_dtype(attn_mask),
161
- other_name="attn_mask",
162
- target_type=query.dtype
163
- )
164
-
165
- if is_causal and attn_mask is None:
166
- raise RuntimeError(
167
- "Need attn_mask if specifying the is_causal hint. "
168
- "You may use the Transformer module method "
169
- "`generate_square_subsequent_mask` to create this mask."
170
- )
171
-
172
- if is_causal and key_padding_mask is None and not need_weights:
173
- # when we have a kpm or need weights, we need attn_mask
174
- # Otherwise, we use the is_causal hint go as is_causal
175
- # indicator to SDPA.
176
- attn_mask = None
177
- else:
178
- attn_mask = _canonical_mask(
179
- mask=attn_mask,
180
- mask_name="attn_mask",
181
- other_type=None,
182
- other_name="",
183
- target_type=query.dtype,
184
- check_other=False,
185
- )
186
-
187
-
188
- if key_padding_mask is not None:
189
- # We have the attn_mask, and use that to merge kpm into it.
190
- # Turn off use of is_causal hint, as the merged mask is no
191
- # longer causal.
192
- is_causal = False
193
-
194
- assert embed_dim == embed_dim_to_check, \
195
- f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
196
- if isinstance(embed_dim, torch.Tensor):
197
- # embed_dim can be a tensor when JIT tracing
198
- head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
199
- else:
200
- head_dim = embed_dim // num_heads
201
- assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
202
- if use_separate_proj_weight:
203
- # allow MHA to have different embedding dimensions when separate projection weights are used
204
- assert key.shape[:2] == value.shape[:2], \
205
- f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
206
- else:
207
- assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
208
-
209
- #
210
- # compute in-projection
211
- #
212
- if not use_separate_proj_weight:
213
- assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
214
- q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
215
- else:
216
- assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
217
- assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
218
- assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
219
- if in_proj_bias is None:
220
- b_q = b_k = b_v = None
221
- else:
222
- b_q, b_k, b_v = in_proj_bias.chunk(3)
223
- q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
224
- if(cache!=None):
225
- if(cache["first_infer"]==1):
226
- cache["k"][cache["stage"]]=k
227
- # print(0,cache["k"].shape)
228
- cache["v"][cache["stage"]]=v
229
- else:###12个layer每个都要留自己的cache_kv
230
- # print(1,cache["k"].shape)
231
- cache["k"][cache["stage"]]=torch.cat([cache["k"][cache["stage"]],k],0)##本来时序是1,但是proj的时候可能transpose了所以时序到0维了
232
- cache["v"][cache["stage"]]=torch.cat([cache["v"][cache["stage"]],v],0)
233
- # print(2, cache["k"].shape)
234
- src_len = cache["k"][cache["stage"]].shape[0]
235
- k=cache["k"][cache["stage"]]
236
- v=cache["v"][cache["stage"]]
237
- # if attn_mask is not None:
238
- # attn_mask=attn_mask[-1:,]
239
- # print(attn_mask.shape,attn_mask)
240
- cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
241
- # print(2333,cache)
242
- # prep attention mask
243
-
244
- attn_mask = _canonical_mask(
245
- mask=attn_mask,
246
- mask_name="attn_mask",
247
- other_type=None,
248
- other_name="",
249
- target_type=q.dtype,
250
- check_other=False,
251
- )
252
-
253
- if attn_mask is not None:
254
- # ensure attn_mask's dim is 3
255
- if attn_mask.dim() == 2:
256
- correct_2d_size = (tgt_len, src_len)
257
- if attn_mask.shape != correct_2d_size:
258
- raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
259
- attn_mask = attn_mask.unsqueeze(0)
260
- elif attn_mask.dim() == 3:
261
- correct_3d_size = (bsz * num_heads, tgt_len, src_len)
262
- if attn_mask.shape != correct_3d_size:
263
- raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
264
- else:
265
- raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
266
-
267
- # add bias along batch dimension (currently second)
268
- if bias_k is not None and bias_v is not None:
269
- assert static_k is None, "bias cannot be added to static key."
270
- assert static_v is None, "bias cannot be added to static value."
271
- k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
272
- v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
273
- if attn_mask is not None:
274
- attn_mask = pad(attn_mask, (0, 1))
275
- if key_padding_mask is not None:
276
- key_padding_mask = pad(key_padding_mask, (0, 1))
277
- else:
278
- assert bias_k is None
279
- assert bias_v is None
280
-
281
- #
282
- # reshape q, k, v for multihead attention and make em batch first
283
- #
284
- q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
285
- if static_k is None:
286
- k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
287
- else:
288
- # TODO finish disentangling control flow so we don't do in-projections when statics are passed
289
- assert static_k.size(0) == bsz * num_heads, \
290
- f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
291
- assert static_k.size(2) == head_dim, \
292
- f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
293
- k = static_k
294
- if static_v is None:
295
- v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
296
- else:
297
- # TODO finish disentangling control flow so we don't do in-projections when statics are passed
298
- assert static_v.size(0) == bsz * num_heads, \
299
- f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
300
- assert static_v.size(2) == head_dim, \
301
- f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
302
- v = static_v
303
-
304
- # add zero attention along batch dimension (now first)
305
- if add_zero_attn:
306
- zero_attn_shape = (bsz * num_heads, 1, head_dim)
307
- k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
308
- v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
309
- if attn_mask is not None:
310
- attn_mask = pad(attn_mask, (0, 1))
311
- if key_padding_mask is not None:
312
- key_padding_mask = pad(key_padding_mask, (0, 1))
313
-
314
- # update source sequence length after adjustments
315
- src_len = k.size(1)
316
-
317
- # merge key padding and attention masks
318
- if key_padding_mask is not None:
319
- assert key_padding_mask.shape == (bsz, src_len), \
320
- f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
321
- key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
322
- expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
323
- if attn_mask is None:
324
- attn_mask = key_padding_mask
325
- else:
326
- attn_mask = attn_mask + key_padding_mask
327
-
328
- # adjust dropout probability
329
- if not training:
330
- dropout_p = 0.0
331
-
332
- #
333
- # (deep breath) calculate attention and out projection
334
- #
335
-
336
- if need_weights:
337
- B, Nt, E = q.shape
338
- q_scaled = q / math.sqrt(E)
339
-
340
- assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
341
-
342
- if attn_mask is not None:
343
- attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
344
- else:
345
- attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
346
- attn_output_weights = softmax(attn_output_weights, dim=-1)
347
- if dropout_p > 0.0:
348
- attn_output_weights = dropout(attn_output_weights, p=dropout_p)
349
-
350
- attn_output = torch.bmm(attn_output_weights, v)
351
-
352
- attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
353
- attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
354
- attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
355
-
356
- # optionally average attention weights over heads
357
- attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
358
- if average_attn_weights:
359
- attn_output_weights = attn_output_weights.mean(dim=1)
360
-
361
- if not is_batched:
362
- # squeeze the output if input was unbatched
363
- attn_output = attn_output.squeeze(1)
364
- attn_output_weights = attn_output_weights.squeeze(0)
365
- return attn_output, attn_output_weights
366
- else:
367
- # attn_mask can be either (L,S) or (N*num_heads, L, S)
368
- # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
369
- # in order to match the input for SDPA of (N, num_heads, L, S)
370
- if attn_mask is not None:
371
- if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
372
- attn_mask = attn_mask.unsqueeze(0)
373
- else:
374
- attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
375
-
376
- q = q.view(bsz, num_heads, tgt_len, head_dim)
377
- k = k.view(bsz, num_heads, src_len, head_dim)
378
- v = v.view(bsz, num_heads, src_len, head_dim)
379
-
380
- attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
381
- attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
382
-
383
- attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
384
- attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
385
- if not is_batched:
386
- # squeeze the output if input was unbatched
387
- attn_output = attn_output.squeeze(1)
388
- return attn_output, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/scaling.py DELETED
@@ -1,319 +0,0 @@
1
- # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
- #
3
- # See ../../../../LICENSE for clarification regarding multiple authors
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- import logging
17
- import math
18
- import random
19
- from typing import Optional
20
- from typing import Tuple
21
- from typing import Union
22
-
23
- import torch
24
- import torch.nn as nn
25
- from torch import Tensor
26
-
27
-
28
- class DoubleSwishFunction(torch.autograd.Function):
29
- """
30
- double_swish(x) = x * torch.sigmoid(x-1)
31
- This is a definition, originally motivated by its close numerical
32
- similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
33
-
34
- Memory-efficient derivative computation:
35
- double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
36
- double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
37
- Now, s'(x) = s(x) * (1-s(x)).
38
- double_swish'(x) = x * s'(x) + s(x).
39
- = x * s(x) * (1-s(x)) + s(x).
40
- = double_swish(x) * (1-s(x)) + s(x)
41
- ... so we just need to remember s(x) but not x itself.
42
- """
43
-
44
- @staticmethod
45
- def forward(ctx, x: Tensor) -> Tensor:
46
- requires_grad = x.requires_grad
47
- x_dtype = x.dtype
48
- if x.dtype == torch.float16:
49
- x = x.to(torch.float32)
50
-
51
- s = torch.sigmoid(x - 1.0)
52
- y = x * s
53
-
54
- if requires_grad:
55
- deriv = y * (1 - s) + s
56
- # notes on derivative of x * sigmoid(x - 1):
57
- # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
58
- # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
59
- # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
60
- # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
61
- # floors), should be expectation-preserving.
62
- floor = -0.043637
63
- ceil = 1.2
64
- d_scaled = (deriv - floor) * (255.0 / (ceil - floor)
65
- ) + torch.rand_like(deriv)
66
- if __name__ == "__main__":
67
- # for self-testing only.
68
- assert d_scaled.min() >= 0.0
69
- assert d_scaled.max() < 256.0
70
- d_int = d_scaled.to(torch.uint8)
71
- ctx.save_for_backward(d_int)
72
- if x.dtype == torch.float16 or torch.is_autocast_enabled():
73
- y = y.to(torch.float16)
74
- return y
75
-
76
- @staticmethod
77
- def backward(ctx, y_grad: Tensor) -> Tensor:
78
- (d, ) = ctx.saved_tensors
79
- # the same constants as used in forward pass.
80
- floor = -0.043637
81
- ceil = 1.2
82
- d = d * ((ceil - floor) / 255.0) + floor
83
- return y_grad * d
84
-
85
-
86
- class DoubleSwish(torch.nn.Module):
87
- def forward(self, x: Tensor) -> Tensor:
88
- """Return double-swish activation function which is an approximation to Swish(Swish(x)),
89
- that we approximate closely with x * sigmoid(x-1).
90
- """
91
- if torch.jit.is_scripting() or torch.jit.is_tracing():
92
- return x * torch.sigmoid(x - 1.0)
93
- return DoubleSwishFunction.apply(x)
94
-
95
-
96
- class ActivationBalancerFunction(torch.autograd.Function):
97
- @staticmethod
98
- def forward(
99
- ctx,
100
- x: Tensor,
101
- scale_factor: Tensor,
102
- sign_factor: Optional[Tensor],
103
- channel_dim: int, ) -> Tensor:
104
- if channel_dim < 0:
105
- channel_dim += x.ndim
106
- ctx.channel_dim = channel_dim
107
- xgt0 = x > 0
108
- if sign_factor is None:
109
- ctx.save_for_backward(xgt0, scale_factor)
110
- else:
111
- ctx.save_for_backward(xgt0, scale_factor, sign_factor)
112
- return x
113
-
114
- @staticmethod
115
- def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
116
- if len(ctx.saved_tensors) == 3:
117
- xgt0, scale_factor, sign_factor = ctx.saved_tensors
118
- for _ in range(ctx.channel_dim, x_grad.ndim - 1):
119
- scale_factor = scale_factor.unsqueeze(-1)
120
- sign_factor = sign_factor.unsqueeze(-1)
121
- factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
122
- else:
123
- xgt0, scale_factor = ctx.saved_tensors
124
- for _ in range(ctx.channel_dim, x_grad.ndim - 1):
125
- scale_factor = scale_factor.unsqueeze(-1)
126
- factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
127
- neg_delta_grad = x_grad.abs() * factor
128
- return (x_grad - neg_delta_grad, None, None, None, )
129
-
130
-
131
- def _compute_scale_factor(
132
- x: Tensor,
133
- channel_dim: int,
134
- min_abs: float,
135
- max_abs: float,
136
- gain_factor: float,
137
- max_factor: float, ) -> Tensor:
138
- if channel_dim < 0:
139
- channel_dim += x.ndim
140
- sum_dims = [d for d in range(x.ndim) if d != channel_dim]
141
- x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
142
-
143
- if min_abs == 0.0:
144
- below_threshold = 0.0
145
- else:
146
- # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
147
- # x_abs)_mean , min_abs.
148
- below_threshold = (
149
- (min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
150
- min=0, max=max_factor)
151
-
152
- above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
153
- min=0, max=max_factor)
154
-
155
- return below_threshold - above_threshold
156
-
157
-
158
- def _compute_sign_factor(
159
- x: Tensor,
160
- channel_dim: int,
161
- min_positive: float,
162
- max_positive: float,
163
- gain_factor: float,
164
- max_factor: float, ) -> Tensor:
165
- if channel_dim < 0:
166
- channel_dim += x.ndim
167
- sum_dims = [d for d in range(x.ndim) if d != channel_dim]
168
- proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
169
- if min_positive == 0.0:
170
- factor1 = 0.0
171
- else:
172
- # 0 if proportion_positive >= min_positive, else can be
173
- # as large as max_factor.
174
- factor1 = ((min_positive - proportion_positive) *
175
- (gain_factor / min_positive)).clamp_(
176
- min=0, max=max_factor)
177
-
178
- if max_positive == 1.0:
179
- factor2 = 0.0
180
- else:
181
- # 0 if self.proportion_positive <= max_positive, else can be
182
- # as large as -max_factor.
183
- factor2 = ((proportion_positive - max_positive) *
184
- (gain_factor / (1.0 - max_positive))).clamp_(
185
- min=0, max=max_factor)
186
- sign_factor = factor1 - factor2
187
- # require min_positive != 0 or max_positive != 1:
188
- assert not isinstance(sign_factor, float)
189
- return sign_factor
190
-
191
-
192
- class ActivationBalancer(torch.nn.Module):
193
- """
194
- Modifies the backpropped derivatives of a function to try to encourage, for
195
- each channel, that it is positive at least a proportion `threshold` of the
196
- time. It does this by multiplying negative derivative values by up to
197
- (1+max_factor), and positive derivative values by up to (1-max_factor),
198
- interpolated from 1 at the threshold to those extremal values when none
199
- of the inputs are positive.
200
-
201
- Args:
202
- num_channels: the number of channels
203
- channel_dim: the dimension/axis corresponding to the channel, e.g.
204
- -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
205
- min_positive: the minimum, per channel, of the proportion of the time
206
- that (x > 0), below which we start to modify the derivatives.
207
- max_positive: the maximum, per channel, of the proportion of the time
208
- that (x > 0), above which we start to modify the derivatives.
209
- max_factor: the maximum factor by which we modify the derivatives for
210
- either the sign constraint or the magnitude constraint;
211
- e.g. with max_factor=0.02, the the derivatives would be multiplied by
212
- values in the range [0.98..1.02].
213
- sign_gain_factor: determines the 'gain' with which we increase the
214
- change in gradient once the constraints on min_positive and max_positive
215
- are violated.
216
- scale_gain_factor: determines the 'gain' with which we increase the
217
- change in gradient once the constraints on min_abs and max_abs
218
- are violated.
219
- min_abs: the minimum average-absolute-value difference from the mean
220
- value per channel, which we allow, before we start to modify
221
- the derivatives to prevent this.
222
- max_abs: the maximum average-absolute-value difference from the mean
223
- value per channel, which we allow, before we start to modify
224
- the derivatives to prevent this.
225
- min_prob: determines the minimum probability with which we modify the
226
- gradients for the {min,max}_positive and {min,max}_abs constraints,
227
- on each forward(). This is done randomly to prevent all layers
228
- from doing it at the same time. Early in training we may use
229
- higher probabilities than this; it will decay to this value.
230
- """
231
-
232
- def __init__(
233
- self,
234
- num_channels: int,
235
- channel_dim: int,
236
- min_positive: float=0.05,
237
- max_positive: float=0.95,
238
- max_factor: float=0.04,
239
- sign_gain_factor: float=0.01,
240
- scale_gain_factor: float=0.02,
241
- min_abs: float=0.2,
242
- max_abs: float=100.0,
243
- min_prob: float=0.1, ):
244
- super(ActivationBalancer, self).__init__()
245
- self.num_channels = num_channels
246
- self.channel_dim = channel_dim
247
- self.min_positive = min_positive
248
- self.max_positive = max_positive
249
- self.max_factor = max_factor
250
- self.min_abs = min_abs
251
- self.max_abs = max_abs
252
- self.min_prob = min_prob
253
- self.sign_gain_factor = sign_gain_factor
254
- self.scale_gain_factor = scale_gain_factor
255
-
256
- # count measures how many times the forward() function has been called.
257
- # We occasionally sync this to a tensor called `count`, that exists to
258
- # make sure it is synced to disk when we load and save the model.
259
- self.cpu_count = 0
260
- self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
261
-
262
- def forward(self, x: Tensor) -> Tensor:
263
- if (torch.jit.is_scripting() or not x.requires_grad or
264
- torch.jit.is_tracing()):
265
- return _no_op(x)
266
-
267
- count = self.cpu_count
268
- self.cpu_count += 1
269
-
270
- if random.random() < 0.01:
271
- # Occasionally sync self.cpu_count with self.count.
272
- # count affects the decay of 'prob'. don't do this on every iter,
273
- # because syncing with the GPU is slow.
274
- self.cpu_count = max(self.cpu_count, self.count.item())
275
- self.count.fill_(self.cpu_count)
276
-
277
- # the prob of doing some work exponentially decreases from 0.5 till it hits
278
- # a floor at min_prob (==0.1, by default)
279
- prob = max(self.min_prob, 0.5**(1 + (count / 4000.0)))
280
-
281
- if random.random() < prob:
282
- sign_gain_factor = 0.5
283
- if self.min_positive != 0.0 or self.max_positive != 1.0:
284
- sign_factor = _compute_sign_factor(
285
- x,
286
- self.channel_dim,
287
- self.min_positive,
288
- self.max_positive,
289
- gain_factor=self.sign_gain_factor / prob,
290
- max_factor=self.max_factor, )
291
- else:
292
- sign_factor = None
293
-
294
- scale_factor = _compute_scale_factor(
295
- x.detach(),
296
- self.channel_dim,
297
- min_abs=self.min_abs,
298
- max_abs=self.max_abs,
299
- gain_factor=self.scale_gain_factor / prob,
300
- max_factor=self.max_factor, )
301
- return ActivationBalancerFunction.apply(
302
- x,
303
- scale_factor,
304
- sign_factor,
305
- self.channel_dim, )
306
- else:
307
- return _no_op(x)
308
-
309
-
310
- def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0,
311
- min_prob=0.25) -> nn.Sequential:
312
- """
313
- ActivationBalancer -> DoubleSwish
314
- """
315
- balancer = ActivationBalancer(
316
- d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob)
317
- return nn.Sequential(
318
- balancer,
319
- DoubleSwish(), )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/modules/transformer.py DELETED
@@ -1,347 +0,0 @@
1
- # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py
2
- import copy
3
- import numbers
4
- from functools import partial
5
- from typing import Any
6
- from typing import Callable
7
- from typing import List
8
- from typing import Optional
9
- from typing import Tuple
10
- from typing import Union
11
-
12
- import torch
13
- from AR.modules.activation import MultiheadAttention
14
- from AR.modules.scaling import BalancedDoubleSwish
15
- from torch import nn
16
- from torch import Tensor
17
- from torch.nn import functional as F
18
-
19
- _shape_t = Union[int, List[int], torch.Size]
20
-
21
-
22
- class LayerNorm(nn.Module):
23
- __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
24
- normalized_shape: Tuple[int, ...]
25
- eps: float
26
- elementwise_affine: bool
27
-
28
- def __init__(
29
- self,
30
- normalized_shape: _shape_t,
31
- eps: float=1e-5,
32
- elementwise_affine: bool=True,
33
- device=None,
34
- dtype=None, ) -> None:
35
- factory_kwargs = {"device": device, "dtype": dtype}
36
- super(LayerNorm, self).__init__()
37
- if isinstance(normalized_shape, numbers.Integral):
38
- # mypy error: incompatible types in assignment
39
- normalized_shape = (normalized_shape, ) # type: ignore[assignment]
40
- self.normalized_shape = tuple(
41
- normalized_shape) # type: ignore[arg-type]
42
- self.eps = eps
43
- self.elementwise_affine = elementwise_affine
44
- if self.elementwise_affine:
45
- self.weight = nn.Parameter(
46
- torch.empty(self.normalized_shape, **factory_kwargs))
47
- self.bias = nn.Parameter(
48
- torch.empty(self.normalized_shape, **factory_kwargs))
49
- else:
50
- self.register_parameter("weight", None)
51
- self.register_parameter("bias", None)
52
-
53
- self.reset_parameters()
54
-
55
- def reset_parameters(self) -> None:
56
- if self.elementwise_affine:
57
- nn.init.ones_(self.weight)
58
- nn.init.zeros_(self.bias)
59
-
60
- def forward(self, input: Tensor, embedding: Any=None) -> Tensor:
61
- if isinstance(input, tuple):
62
- input, embedding = input
63
- return (F.layer_norm(
64
- input,
65
- self.normalized_shape,
66
- self.weight,
67
- self.bias,
68
- self.eps, ), embedding, )
69
-
70
- assert embedding is None
71
- return F.layer_norm(input, self.normalized_shape, self.weight,
72
- self.bias, self.eps)
73
-
74
- def extra_repr(self) -> str:
75
- return (
76
- "{normalized_shape}, eps={eps}, "
77
- "elementwise_affine={elementwise_affine}".format(**self.__dict__))
78
-
79
-
80
- class IdentityNorm(nn.Module):
81
- def __init__(
82
- self,
83
- d_model: int,
84
- eps: float=1e-5,
85
- device=None,
86
- dtype=None, ) -> None:
87
- super(IdentityNorm, self).__init__()
88
-
89
- def forward(self, input: Tensor, embedding: Any=None) -> Tensor:
90
- if isinstance(input, tuple):
91
- return input
92
-
93
- assert embedding is None
94
- return input
95
-
96
-
97
- class TransformerEncoder(nn.Module):
98
- r"""TransformerEncoder is a stack of N encoder layers. Users can build the
99
- BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
100
-
101
- Args:
102
- encoder_layer: an instance of the TransformerEncoderLayer() class (required).
103
- num_layers: the number of sub-encoder-layers in the encoder (required).
104
- norm: the layer normalization component (optional).
105
- enable_nested_tensor: if True, input will automatically convert to nested tensor
106
- (and convert back on output). This will improve the overall performance of
107
- TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
108
-
109
- Examples::
110
- >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
111
- >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
112
- >>> src = torch.rand(10, 32, 512)
113
- >>> out = transformer_encoder(src)
114
- """
115
- __constants__ = ["norm"]
116
-
117
- def __init__(self, encoder_layer, num_layers, norm=None):
118
- super(TransformerEncoder, self).__init__()
119
- self.layers = _get_clones(encoder_layer, num_layers)
120
- self.num_layers = num_layers
121
- self.norm = norm
122
-
123
- def forward(
124
- self,
125
- src: Tensor,
126
- mask: Optional[Tensor]=None,
127
- src_key_padding_mask: Optional[Tensor]=None,
128
- return_layer_states: bool=False,cache=None ) -> Tensor:
129
- r"""Pass the input through the encoder layers in turn.
130
-
131
- Args:
132
- src: the sequence to the encoder (required).
133
- mask: the mask for the src sequence (optional).
134
- src_key_padding_mask: the mask for the src keys per batch (optional).
135
- return_layer_states: return layers' state (optional).
136
-
137
- Shape:
138
- see the docs in Transformer class.
139
- """
140
- if return_layer_states:
141
- layer_states = [] # layers' output
142
- output = src
143
- for mod in self.layers:
144
- output = mod(
145
- output,
146
- src_mask=mask,
147
- src_key_padding_mask=src_key_padding_mask, cache=cache)
148
- layer_states.append(output[0])
149
-
150
- if self.norm is not None:
151
- output = self.norm(output)
152
-
153
- return layer_states, output
154
-
155
- output = src
156
- for mod in self.layers:
157
- output = mod(output,
158
- src_mask=mask,
159
- src_key_padding_mask=src_key_padding_mask, cache=cache)
160
-
161
- if self.norm is not None:
162
- output = self.norm(output)
163
-
164
- return output
165
-
166
-
167
- class TransformerEncoderLayer(nn.Module):
168
- __constants__ = ["batch_first", "norm_first"]
169
-
170
- def __init__(
171
- self,
172
- d_model: int,
173
- nhead: int,
174
- dim_feedforward: int=2048,
175
- dropout: float=0.1,
176
- activation: Union[str, Callable[[Tensor], Tensor]]=F.relu,
177
- batch_first: bool=False,
178
- norm_first: bool=False,
179
- device=None,
180
- dtype=None,
181
- linear1_self_attention_cls: nn.Module=nn.Linear,
182
- linear2_self_attention_cls: nn.Module=nn.Linear,
183
- linear1_feedforward_cls: nn.Module=nn.Linear,
184
- linear2_feedforward_cls: nn.Module=nn.Linear,
185
- layer_norm_cls: nn.Module=LayerNorm,
186
- layer_norm_eps: float=1e-5,
187
- adaptive_layer_norm=False, ) -> None:
188
- factory_kwargs = {"device": device, "dtype": dtype}
189
- super(TransformerEncoderLayer, self).__init__()
190
- # print(233333333333,d_model,nhead)
191
- # import os
192
- # os._exit(2333333)
193
- self.self_attn = MultiheadAttention(
194
- d_model,#512 16
195
- nhead,
196
- dropout=dropout,
197
- batch_first=batch_first,
198
- linear1_cls=linear1_self_attention_cls,
199
- linear2_cls=linear2_self_attention_cls,
200
- **factory_kwargs, )
201
-
202
- # Implementation of Feedforward model
203
- self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward,
204
- **factory_kwargs)
205
- self.dropout = nn.Dropout(dropout)
206
- self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model,
207
- **factory_kwargs)
208
-
209
- self.norm_first = norm_first
210
- self.dropout1 = nn.Dropout(dropout)
211
- self.dropout2 = nn.Dropout(dropout)
212
-
213
- # Legacy string support for activation function.
214
- if isinstance(activation, str):
215
- activation = _get_activation_fn(activation)
216
- elif isinstance(activation, partial):
217
- activation = activation(d_model)
218
- elif activation == BalancedDoubleSwish:
219
- activation = BalancedDoubleSwish(d_model)
220
-
221
- # # We can't test self.activation in forward() in TorchScript,
222
- # # so stash some information about it instead.
223
- # if activation is F.relu or isinstance(activation, torch.nn.ReLU):
224
- # self.activation_relu_or_gelu = 1
225
- # elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
226
- # self.activation_relu_or_gelu = 2
227
- # else:
228
- # self.activation_relu_or_gelu = 0
229
- self.activation = activation
230
-
231
- norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
232
- if layer_norm_cls == IdentityNorm:
233
- norm2 = BalancedBasicNorm(
234
- d_model, eps=layer_norm_eps, **factory_kwargs)
235
- else:
236
- norm2 = layer_norm_cls(
237
- d_model, eps=layer_norm_eps, **factory_kwargs)
238
-
239
- if adaptive_layer_norm:
240
- self.norm1 = AdaptiveLayerNorm(d_model, norm1)
241
- self.norm2 = AdaptiveLayerNorm(d_model, norm2)
242
- else:
243
- self.norm1 = norm1
244
- self.norm2 = norm2
245
-
246
- def __setstate__(self, state):
247
- super(TransformerEncoderLayer, self).__setstate__(state)
248
- if not hasattr(self, "activation"):
249
- self.activation = F.relu
250
-
251
- def forward(
252
- self,
253
- src: Tensor,
254
- src_mask: Optional[Tensor]=None,
255
- src_key_padding_mask: Optional[Tensor]=None,cache=None ) -> Tensor:
256
- r"""Pass the input through the encoder layer.
257
-
258
- Args:
259
- src: the sequence to the encoder layer (required).
260
- src_mask: the mask for the src sequence (optional).
261
- src_key_padding_mask: the mask for the src keys per batch (optional).
262
-
263
- Shape:
264
- see the docs in Transformer class.
265
- """
266
- x, stage_embedding = src, None
267
- is_src_tuple = False
268
- if isinstance(src, tuple):
269
- x, stage_embedding = src
270
- is_src_tuple = True
271
-
272
- if src_key_padding_mask is not None:
273
- _skpm_dtype = src_key_padding_mask.dtype
274
- if _skpm_dtype != torch.bool and not torch.is_floating_point(
275
- src_key_padding_mask):
276
- raise AssertionError(
277
- "only bool and floating types of key_padding_mask are supported"
278
- )
279
-
280
- if self.norm_first:
281
- x = x + self._sa_block(
282
- self.norm1(x, stage_embedding),
283
- src_mask,
284
- src_key_padding_mask,cache=cache )
285
- x = x + self._ff_block(self.norm2(x, stage_embedding))
286
- else:
287
- x = self.norm1(
288
- x + self._sa_block(x, src_mask, src_key_padding_mask,cache=cache),
289
- stage_embedding, )
290
- x = self.norm2(x + self._ff_block(x), stage_embedding)
291
-
292
- if is_src_tuple:
293
- return (x, stage_embedding)
294
- return x
295
-
296
- # self-attention block
297
- def _sa_block(
298
- self,
299
- x: Tensor,
300
- attn_mask: Optional[Tensor],
301
- key_padding_mask: Optional[Tensor],cache=None ) -> Tensor:
302
- # print(x.shape,attn_mask.shape,key_padding_mask)
303
- #torch.Size([1, 188, 512]) torch.Size([188, 188]) None
304
- # import os
305
- # os._exit(23333)
306
- x = self.self_attn(
307
- x,
308
- x,
309
- x,
310
- attn_mask=attn_mask,
311
- key_padding_mask=key_padding_mask,
312
- need_weights=False,cache=cache )[0]
313
- return self.dropout1(x)
314
-
315
- # feed forward block
316
- def _ff_block(self, x: Tensor) -> Tensor:
317
- x = self.linear2(self.dropout(self.activation(self.linear1(x))))
318
- return self.dropout2(x)
319
-
320
-
321
- class AdaptiveLayerNorm(nn.Module):
322
- r"""Adaptive Layer Normalization"""
323
-
324
- def __init__(self, d_model, norm) -> None:
325
- super(AdaptiveLayerNorm, self).__init__()
326
- self.project_layer = nn.Linear(d_model, 2 * d_model)
327
- self.norm = norm
328
- self.d_model = d_model
329
- self.eps = self.norm.eps
330
-
331
- def forward(self, input: Tensor, embedding: Tensor=None) -> Tensor:
332
- if isinstance(input, tuple):
333
- input, embedding = input
334
- weight, bias = torch.split(
335
- self.project_layer(embedding),
336
- split_size_or_sections=self.d_model,
337
- dim=-1, )
338
- return (weight * self.norm(input) + bias, embedding)
339
-
340
- weight, bias = torch.split(
341
- self.project_layer(embedding),
342
- split_size_or_sections=self.d_model,
343
- dim=-1, )
344
- return weight * self.norm(input) + bias
345
-
346
- def _get_clones(module, N):
347
- return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/text_processing/__init__.py DELETED
File without changes
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/text_processing/phonemizer.py DELETED
@@ -1,80 +0,0 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/phonemizer.py
2
- import itertools
3
- import re
4
- from typing import Dict
5
- from typing import List
6
-
7
- import regex
8
- from gruut import sentences
9
- from gruut.const import Sentence
10
- from gruut.const import Word
11
- from AR.text_processing.symbols import SYMBOL_TO_ID
12
-
13
-
14
- class GruutPhonemizer:
15
- def __init__(self, language: str):
16
- self._phonemizer = sentences
17
- self.lang = language
18
- self.symbol_to_id = SYMBOL_TO_ID
19
- self._special_cases_dict: Dict[str] = {
20
- r"\.\.\.": "... ",
21
- ";": "; ",
22
- ":": ": ",
23
- ",": ", ",
24
- r"\.": ". ",
25
- "!": "! ",
26
- r"\?": "? ",
27
- "—": "—",
28
- "…": "… ",
29
- "«": "«",
30
- "»": "»"
31
- }
32
- self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
33
-
34
- def _normalize_punctuation(self, text: str) -> str:
35
- text = regex.sub(fr"\pZ+{self._punctuation_regexp}", r"\1", text)
36
- text = regex.sub(fr"{self._punctuation_regexp}(\pL)", r"\1 \2", text)
37
- text = regex.sub(r"\pZ+", r" ", text)
38
- return text.strip()
39
-
40
- def _convert_punctuation(self, word: Word) -> str:
41
- if not word.phonemes:
42
- return ''
43
- if word.phonemes[0] in ['‖', '|']:
44
- return word.text.strip()
45
-
46
- phonemes = ''.join(word.phonemes)
47
- # remove modifier characters ˈˌː with regex
48
- phonemes = re.sub(r'[ˈˌː͡]', '', phonemes)
49
- return phonemes.strip()
50
-
51
- def phonemize(self, text: str, espeak: bool=False) -> str:
52
- text_to_phonemize: str = self._normalize_punctuation(text)
53
- sents: List[Sentence] = [
54
- sent
55
- for sent in self._phonemizer(
56
- text_to_phonemize, lang="en-us", espeak=espeak)
57
- ]
58
- words: List[str] = [
59
- self._convert_punctuation(word) for word in itertools.chain(*sents)
60
- ]
61
- return ' '.join(words)
62
-
63
- def transform(self, phonemes):
64
- # convert phonemes to ids
65
- # dictionary is in symbols.py
66
- return [
67
- self.symbol_to_id[p] for p in phonemes
68
- if p in self.symbol_to_id.keys()
69
- ]
70
-
71
-
72
- if __name__ == "__main__":
73
- phonemizer = GruutPhonemizer("en-us")
74
- # text -> IPA
75
- phonemes = phonemizer.phonemize("Hello, wor-ld ?")
76
- print("phonemes:", phonemes)
77
- print("len(phonemes):", len(phonemes))
78
- phoneme_ids = phonemizer.transform(phonemes)
79
- print("phoneme_ids:", phoneme_ids)
80
- print("len(phoneme_ids):", len(phoneme_ids))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/text_processing/symbols.py DELETED
@@ -1,9 +0,0 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/symbols.py
2
- PAD = '_'
3
- PUNCTUATION = ';:,.!?¡¿—…"«»“” '
4
- LETTERS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
5
- IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
6
- SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
7
- SPACE_ID = SYMBOLS.index(" ")
8
- SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}
9
- ID_TO_SYMBOL = {i: s for i, s in enumerate(SYMBOLS)}
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/utils/__init__.py DELETED
@@ -1,37 +0,0 @@
1
- import re
2
-
3
-
4
- def str2bool(str):
5
- return True if str.lower() == 'true' else False
6
-
7
-
8
- def get_newest_ckpt(string_list):
9
- # 定义一个正则表达式模式,用于匹配字符串中的数字
10
- pattern = r'epoch=(\d+)-step=(\d+)\.ckpt'
11
-
12
- # 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表
13
- extracted_info = []
14
- for string in string_list:
15
- match = re.match(pattern, string)
16
- if match:
17
- epoch = int(match.group(1))
18
- step = int(match.group(2))
19
- extracted_info.append((epoch, step, string))
20
- # 按照 epoch 后面的数字和 step 后面的数字进行排序
21
- sorted_info = sorted(
22
- extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
23
- # 获取最新的 ckpt 文件名
24
- newest_ckpt = sorted_info[0][2]
25
- return newest_ckpt
26
-
27
-
28
- # 文本存在且不为空时 return True
29
- def check_txt_file(file_path):
30
- try:
31
- with open(file_path, 'r') as file:
32
- text = file.readline().strip()
33
- assert text.strip() != ''
34
- return text
35
- except Exception:
36
- return False
37
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/utils/initialize.py DELETED
@@ -1,38 +0,0 @@
1
- #!/usr/bin/env python3
2
- """Initialize modules for espnet2 neural networks."""
3
- import torch
4
- from typeguard import check_argument_types
5
-
6
-
7
- def initialize(model: torch.nn.Module, init: str):
8
- """Initialize weights of a neural network module.
9
-
10
- Parameters are initialized using the given method or distribution.
11
-
12
- Custom initialization routines can be implemented into submodules
13
- as function `espnet_initialization_fn` within the custom module.
14
-
15
- Args:
16
- model: Target.
17
- init: Method of initialization.
18
- """
19
- assert check_argument_types()
20
- print("init with", init)
21
-
22
- # weight init
23
- for p in model.parameters():
24
- if p.dim() > 1:
25
- if init == "xavier_uniform":
26
- torch.nn.init.xavier_uniform_(p.data)
27
- elif init == "xavier_normal":
28
- torch.nn.init.xavier_normal_(p.data)
29
- elif init == "kaiming_uniform":
30
- torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
31
- elif init == "kaiming_normal":
32
- torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
33
- else:
34
- raise ValueError("Unknown initialization: " + init)
35
- # bias init
36
- for name, p in model.named_parameters():
37
- if ".bias" in name and p.dim() == 1:
38
- p.data.zero_()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/AR/utils/io.py DELETED
@@ -1,32 +0,0 @@
1
- import sys
2
-
3
- import torch
4
- import yaml
5
-
6
-
7
- def load_yaml_config(path):
8
- with open(path) as f:
9
- config = yaml.full_load(f)
10
- return config
11
-
12
-
13
- def save_config_to_yaml(config, path):
14
- assert path.endswith('.yaml')
15
- with open(path, 'w') as f:
16
- f.write(yaml.dump(config))
17
- f.close()
18
-
19
-
20
- def write_args(args, path):
21
- args_dict = dict((name, getattr(args, name)) for name in dir(args)
22
- if not name.startswith('_'))
23
- with open(path, 'a') as args_file:
24
- args_file.write('==> torch version: {}\n'.format(torch.__version__))
25
- args_file.write(
26
- '==> cudnn version: {}\n'.format(torch.backends.cudnn.version()))
27
- args_file.write('==> Cmd:\n')
28
- args_file.write(str(sys.argv))
29
- args_file.write('\n==> args:\n')
30
- for k, v in sorted(args_dict.items()):
31
- args_file.write(' %s: %s\n' % (str(k), str(v)))
32
- args_file.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/configs/s1.yaml DELETED
@@ -1,31 +0,0 @@
1
- train:
2
- seed: 1234
3
- epochs: 300
4
- batch_size: 8
5
- gradient_accumulation: 4
6
- save_every_n_epoch: 1
7
- precision: 16
8
- gradient_clip: 1.0
9
- optimizer:
10
- lr: 0.01
11
- lr_init: 0.00001
12
- lr_end: 0.0001
13
- warmup_steps: 2000
14
- decay_steps: 40000
15
- data:
16
- max_eval_sample: 8
17
- max_sec: 54
18
- num_workers: 1
19
- pad_val: 1024 # same with EOS in model
20
- model:
21
- vocab_size: 1025
22
- phoneme_vocab_size: 512
23
- embedding_dim: 512
24
- hidden_dim: 512
25
- head: 16
26
- linear_units: 2048
27
- n_layer: 12
28
- dropout: 0
29
- EOS: 1024
30
- inference:
31
- top_k: 5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
GPT-SoVITS-models/GPT-SoVITS/GPT_SoVITS/configs/s1big.yaml DELETED
@@ -1,31 +0,0 @@
1
- train:
2
- seed: 1234
3
- epochs: 300
4
- batch_size: 8
5
- gradient_accumulation: 4
6
- save_every_n_epoch: 1
7
- precision: 16-mixed
8
- gradient_clip: 1.0
9
- optimizer:
10
- lr: 0.01
11
- lr_init: 0.00001
12
- lr_end: 0.0001
13
- warmup_steps: 2000
14
- decay_steps: 40000
15
- data:
16
- max_eval_sample: 8
17
- max_sec: 54
18
- num_workers: 1
19
- pad_val: 1024 # same with EOS in model
20
- model:
21
- vocab_size: 1025
22
- phoneme_vocab_size: 512
23
- embedding_dim: 1024
24
- hidden_dim: 1024
25
- head: 16
26
- linear_units: 2048
27
- n_layer: 16
28
- dropout: 0
29
- EOS: 1024
30
- inference:
31
- top_k: 5