josephchay commited on
Commit
70c3139
·
1 Parent(s): 7849812

Add remaining files

Browse files
Files changed (48) hide show
  1. .gitignore +3 -0
  2. LangSegment/LangSegment.py +1068 -0
  3. LangSegment/__init__.py +9 -0
  4. LangSegment/utils/__init__.py +0 -0
  5. LangSegment/utils/num.py +327 -0
  6. README.md +17 -0
  7. app.py +430 -0
  8. packages.txt +1 -0
  9. pretrained/eval.py +66 -0
  10. pretrained/eval.safetensors +3 -0
  11. pretrained/eval.yaml +6 -0
  12. requirements.txt +35 -0
  13. soundsation/config/config.json +13 -0
  14. soundsation/config/defaults.ini +94 -0
  15. soundsation/g2p/g2p/__init__.py +87 -0
  16. soundsation/g2p/g2p/chinese_model_g2p.py +213 -0
  17. soundsation/g2p/g2p/cleaners.py +31 -0
  18. soundsation/g2p/g2p/english.py +202 -0
  19. soundsation/g2p/g2p/french.py +149 -0
  20. soundsation/g2p/g2p/german.py +94 -0
  21. soundsation/g2p/g2p/japanese.py +816 -0
  22. soundsation/g2p/g2p/korean.py +81 -0
  23. soundsation/g2p/g2p/mandarin.py +600 -0
  24. soundsation/g2p/g2p/text_tokenizers.py +84 -0
  25. soundsation/g2p/g2p/vocab.json +372 -0
  26. soundsation/g2p/g2p_generation.py +133 -0
  27. soundsation/g2p/sources/bpmf_2_pinyin.txt +41 -0
  28. soundsation/g2p/sources/chinese_lexicon.txt +3 -0
  29. soundsation/g2p/sources/g2p_chinese_model/config.json +819 -0
  30. soundsation/g2p/sources/g2p_chinese_model/poly_bert_model.onnx +3 -0
  31. soundsation/g2p/sources/g2p_chinese_model/polychar.txt +159 -0
  32. soundsation/g2p/sources/g2p_chinese_model/polydict.json +393 -0
  33. soundsation/g2p/sources/g2p_chinese_model/polydict_r.json +393 -0
  34. soundsation/g2p/sources/g2p_chinese_model/vocab.txt +0 -0
  35. soundsation/g2p/sources/pinyin_2_bpmf.txt +429 -0
  36. soundsation/g2p/utils/front_utils.py +20 -0
  37. soundsation/g2p/utils/g2p.py +139 -0
  38. soundsation/g2p/utils/log.py +52 -0
  39. soundsation/g2p/utils/mls_en.json +335 -0
  40. soundsation/infer/infer.py +229 -0
  41. soundsation/infer/infer_utils.py +498 -0
  42. soundsation/model/__init__.py +6 -0
  43. soundsation/model/cfm.py +324 -0
  44. soundsation/model/dit.py +221 -0
  45. soundsation/model/modules.py +652 -0
  46. soundsation/model/trainer.py +350 -0
  47. soundsation/model/utils.py +182 -0
  48. src/negative_prompt.npy +3 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .venv/
2
+ .idea/
3
+ __pycache__/
LangSegment/LangSegment.py ADDED
@@ -0,0 +1,1068 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file bundles language identification functions.
3
+
4
+ Modifications (fork): Copyright (c) 2021, Adrien Barbaresi.
5
+
6
+ Original code: Copyright (c) 2011 Marco Lui <[email protected]>.
7
+ Based on research by Marco Lui and Tim Baldwin.
8
+
9
+ See LICENSE file for more info.
10
+ https://github.com/adbar/py3langid
11
+
12
+ Projects:
13
+ https://github.com/juntaosun/LangSegment
14
+ """
15
+
16
+ import os
17
+ import re
18
+ import sys
19
+ import numpy as np
20
+ from collections import Counter
21
+ from collections import defaultdict
22
+
23
+ # import langid
24
+ # import py3langid as langid
25
+ # pip install py3langid==0.2.2
26
+
27
+ # 启用语言预测概率归一化,概率预测的分数。因此,实现重新规范化 产生 0-1 范围内的输出。
28
+ # langid disables probability normalization by default. For command-line usages of , it can be enabled by passing the flag.
29
+ # For probability normalization in library use, the user must instantiate their own . An example of such usage is as follows:
30
+ from py3langid.langid import LanguageIdentifier, MODEL_FILE
31
+ langid = LanguageIdentifier.from_pickled_model(MODEL_FILE, norm_probs=True)
32
+
33
+ # Digital processing
34
+ try:from LangSegment.utils.num import num2str
35
+ except ImportError:
36
+ try:from utils.num import num2str
37
+ except ImportError as e:
38
+ raise e
39
+
40
+ # -----------------------------------
41
+ # 更新日志:新版本分词更加精准。
42
+ # Changelog: The new version of the word segmentation is more accurate.
43
+ # チェンジログ:新しいバージョンの単語セグメンテーションはより正確です。
44
+ # Changelog: 분할이라는 단어의 새로운 버전이 더 정확합니다.
45
+ # -----------------------------------
46
+
47
+
48
+ # Word segmentation function:
49
+ # automatically identify and split the words (Chinese/English/Japanese/Korean) in the article or sentence according to different languages,
50
+ # making it more suitable for TTS processing.
51
+ # This code is designed for front-end text multi-lingual mixed annotation distinction, multi-language mixed training and inference of various TTS projects.
52
+ # This processing result is mainly for (Chinese = zh, Japanese = ja, English = en, Korean = ko), and can actually support up to 97 different language mixing processing.
53
+
54
+ #===========================================================================================================
55
+ #分かち書き機能:文章や文章の中の例えば(中国語/英語/日本語/韓国語)を、異なる言語で自動的に認識して分割し、TTS処理により適したものにします。
56
+ #このコードは、さまざまなTTSプロジェクトのフロントエンドテキストの多言語混合注釈区別、多言語混合トレーニング、および推論のために特別に作成されています。
57
+ #===========================================================================================================
58
+ #(1)自動分詞:「韓国語では何を読むのですかあなたの体育の先生は誰ですか?今回の発表会では、iPhone 15シリーズの4機種が登場しました」
59
+ #(2)手动分词:“あなたの名前は<ja>佐々木ですか?<ja>ですか?”
60
+ #この処理結果は主に(中国語=ja、日本語=ja、英語=en、韓国語=ko)を対象としており、実際には最大97の異なる言語の混合処理をサポートできます。
61
+ #===========================================================================================================
62
+
63
+ #===========================================================================================================
64
+ # 단어 분할 기능: 기사 또는 문장에서 단어(중국어/영어/일본어/한국어)를 다른 언어에 따라 자동으로 식별하고 분할하여 TTS 처리에 더 적합합니다.
65
+ # 이 코드는 프런트 엔드 텍스트 다국어 혼합 주석 분화, 다국어 혼합 교육 및 다양한 TTS 프로젝트의 추론을 위해 설계되었습니다.
66
+ #===========================================================================================================
67
+ # (1) 자동 단어 분할: "한국어로 무엇을 읽습니까? 스포츠 씨? 이 컨퍼런스는 4개의 iPhone 15 시리즈 모델을 제공합니다."
68
+ # (2) 수동 참여: "이름이 <ja>Saki입니까? <ja>?"
69
+ # 이 처리 결과는 주로 (중국어 = zh, 일본어 = ja, 영어 = en, 한국어 = ko)를 위한 것이며 실제로 혼합 처리를 위해 최대 97개의 언어를 지원합니다.
70
+ #===========================================================================================================
71
+
72
+ # ===========================================================================================================
73
+ # 分词功能:将文章或句子里的例如(中/英/日/韩),按不同语言自动识别并拆分,让它更适合TTS处理。
74
+ # 本代码专为各种 TTS 项目的前端文本多语种混合标注区分,多语言混合训练和推理而编写。
75
+ # ===========================================================================================================
76
+ # (1)自动分词:“韩语中的오빠���什么呢?あなたの体育の先生は誰ですか? 此次发布会带来了四款iPhone 15系列机型”
77
+ # (2)手动分词:“你的名字叫<ja>佐々木?<ja>吗?”
78
+ # 本处理结果主要针对(中文=zh , 日文=ja , 英文=en , 韩语=ko), 实际上可支持多达 97 种不同的语言混合处理。
79
+ # ===========================================================================================================
80
+
81
+
82
+ # 手动分词标签规范:<语言标签>文本内容</语言标签>
83
+ # 수동 단어 분할 태그 사양: <언어 태그> 텍스트 내용</언어 태그>
84
+ # Manual word segmentation tag specification: <language tags> text content </language tags>
85
+ # 手動分詞タグ仕様:<言語タグ>テキスト内容</言語タグ>
86
+ # ===========================================================================================================
87
+ # For manual word segmentation, labels need to appear in pairs, such as:
88
+ # 如需手动分词,标签需要成对出现,例如:“<ja>佐々木<ja>” 或者 “<ja>佐々木</ja>”
89
+ # 错误示范:“你的名字叫<ja>佐々木。” 此句子中出现的单个<ja>标签将被忽略,不会处理。
90
+ # Error demonstration: "Your name is <ja>佐々木。" Single <ja> tags that appear in this sentence will be ignored and will not be processed.
91
+ # ===========================================================================================================
92
+
93
+
94
+ # ===========================================================================================================
95
+ # 语音合成标记语言 SSML , 这里只支持它的标签(非 XML)Speech Synthesis Markup Language SSML, only its tags are supported here (not XML)
96
+ # 想支持更多的 SSML 标签?欢迎 PR! Want to support more SSML tags? PRs are welcome!
97
+ # 说明:除了中文以外,它也可改造成支持多语种 SSML ,不仅仅是中文。
98
+ # Note: In addition to Chinese, it can also be modified to support multi-language SSML, not just Chinese.
99
+ # ===========================================================================================================
100
+ # 中文实现:Chinese implementation:
101
+ # 【SSML】<number>=中文大写数字读法(单字)
102
+ # 【SSML】<telephone>=数字转成中文电话号码大写汉字(单字)
103
+ # 【SSML】<currency>=按金额发音。
104
+ # 【SSML】<date>=按日期发音。支持 2024年08月24, 2024/8/24, 2024-08, 08-24, 24 等输入。
105
+ # ===========================================================================================================
106
+ class LangSSML:
107
+
108
+ # 纯数字
109
+ _zh_numerals_number = {
110
+ '0': '零',
111
+ '1': '一',
112
+ '2': '二',
113
+ '3': '三',
114
+ '4': '四',
115
+ '5': '五',
116
+ '6': '六',
117
+ '7': '七',
118
+ '8': '八',
119
+ '9': '九'
120
+ }
121
+
122
+
123
+ # 将2024/8/24, 2024-08, 08-24, 24 标准化“年月日”
124
+ # Standardize 2024/8/24, 2024-08, 08-24, 24 to "year-month-day"
125
+ def _format_chinese_data(date_str:str):
126
+ # 处理日期格式
127
+ input_date = date_str
128
+ if date_str is None or date_str.strip() == "":return ""
129
+ date_str = re.sub(r"[\/\._|年|月]","-",date_str)
130
+ date_str = re.sub(r"日",r"",date_str)
131
+ date_arrs = date_str.split(' ')
132
+ if len(date_arrs) == 1 and ":" in date_arrs[0]:
133
+ time_str = date_arrs[0]
134
+ date_arrs = []
135
+ else:
136
+ time_str = date_arrs[1] if len(date_arrs) >=2 else ""
137
+ def nonZero(num,cn,func=None):
138
+ if func is not None:num=func(num)
139
+ return f"{num}{cn}" if num is not None and num != "" and num != "0" else ""
140
+ f_number = LangSSML.to_chinese_number
141
+ f_currency = LangSSML.to_chinese_currency
142
+ # year, month, day
143
+ year_month_day = ""
144
+ if len(date_arrs) > 0:
145
+ year, month, day = "","",""
146
+ parts = date_arrs[0].split('-')
147
+ if len(parts) == 3: # 格式为 YYYY-MM-DD
148
+ year, month, day = parts
149
+ elif len(parts) == 2: # 格式为 MM-DD 或 YYYY-MM
150
+ if len(parts[0]) == 4: # 年-月
151
+ year, month = parts
152
+ else:month, day = parts # 月-日
153
+ elif len(parts[0]) > 0: # 仅有月-日或年
154
+ if len(parts[0]) == 4:
155
+ year = parts[0]
156
+ else:day = parts[0]
157
+ year,month,day = nonZero(year,"年",f_number),nonZero(month,"月",f_currency),nonZero(day,"日",f_currency)
158
+ year_month_day = re.sub(r"([年|月|日])+",r"\1",f"{year}{month}{day}")
159
+ # hours, minutes, seconds
160
+ time_str = re.sub(r"[\/\.\-:_]",":",time_str)
161
+ time_arrs = time_str.split(":")
162
+ hours, minutes, seconds = "","",""
163
+ if len(time_arrs) == 3: # H/M/S
164
+ hours, minutes, seconds = time_arrs
165
+ elif len(time_arrs) == 2:# H/M
166
+ hours, minutes = time_arrs
167
+ elif len(time_arrs[0]) > 0:hours = f'{time_arrs[0]}点' # H
168
+ if len(time_arrs) > 1:
169
+ hours, minutes, seconds = nonZero(hours,"点",f_currency),nonZero(minutes,"分",f_currency),nonZero(seconds,"秒",f_currency)
170
+ hours_minutes_seconds = re.sub(r"([点|分|秒])+",r"\1",f"{hours}{minutes}{seconds}")
171
+ output_date = f"{year_month_day}{hours_minutes_seconds}"
172
+ return output_date
173
+
174
+ # 【SSML】number=中文大写数字读法(单字)
175
+ # Chinese Numbers(single word)
176
+ def to_chinese_number(num:str):
177
+ pattern = r'(\d+)'
178
+ zh_numerals = LangSSML._zh_numerals_number
179
+ arrs = re.split(pattern, num)
180
+ output = ""
181
+ for item in arrs:
182
+ if re.match(pattern,item):
183
+ output += ''.join(zh_numerals[digit] if digit in zh_numerals else "" for digit in str(item))
184
+ else:output += item
185
+ output = output.replace(".","点")
186
+ return output
187
+
188
+ # 【SSML】telephone=数字转成中文电话号码大写汉字(单字)
189
+ # Convert numbers to Chinese phone numbers in uppercase Chinese characters(single word)
190
+ def to_chinese_telephone(num:str):
191
+ output = LangSSML.to_chinese_number(num.replace("+86","")) # zh +86
192
+ output = output.replace("一","幺")
193
+ return output
194
+
195
+ # 【SSML】currency=按金额发音。
196
+ # Digital processing from GPT_SoVITS num.py (thanks)
197
+ def to_chinese_currency(num:str):
198
+ pattern = r'(\d+)'
199
+ arrs = re.split(pattern, num)
200
+ output = ""
201
+ for item in arrs:
202
+ if re.match(pattern,item):
203
+ output += num2str(item)
204
+ else:output += item
205
+ output = output.replace(".","点")
206
+ return output
207
+
208
+ # 【SSML】date=按日期发音。支持 2024年08月24, 2024/8/24, 2024-08, 08-24, 24 等输入。
209
+ def to_chinese_date(num:str):
210
+ chinese_date = LangSSML._format_chinese_data(num)
211
+ return chinese_date
212
+
213
+
214
+
215
+
216
+ class LangSegment():
217
+
218
+ _text_cache = None
219
+ _text_lasts = None
220
+ _text_langs = None
221
+ _lang_count = None
222
+ _lang_eos = None
223
+
224
+ # 可自定义语言匹配标签:カスタマイズ可能な言語対応タグ:사용자 지정 가능한 언어 일치 태그:
225
+ # Customizable language matching tags: These are supported,이 표현들은 모두 지지합니다
226
+ # <zh>你好<zh> , <ja>佐々木</ja> , <en>OK<en> , <ko>오빠</ko> 这些写法均支持
227
+ SYMBOLS_PATTERN = r'(<([a-zA-Z|-]*)>(.*?)<\/*[a-zA-Z|-]*>)'
228
+
229
+ # 语言过滤组功能, 可以指定保留语言。不在过滤组中的语言将被清除。您可随心搭配TTS语音合成所支持的语言。
230
+ # 언어 필터 그룹 기능을 사용하면 예약된 언어를 지정할 수 있습니다. 필터 그룹에 없는 언어는 지워집니다. TTS 텍스트에서 지원하는 언어를 원하는 대로 일치시킬 수 있습니다.
231
+ # 言語フィルターグループ機能では、予約言語を指定できます。フィルターグループに含まれていない言語はクリアされます。TTS音声合成がサポートする言語を自由に組み合わせることができます。
232
+ # The language filter group function allows you to specify reserved languages.
233
+ # Languages not in the filter group will be cleared. You can match the languages supported by TTS Text To Speech as you like.
234
+ # 排名越前,优先级越高,The higher the ranking, the higher the priority,ランキングが上位になるほど、優先度が高くなります。
235
+
236
+ # 系统默认过滤器。System default filter。(ISO 639-1 codes given)
237
+ # ----------------------------------------------------------------------------------------------------------------------------------
238
+ # "zh"中文=Chinese ,"en"英语=English ,"ja"日语=Japanese ,"ko"韩语=Korean ,"fr"法语=French ,"vi"越南语=Vietnamese , "ru"俄语=Russian
239
+ # "th"泰语=Thai
240
+ # ----------------------------------------------------------------------------------------------------------------------------------
241
+ DEFAULT_FILTERS = ["zh", "ja", "ko", "en"]
242
+
243
+ # 用户可自定义过滤器。User-defined filters
244
+ Langfilters = DEFAULT_FILTERS[:] # 创建副本
245
+
246
+ # 合并文本
247
+ isLangMerge = True
248
+
249
+ # 试验性支持:您可自定义添加:"fr"法语 , "vi"越南语。Experimental: You can customize to add: "fr" French, "vi" Vietnamese.
250
+ # 请使用API启用:LangSegment.setfilters(["zh", "en", "ja", "ko", "fr", "vi" , "ru" , "th"]) # 您可自定义添加,如:"fr"法语 , "vi"越南语。
251
+
252
+ # 预览版功能,自动启用或禁用,无需设置
253
+ # Preview feature, automatically enabled or disabled, no settings required
254
+ EnablePreview = False
255
+
256
+ # 除此以外,它支持简写过滤器,只需按不同语种任意组合即可。
257
+ # In addition to that, it supports abbreviation filters, allowing for any combination of different languages.
258
+ # 示例:您可以任意指定多种组���,进行过滤
259
+ # Example: You can specify any combination to filter
260
+
261
+ # 中/日语言优先级阀值(评分范围为 0 ~ 1):评分低于设定阀值 <0.89 时,启用 filters 中的优先级。\n
262
+ # 중/일본어 우선 순위 임계값(점수 범위 0-1): 점수가 설정된 임계값 <0.89보다 낮을 때 필터에서 우선 순위를 활성화합니다.
263
+ # 中国語/日本語の優先度しきい値(スコア範囲0〜1):スコアが設定されたしきい値<0.89未満の場合、フィルターの優先度が有効になります。\n
264
+ # Chinese and Japanese language priority threshold (score range is 0 ~ 1): The default threshold is 0.89. \n
265
+ # Only the common characters between Chinese and Japanese are processed with confidence and priority. \n
266
+ LangPriorityThreshold = 0.89
267
+
268
+ # Langfilters = ["zh"] # 按中文识别
269
+ # Langfilters = ["en"] # 按英文识别
270
+ # Langfilters = ["ja"] # 按日文识别
271
+ # Langfilters = ["ko"] # 按韩文识别
272
+ # Langfilters = ["zh_ja"] # 中日混合识别
273
+ # Langfilters = ["zh_en"] # 中英混合识别
274
+ # Langfilters = ["ja_en"] # 日英混合识别
275
+ # Langfilters = ["zh_ko"] # 中韩混合识别
276
+ # Langfilters = ["ja_ko"] # 日韩混合识别
277
+ # Langfilters = ["en_ko"] # 英韩混合识别
278
+ # Langfilters = ["zh_ja_en"] # 中日英混合识别
279
+ # Langfilters = ["zh_ja_en_ko"] # 中日英韩混合识别
280
+
281
+ # 更多过滤组合,请您随意。。。For more filter combinations, please feel free to......
282
+ # より多くのフィルターの組み合わせ、お気軽に。。。더 많은 필터 조합을 원하시면 자유롭게 해주세요. .....
283
+
284
+ # 可选保留:支持中文数字拼音格式,更方便前端实现拼音音素修改和推理,默认关闭 False 。
285
+ # 开启后 True ,括号内的数字拼音格式均保留,并识别输出为:"zh"中文。
286
+ keepPinyin = False
287
+
288
+
289
+ # DEFINITION
290
+ PARSE_TAG = re.compile(r'(⑥\$*\d+[\d]{6,}⑥)')
291
+
292
+ @staticmethod
293
+ def _clears():
294
+ LangSegment._text_cache = None
295
+ LangSegment._text_lasts = None
296
+ LangSegment._text_langs = None
297
+ LangSegment._text_waits = None
298
+ LangSegment._lang_count = None
299
+ LangSegment._lang_eos = None
300
+ pass
301
+
302
+ @staticmethod
303
+ def _is_english_word(word):
304
+ return bool(re.match(r'^[a-zA-Z]+$', word))
305
+
306
+ @staticmethod
307
+ def _is_chinese(word):
308
+ for char in word:
309
+ if '\u4e00' <= char <= '\u9fff':
310
+ return True
311
+ return False
312
+
313
+ @staticmethod
314
+ def _is_japanese_kana(word):
315
+ pattern = re.compile(r'[\u3040-\u309F\u30A0-\u30FF]+')
316
+ matches = pattern.findall(word)
317
+ return len(matches) > 0
318
+
319
+ @staticmethod
320
+ def _insert_english_uppercase(word):
321
+ modified_text = re.sub(r'(?<!\b)([A-Z])', r' \1', word)
322
+ modified_text = modified_text.strip('-')
323
+ return modified_text + " "
324
+
325
+ @staticmethod
326
+ def _split_camel_case(word):
327
+ return re.sub(r'(?<!^)(?=[A-Z])', ' ', word)
328
+
329
+ @staticmethod
330
+ def _statistics(language, text):
331
+ # Language word statistics:
332
+ # Chinese characters usually occupy double bytes
333
+ if LangSegment._lang_count is None or not isinstance(LangSegment._lang_count, defaultdict):
334
+ LangSegment._lang_count = defaultdict(int)
335
+ lang_count = LangSegment._lang_count
336
+ if not "|" in language:
337
+ lang_count[language] += int(len(text)*2) if language == "zh" else len(text)
338
+ LangSegment._lang_count = lang_count
339
+ pass
340
+
341
+ @staticmethod
342
+ def _clear_text_number(text):
343
+ if text == "\n":return text,False # Keep Line Breaks
344
+ clear_text = re.sub(r'([^\w\s]+)','',re.sub(r'\n+','',text)).strip()
345
+ is_number = len(re.sub(re.compile(r'(\d+)'),'',clear_text)) == 0
346
+ return clear_text,is_number
347
+
348
+ @staticmethod
349
+ def _saveData(words,language:str,text:str,score:float,symbol=None):
350
+ # Pre-detection
351
+ clear_text , is_number = LangSegment._clear_text_number(text)
352
+ # Merge the same language and save the results
353
+ preData = words[-1] if len(words) > 0 else None
354
+ if symbol is not None:pass
355
+ elif preData is not None and preData["symbol"] is None:
356
+ if len(clear_text) == 0:language = preData["lang"]
357
+ elif is_number == True:language = preData["lang"]
358
+ _ , pre_is_number = LangSegment._clear_text_number(preData["text"])
359
+ if (preData["lang"] == language):
360
+ LangSegment._statistics(preData["lang"],text)
361
+ text = preData["text"] + text
362
+ preData["text"] = text
363
+ return preData
364
+ elif pre_is_number == True:
365
+ text = f'{preData["text"]}{text}'
366
+ words.pop()
367
+ elif is_number == True:
368
+ priority_language = LangSegment._get_filters_string()[:2]
369
+ if priority_language in "ja-zh-en-ko-fr-vi":language = priority_language
370
+ data = {"lang":language,"text": text,"score":score,"symbol":symbol}
371
+ filters = LangSegment.Langfilters
372
+ if filters is None or len(filters) == 0 or "?" in language or \
373
+ language in filters or language in filters[0] or \
374
+ filters[0] == "*" or filters[0] in "alls-mixs-autos":
375
+ words.append(data)
376
+ LangSegment._statistics(data["lang"],data["text"])
377
+ return data
378
+
379
+ @staticmethod
380
+ def _addwords(words,language,text,score,symbol=None):
381
+ if text == "\n":pass # Keep Line Breaks
382
+ elif text is None or len(text.strip()) == 0:return True
383
+ if language is None:language = ""
384
+ language = language.lower()
385
+ if language == 'en':text = LangSegment._insert_english_uppercase(text)
386
+ # text = re.sub(r'[(())]', ',' , text) # Keep it.
387
+ text_waits = LangSegment._text_waits
388
+ ispre_waits = len(text_waits)>0
389
+ preResult = text_waits.pop() if ispre_waits else None
390
+ if preResult is None:preResult = words[-1] if len(words) > 0 else None
391
+ if preResult and ("|" in preResult["lang"]):
392
+ pre_lang = preResult["lang"]
393
+ if language in pre_lang:preResult["lang"] = language = language.split("|")[0]
394
+ else:preResult["lang"]=pre_lang.split("|")[0]
395
+ if ispre_waits:preResult = LangSegment._saveData(words,preResult["lang"],preResult["text"],preResult["score"],preResult["symbol"])
396
+ pre_lang = preResult["lang"] if preResult else None
397
+ if ("|" in language) and (pre_lang and not pre_lang in language and not "…" in language):language = language.split("|")[0]
398
+ if "|" in language:LangSegment._text_waits.append({"lang":language,"text": text,"score":score,"symbol":symbol})
399
+ else:LangSegment._saveData(words,language,text,score,symbol)
400
+ return False
401
+
402
+ @staticmethod
403
+ def _get_prev_data(words):
404
+ data = words[-1] if words and len(words) > 0 else None
405
+ if data:return (data["lang"] , data["text"])
406
+ return (None,"")
407
+
408
+ @staticmethod
409
+ def _match_ending(input , index):
410
+ if input is None or len(input) == 0:return False,None
411
+ input = re.sub(r'\s+', '', input)
412
+ if len(input) == 0 or abs(index) > len(input):return False,None
413
+ ending_pattern = re.compile(r'([「」“”‘’"\'::。.!!?.?])')
414
+ return ending_pattern.match(input[index]),input[index]
415
+
416
+ @staticmethod
417
+ def _cleans_text(cleans_text):
418
+ cleans_text = re.sub(r'(.*?)([^\w]+)', r'\1 ', cleans_text)
419
+ cleans_text = re.sub(r'(.)\1+', r'\1', cleans_text)
420
+ return cleans_text.strip()
421
+
422
+ @staticmethod
423
+ def _mean_processing(text:str):
424
+ if text is None or (text.strip()) == "":return None , 0.0
425
+ arrs = LangSegment._split_camel_case(text).split(" ")
426
+ langs = []
427
+ for t in arrs:
428
+ if len(t.strip()) <= 3:continue
429
+ language, score = langid.classify(t)
430
+ langs.append({"lang":language})
431
+ if len(langs) == 0:return None , 0.0
432
+ return Counter([item['lang'] for item in langs]).most_common(1)[0][0],1.0
433
+
434
+ @staticmethod
435
+ def _lang_classify(cleans_text):
436
+ language, score = langid.classify(cleans_text)
437
+ # fix: Huggingface is np.float32
438
+ if score is not None and isinstance(score, np.generic) and hasattr(score,"item"):
439
+ score = score.item()
440
+ score = round(score , 3)
441
+ return language, score
442
+
443
+ @staticmethod
444
+ def _get_filters_string():
445
+ filters = LangSegment.Langfilters
446
+ return "-".join(filters).lower().strip() if filters is not None else ""
447
+
448
+ @staticmethod
449
+ def _parse_language(words , segment):
450
+ LANG_JA = "ja"
451
+ LANG_ZH = "zh"
452
+ LANG_ZH_JA = f'{LANG_ZH}|{LANG_JA}'
453
+ LANG_JA_ZH = f'{LANG_JA}|{LANG_ZH}'
454
+ language = LANG_ZH
455
+ regex_pattern = re.compile(r'([^\w\s]+)')
456
+ lines = regex_pattern.split(segment)
457
+ lines_max = len(lines)
458
+ LANG_EOS =LangSegment._lang_eos
459
+ for index, text in enumerate(lines):
460
+ if len(text) == 0:continue
461
+ EOS = index >= (lines_max - 1)
462
+ nextId = index + 1
463
+ nextText = lines[nextId] if not EOS else ""
464
+ nextPunc = len(re.sub(regex_pattern,'',re.sub(r'\n+','',nextText)).strip()) == 0
465
+ textPunc = len(re.sub(regex_pattern,'',re.sub(r'\n+','',text)).strip()) == 0
466
+ if not EOS and (textPunc == True or ( len(nextText.strip()) >= 0 and nextPunc == True)):
467
+ lines[nextId] = f'{text}{nextText}'
468
+ continue
469
+ number_tags = re.compile(r'(⑥\d{6,}⑥)')
470
+ cleans_text = re.sub(number_tags, '' ,text)
471
+ cleans_text = re.sub(r'\d+', '' ,cleans_text)
472
+ cleans_text = LangSegment._cleans_text(cleans_text)
473
+ # fix:Langid's recognition of short sentences is inaccurate, and it is spliced longer.
474
+ if not EOS and len(cleans_text) <= 2:
475
+ lines[nextId] = f'{text}{nextText}'
476
+ continue
477
+ language,score = LangSegment._lang_classify(cleans_text)
478
+ prev_language , prev_text = LangSegment._get_prev_data(words)
479
+ if language != LANG_ZH and all('\u4e00' <= c <= '\u9fff' for c in re.sub(r'\s','',cleans_text)):language,score = LANG_ZH,1
480
+ if len(cleans_text) <= 5 and LangSegment._is_chinese(cleans_text):
481
+ filters_string = LangSegment._get_filters_string()
482
+ if score < LangSegment.LangPriorityThreshold and len(filters_string) > 0:
483
+ index_ja , index_zh = filters_string.find(LANG_JA) , filters_string.find(LANG_ZH)
484
+ if index_ja != -1 and index_ja < index_zh:language = LANG_JA
485
+ elif index_zh != -1 and index_zh < index_ja:language = LANG_ZH
486
+ if LangSegment._is_japanese_kana(cleans_text):language = LANG_JA
487
+ elif len(cleans_text) > 2 and score > 0.90:pass
488
+ elif EOS and LANG_EOS:language = LANG_ZH if len(cleans_text) <= 1 else language
489
+ else:
490
+ LANG_UNKNOWN = LANG_ZH_JA if language == LANG_ZH or (len(cleans_text) <=2 and prev_language == LANG_ZH) else LANG_JA_ZH
491
+ match_end,match_char = LangSegment._match_ending(text, -1)
492
+ referen = prev_language in LANG_UNKNOWN or LANG_UNKNOWN in prev_language if prev_language else False
493
+ if match_char in "。.": language = prev_language if referen and len(words) > 0 else language
494
+ else:language = f"{LANG_UNKNOWN}|…"
495
+ text,*_ = re.subn(number_tags , LangSegment._restore_number , text )
496
+ LangSegment._addwords(words,language,text,score)
497
+ pass
498
+ pass
499
+
500
+ # ----------------------------------------------------------
501
+ # 【SSML】中文数字处理:Chinese Number Processing (SSML support)
502
+ # 这里默认都是中文,用于处理 SSML 中文标签。当然可以支持任意语言,例如:
503
+ # The default here is Chinese, which is used to process SSML Chinese tags. Of course, any language can be supported, for example:
504
+ # 中文电话号码:<telephone>1234567</telephone>
505
+ # 中文数字号码:<number>1234567</number>
506
+ @staticmethod
507
+ def _process_symbol_SSML(words,data):
508
+ tag , match = data
509
+ language = SSML = match[1]
510
+ text = match[2]
511
+ score = 1.0
512
+ if SSML == "telephone":
513
+ # 中文-电话号码
514
+ language = "zh"
515
+ text = LangSSML.to_chinese_telephone(text)
516
+ pass
517
+ elif SSML == "number":
518
+ # 中文-数字读法
519
+ language = "zh"
520
+ text = LangSSML.to_chinese_number(text)
521
+ pass
522
+ elif SSML == "currency":
523
+ # 中文-按金额发音
524
+ language = "zh"
525
+ text = LangSSML.to_chinese_currency(text)
526
+ pass
527
+ elif SSML == "date":
528
+ # 中文-按金额发音
529
+ language = "zh"
530
+ text = LangSSML.to_chinese_date(text)
531
+ pass
532
+ LangSegment._addwords(words,language,text,score,SSML)
533
+ pass
534
+
535
+ # ----------------------------------------------------------
536
+
537
+ @staticmethod
538
+ def _restore_number(matche):
539
+ value = matche.group(0)
540
+ text_cache = LangSegment._text_cache
541
+ if value in text_cache:
542
+ process , data = text_cache[value]
543
+ tag , match = data
544
+ value = match
545
+ return value
546
+
547
+ @staticmethod
548
+ def _pattern_symbols(item , text):
549
+ if text is None:return text
550
+ tag , pattern , process = item
551
+ matches = pattern.findall(text)
552
+ if len(matches) == 1 and "".join(matches[0]) == text:
553
+ return text
554
+ for i , match in enumerate(matches):
555
+ key = f"⑥{tag}{i:06d}⑥"
556
+ text = re.sub(pattern , key , text , count=1)
557
+ LangSegment._text_cache[key] = (process , (tag , match))
558
+ return text
559
+
560
+ @staticmethod
561
+ def _process_symbol(words,data):
562
+ tag , match = data
563
+ language = match[1]
564
+ text = match[2]
565
+ score = 1.0
566
+ filters = LangSegment._get_filters_string()
567
+ if language not in filters:
568
+ LangSegment._process_symbol_SSML(words,data)
569
+ else:
570
+ LangSegment._addwords(words,language,text,score,True)
571
+ pass
572
+
573
+ @staticmethod
574
+ def _process_english(words,data):
575
+ tag , match = data
576
+ text = match[0]
577
+ filters = LangSegment._get_filters_string()
578
+ priority_language = filters[:2]
579
+ # Preview feature, other language segmentation processing
580
+ enablePreview = LangSegment.EnablePreview
581
+ if enablePreview == True:
582
+ # Experimental: Other language support
583
+ regex_pattern = re.compile(r'(.*?[。.??!!]+[\n]{,1})')
584
+ lines = regex_pattern.split(text)
585
+ for index , text in enumerate(lines):
586
+ if len(text.strip()) == 0:continue
587
+ cleans_text = LangSegment._cleans_text(text)
588
+ language,score = LangSegment._lang_classify(cleans_text)
589
+ if language not in filters:
590
+ language,score = LangSegment._mean_processing(cleans_text)
591
+ if language is None or score <= 0.0:continue
592
+ elif language in filters:pass # pass
593
+ elif score >= 0.95:continue # High score, but not in the filter, excluded.
594
+ elif score <= 0.15 and filters[:2] == "fr":language = priority_language
595
+ else:language = "en"
596
+ LangSegment._addwords(words,language,text,score)
597
+ else:
598
+ # Default is English
599
+ language, score = "en", 1.0
600
+ LangSegment._addwords(words,language,text,score)
601
+ pass
602
+
603
+ @staticmethod
604
+ def _process_Russian(words,data):
605
+ tag , match = data
606
+ text = match[0]
607
+ language = "ru"
608
+ score = 1.0
609
+ LangSegment._addwords(words,language,text,score)
610
+ pass
611
+
612
+ @staticmethod
613
+ def _process_Thai(words,data):
614
+ tag , match = data
615
+ text = match[0]
616
+ language = "th"
617
+ score = 1.0
618
+ LangSegment._addwords(words,language,text,score)
619
+ pass
620
+
621
+ @staticmethod
622
+ def _process_korean(words,data):
623
+ tag , match = data
624
+ text = match[0]
625
+ language = "ko"
626
+ score = 1.0
627
+ LangSegment._addwords(words,language,text,score)
628
+ pass
629
+
630
+ @staticmethod
631
+ def _process_quotes(words,data):
632
+ tag , match = data
633
+ text = "".join(match)
634
+ childs = LangSegment.PARSE_TAG.findall(text)
635
+ if len(childs) > 0:
636
+ LangSegment._process_tags(words , text , False)
637
+ else:
638
+ cleans_text = LangSegment._cleans_text(match[1])
639
+ if len(cleans_text) <= 5:
640
+ LangSegment._parse_language(words,text)
641
+ else:
642
+ language,score = LangSegment._lang_classify(cleans_text)
643
+ LangSegment._addwords(words,language,text,score)
644
+ pass
645
+
646
+
647
+ @staticmethod
648
+ def _process_pinyin(words,data):
649
+ tag , match = data
650
+ text = match
651
+ language = "zh"
652
+ score = 1.0
653
+ LangSegment._addwords(words,language,text,score)
654
+ pass
655
+
656
+ @staticmethod
657
+ def _process_number(words,data): # "$0" process only
658
+ """
659
+ Numbers alone cannot accurately identify language.
660
+ Because numbers are universal in all languages.
661
+ So it won't be executed here, just for testing.
662
+ """
663
+ tag , match = data
664
+ language = words[0]["lang"] if len(words) > 0 else "zh"
665
+ text = match
666
+ score = 0.0
667
+ LangSegment._addwords(words,language,text,score)
668
+ pass
669
+
670
+ @staticmethod
671
+ def _process_tags(words , text , root_tag):
672
+ text_cache = LangSegment._text_cache
673
+ segments = re.split(LangSegment.PARSE_TAG, text)
674
+ segments_len = len(segments) - 1
675
+ for index , text in enumerate(segments):
676
+ if root_tag:LangSegment._lang_eos = index >= segments_len
677
+ if LangSegment.PARSE_TAG.match(text):
678
+ process , data = text_cache[text]
679
+ if process:process(words , data)
680
+ else:
681
+ LangSegment._parse_language(words , text)
682
+ pass
683
+ return words
684
+
685
+ @staticmethod
686
+ def _merge_results(words):
687
+ new_word = []
688
+ for index , cur_data in enumerate(words):
689
+ if "symbol" in cur_data:del cur_data["symbol"]
690
+ if index == 0:new_word.append(cur_data)
691
+ else:
692
+ pre_data = new_word[-1]
693
+ if cur_data["lang"] == pre_data["lang"]:
694
+ pre_data["text"] = f'{pre_data["text"]}{cur_data["text"]}'
695
+ else:new_word.append(cur_data)
696
+ return new_word
697
+
698
+ @staticmethod
699
+ def _parse_symbols(text):
700
+ TAG_NUM = "00" # "00" => default channels , "$0" => testing channel
701
+ TAG_S1,TAG_S2,TAG_P1,TAG_P2,TAG_EN,TAG_KO,TAG_RU,TAG_TH = "$1" ,"$2" ,"$3" ,"$4" ,"$5" ,"$6" ,"$7","$8"
702
+ TAG_BASE = re.compile(fr'(([【《((“‘"\']*[LANGUAGE]+[\W\s]*)+)')
703
+ # Get custom language filter
704
+ filters = LangSegment.Langfilters
705
+ filters = filters if filters is not None else ""
706
+ # =======================================================================================================
707
+ # Experimental: Other language support.Thử nghiệm: Hỗ trợ ngôn ngữ khác.Expérimental : prise en charge d’autres langues.
708
+ # 相关语言字符如有缺失,熟悉相关语言的朋友,可以提交把缺失的发音符��补全。
709
+ # If relevant language characters are missing, friends who are familiar with the relevant languages can submit a submission to complete the missing pronunciation symbols.
710
+ # S'il manque des caractères linguistiques pertinents, les amis qui connaissent les langues concernées peuvent soumettre une soumission pour compléter les symboles de prononciation manquants.
711
+ # Nếu thiếu ký tự ngôn ngữ liên quan, những người bạn quen thuộc với ngôn ngữ liên quan có thể gửi bài để hoàn thành các ký hiệu phát âm còn thiếu.
712
+ # -------------------------------------------------------------------------------------------------------
713
+ # Preview feature, other language support
714
+ enablePreview = LangSegment.EnablePreview
715
+ if "fr" in filters or \
716
+ "vi" in filters:enablePreview = True
717
+ LangSegment.EnablePreview = enablePreview
718
+ # 实验性:法语字符支持。Prise en charge des caractères français
719
+ RE_FR = "" if not enablePreview else "àáâãäåæçèéêëìíîïðñòóôõöùúûüýþÿ"
720
+ # 实验性:越南语字符支持。Hỗ trợ ký tự tiếng Việt
721
+ RE_VI = "" if not enablePreview else "đơưăáàảãạắằẳẵặấầẩẫậéèẻẽẹếềểễệíìỉĩịóòỏõọốồổỗộớờởỡợúùủũụứừửữựôâêơưỷỹ"
722
+ # -------------------------------------------------------------------------------------------------------
723
+ # Basic options:
724
+ process_list = [
725
+ ( TAG_S1 , re.compile(LangSegment.SYMBOLS_PATTERN) , LangSegment._process_symbol ), # Symbol Tag
726
+ ( TAG_KO , re.compile(re.sub(r'LANGUAGE',f'\uac00-\ud7a3',TAG_BASE.pattern)) , LangSegment._process_korean ), # Korean words
727
+ ( TAG_TH , re.compile(re.sub(r'LANGUAGE',f'\u0E00-\u0E7F',TAG_BASE.pattern)) , LangSegment._process_Thai ), # Thai words support.
728
+ ( TAG_RU , re.compile(re.sub(r'LANGUAGE',f'А-Яа-яЁё',TAG_BASE.pattern)) , LangSegment._process_Russian ), # Russian words support.
729
+ ( TAG_NUM , re.compile(r'(\W*\d+\W+\d*\W*\d*)') , LangSegment._process_number ), # Number words, Universal in all languages, Ignore it.
730
+ ( TAG_EN , re.compile(re.sub(r'LANGUAGE',f'a-zA-Z{RE_FR}{RE_VI}',TAG_BASE.pattern)) , LangSegment._process_english ), # English words + Other language support.
731
+ ( TAG_P1 , re.compile(r'(["\'])(.*?)(\1)') , LangSegment._process_quotes ), # Regular quotes
732
+ ( TAG_P2 , re.compile(r'([\n]*[【《((“‘])([^【《((“‘’”))》】]{3,})([’”))》】][\W\s]*[\n]{,1})') , LangSegment._process_quotes ), # Special quotes, There are left and right.
733
+ ]
734
+ # Extended options: Default False
735
+ if LangSegment.keepPinyin == True:process_list.insert(1 ,
736
+ ( TAG_S2 , re.compile(r'([\(({](?:\s*\w*\d\w*\s*)+[})\)])') , LangSegment._process_pinyin ), # Chinese Pinyin Tag.
737
+ )
738
+ # -------------------------------------------------------------------------------------------------------
739
+ words = []
740
+ lines = re.findall(r'.*\n*', re.sub(LangSegment.PARSE_TAG, '' ,text))
741
+ for index , text in enumerate(lines):
742
+ if len(text.strip()) == 0:continue
743
+ LangSegment._lang_eos = False
744
+ LangSegment._text_cache = {}
745
+ for item in process_list:
746
+ text = LangSegment._pattern_symbols(item , text)
747
+ cur_word = LangSegment._process_tags([] , text , True)
748
+ if len(cur_word) == 0:continue
749
+ cur_data = cur_word[0] if len(cur_word) > 0 else None
750
+ pre_data = words[-1] if len(words) > 0 else None
751
+ if cur_data and pre_data and cur_data["lang"] == pre_data["lang"] \
752
+ and cur_data["symbol"] == False and pre_data["symbol"] :
753
+ cur_data["text"] = f'{pre_data["text"]}{cur_data["text"]}'
754
+ words.pop()
755
+ words += cur_word
756
+ if LangSegment.isLangMerge == True:words = LangSegment._merge_results(words)
757
+ lang_count = LangSegment._lang_count
758
+ if lang_count and len(lang_count) > 0:
759
+ lang_count = dict(sorted(lang_count.items(), key=lambda x: x[1], reverse=True))
760
+ lang_count = list(lang_count.items())
761
+ LangSegment._lang_count = lang_count
762
+ return words
763
+
764
+ @staticmethod
765
+ def setfilters(filters):
766
+ # 当过滤器更改时,清除缓存
767
+ # 필터가 변경되면 캐시를 지웁니다.
768
+ # フィルタが変更されると、キャッシュがクリアされます
769
+ # When the filter changes, clear the cache
770
+ if LangSegment.Langfilters != filters:
771
+ LangSegment._clears()
772
+ LangSegment.Langfilters = filters
773
+ pass
774
+
775
+ @staticmethod
776
+ def getfilters():
777
+ return LangSegment.Langfilters
778
+
779
+ @staticmethod
780
+ def setPriorityThreshold(threshold:float):
781
+ LangSegment.LangPriorityThreshold = threshold
782
+ pass
783
+
784
+ @staticmethod
785
+ def getPriorityThreshold():
786
+ return LangSegment.LangPriorityThreshold
787
+
788
+ @staticmethod
789
+ def getCounts():
790
+ lang_count = LangSegment._lang_count
791
+ if lang_count is not None:return lang_count
792
+ text_langs = LangSegment._text_langs
793
+ if text_langs is None or len(text_langs) == 0:return [("zh",0)]
794
+ lang_counts = defaultdict(int)
795
+ for d in text_langs:lang_counts[d['lang']] += int(len(d['text'])*2) if d['lang'] == "zh" else len(d['text'])
796
+ lang_counts = dict(sorted(lang_counts.items(), key=lambda x: x[1], reverse=True))
797
+ lang_counts = list(lang_counts.items())
798
+ LangSegment._lang_count = lang_counts
799
+ return lang_counts
800
+
801
+ @staticmethod
802
+ def getTexts(text:str):
803
+ if text is None or len(text.strip()) == 0:
804
+ LangSegment._clears()
805
+ return []
806
+ # lasts
807
+ text_langs = LangSegment._text_langs
808
+ if LangSegment._text_lasts == text and text_langs is not None:return text_langs
809
+ # parse
810
+ LangSegment._text_waits = []
811
+ LangSegment._lang_count = None
812
+ LangSegment._text_lasts = text
813
+ text = LangSegment._parse_symbols(text)
814
+ LangSegment._text_langs = text
815
+ return text
816
+
817
+ @staticmethod
818
+ def classify(text:str):
819
+ return LangSegment.getTexts(text)
820
+
821
+
822
+ def setLangMerge(value:bool):
823
+ """是否优化合并结果
824
+ """
825
+ LangSegment.isLangMerge = value
826
+ pass
827
+
828
+ def getLangMerge():
829
+ """是否优化合并结果
830
+ """
831
+ return LangSegment.isLangMerge
832
+
833
+
834
+ def setfilters(filters):
835
+ """
836
+ 功能:语言过滤组功能, 可以指定保留语言。不在过滤组中的语言将被清除。您可随心搭配TTS语音合成所支持的语言。
837
+ 기능: 언어 필터 그룹 기능, 예약된 언어를 지정할 수 있습니다. 필터 그룹에 없는 언어는 지워집니다. TTS 텍스트에서 지원하는 언어를 원하는 대로 일치시킬 수 있습니다.
838
+ 機能:言語フィルターグループ機能で、予約言語を指定できます。フィルターグループに含まれていない言語はクリアされます。TTS音声合成がサポートする言語を自由に組み合わせることができます。
839
+ Function: Language filter group function, you can specify reserved languages. \n
840
+ Languages not in the filter group will be cleared. You can match the languages supported by TTS Text To Speech as you like.\n
841
+ Args:
842
+ filters (list): ["zh", "en", "ja", "ko"] 排名越前,优先级越高
843
+ """
844
+ LangSegment.setfilters(filters)
845
+ pass
846
+
847
+ def getfilters():
848
+ """
849
+ 功能:语言过滤组功能, 可以指定保留语言。不在过滤组中的语言将被清除。您可随心搭配TTS语音合成所支持的语言。
850
+ 기능: 언어 필터 그룹 기능, 예약된 언어를 지정할 수 있습니다. 필터 그룹에 없는 언어는 지워집니다. TTS 텍스트에서 지원하는 언어를 원하는 대로 일치시킬 수 있습니다.
851
+ 機能:言語フィルターグループ機能で、予約言語を指定できます。フィルターグループに含まれていない言語はクリアされます。TTS音声合成がサポートする言語を自由に組み合わせることができます。
852
+ Function: Language filter group function, you can specify reserved languages. \n
853
+ Languages not in the filter group will be cleared. You can match the languages supported by TTS Text To Speech as you like.\n
854
+ Args:
855
+ filters (list): ["zh", "en", "ja", "ko"] 排名越前,优先级越高
856
+ """
857
+ return LangSegment.getfilters()
858
+
859
+ # # @Deprecated:Use shorter setfilters
860
+ # def setLangfilters(filters):
861
+ # """
862
+ # >0.1.9废除:使用更简短的setfilters
863
+ # """
864
+ # setfilters(filters)
865
+ # # @Deprecated:Use shorter getfilters
866
+ # def getLangfilters():
867
+ # """
868
+ # >0.1.9废除:使用更简短的getfilters
869
+ # """
870
+ # return getfilters()
871
+
872
+
873
+ def setKeepPinyin(value:bool):
874
+ """
875
+ 可选保留:支持中文数字拼音格式,更方便前端实现拼音音素修改和推理,默认关闭 False 。\n
876
+ 开启后 True ,括号内的数字拼音格式均保留,并识别输出为:"zh"中文。
877
+ """
878
+ LangSegment.keepPinyin = value
879
+ pass
880
+
881
+ def getKeepPinyin():
882
+ """
883
+ 可选保留:支持中文数字拼音格式,更方便前端实现拼音音素修改和推理,默认关闭 False 。\n
884
+ 开启后 True ,括号内的数字拼音格式均保留,并识别输出为:"zh"中文。
885
+ """
886
+ return LangSegment.keepPinyin
887
+
888
+ def setEnablePreview(value:bool):
889
+ """
890
+ 启用预览版功能(默认关闭)
891
+ Enable preview functionality (off by default)
892
+ Args:
893
+ value (bool): True=开启, False=��闭
894
+ """
895
+ LangSegment.EnablePreview = (value == True)
896
+ pass
897
+
898
+ def getEnablePreview():
899
+ """
900
+ 启用预览版功能(默认关闭)
901
+ Enable preview functionality (off by default)
902
+ Args:
903
+ value (bool): True=开启, False=关闭
904
+ """
905
+ return LangSegment.EnablePreview == True
906
+
907
+ def setPriorityThreshold(threshold:float):
908
+ """
909
+ 中/日语言优先级阀值(评分范围为 0 ~ 1):评分低于设定阀值 <0.89 时,启用 filters 中的优先级。\n
910
+ 中国語/日本語の優先度しきい値(スコア範囲0〜1):スコアが設定されたしきい値<0.89未満の場合、フィルターの優先度が有効になります。\n
911
+ 중/일본어 우선 순위 임계값(점수 범위 0-1): 점수가 설정된 임계값 <0.89보다 낮을 때 필터에서 우선 순위를 활성화합니다.
912
+ Chinese and Japanese language priority threshold (score range is 0 ~ 1): The default threshold is 0.89. \n
913
+ Only the common characters between Chinese and Japanese are processed with confidence and priority. \n
914
+ Args:
915
+ threshold:float (score range is 0 ~ 1)
916
+ """
917
+ LangSegment.setPriorityThreshold(threshold)
918
+ pass
919
+
920
+ def getPriorityThreshold():
921
+ """
922
+ 中/日语言优先级阀值(评分范围为 0 ~ 1):评分低于设定阀值 <0.89 时,启用 filters 中的优先级。\n
923
+ 中国語/日本語の優先度しきい値(スコア範囲0〜1):スコアが設定されたしきい値<0.89未満の場合、フィルターの優先度が有効になります。\n
924
+ 중/일본어 우선 순위 임계값(점수 범위 0-1): 점수가 설정된 임계값 <0.89보다 낮을 때 필터에서 우선 순위를 활성화합니다.
925
+ Chinese and Japanese language priority threshold (score range is 0 ~ 1): The default threshold is 0.89. \n
926
+ Only the common characters between Chinese and Japanese are processed with confidence and priority. \n
927
+ Args:
928
+ threshold:float (score range is 0 ~ 1)
929
+ """
930
+ return LangSegment.getPriorityThreshold()
931
+
932
+ def getTexts(text:str):
933
+ """
934
+ 功能:对输入的文本进行多语种分词\n
935
+ 기능: 입력 텍스트의 다국어 분할 \n
936
+ 機能:入力されたテキストの多言語セグメンテーション\n
937
+ Feature: Tokenizing multilingual text input.\n
938
+ 参数-Args:
939
+ text (str): Text content,文本内容\n
940
+ 返回-Returns:
941
+ list: 示例结果:[{'lang':'zh','text':'?'},...]\n
942
+ lang=语种 , text=内容\n
943
+ """
944
+ return LangSegment.getTexts(text)
945
+
946
+ def getCounts():
947
+ """
948
+ 功能:分词结果统计,按语种字数降序,用于确定其主要语言\n
949
+ 기능: 주요 언어를 결정하는 데 사용되는 언어별 단어 수 내림차순으로 단어 분할 결과의 통계 \n
950
+ 機能:主な言語を決定するために使用される、言語の単語数の降順による単語分割結果の統計\n
951
+ Function: Tokenizing multilingual text input.\n
952
+ 返回-Returns:
953
+ list: 示例结果:[('zh', 5), ('ja', 2), ('en', 1)] = [(语种,字数含标点)]\n
954
+ """
955
+ return LangSegment.getCounts()
956
+
957
+ def classify(text:str):
958
+ """
959
+ 功能:兼容接口实现
960
+ Function: Compatible interface implementation
961
+ """
962
+ return LangSegment.classify(text)
963
+
964
+ def printList(langlist):
965
+ """
966
+ 功能:打印数组结果
967
+ 기능: 어레이 결과 인쇄
968
+ 機能:配列結果を印刷
969
+ Function: Print array results
970
+ """
971
+ print("\n===================【打印结果】===================")
972
+ if langlist is None or len(langlist) == 0:
973
+ print("无内容结果,No content result")
974
+ return
975
+ for line in langlist:
976
+ print(line)
977
+ pass
978
+
979
+
980
+
981
+ def main():
982
+
983
+ # -----------------------------------
984
+ # 更新日志:新版本分词更加精准。
985
+ # Changelog: The new version of the word segmentation is more accurate.
986
+ # チェンジログ:新しいバージョンの単語セグメンテーションはより正確です。
987
+ # Changelog: 분할이라는 단어의 새로운 버전이 더 정확합니다.
988
+ # -----------------------------------
989
+
990
+ # 输入示例1:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
991
+ # text = "“昨日は雨が降った,音楽、映画。。。”你今天学习日语了吗?春は桜の季節です。语种分词是语音合成必不可少的环节。言語分詞は音声合成に欠かせない環節である!"
992
+
993
+ # 输入示例2:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
994
+ # text = "欢迎来玩。東京,は日本の首都です。欢迎来玩. 太好了!"
995
+
996
+ # 输入示例3:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
997
+ # text = "明日、私たちは海辺にバカンスに行きます。你会说日语吗:“中国語、話せますか” 你的日语真好啊!"
998
+
999
+
1000
+ # 输入示例4:(包含日文,中文,韩语,英文)Input Example 4: (including Japanese, Chinese, Korean, English)
1001
+ # text = "你的名字叫<ja>佐々木?<ja>吗?韩语中的안녕 오빠读什么呢?あなたの体育の先生は誰ですか? 此次发布会带来了四款iPhone 15系列机型和三款Apple Watch等一系列新品,这次的iPad Air采用了LCD屏幕"
1002
+
1003
+
1004
+ # 试验性支持:"fr"法语 , "vi"越南语 , "ru"俄语 , "th"泰语。Experimental: Other language support.
1005
+ LangSegment.setfilters(["fr", "vi" , "ja", "zh", "ko", "en" , "ru" , "th"])
1006
+ text = """
1007
+ 我喜欢在雨天里听音乐。
1008
+ I enjoy listening to music on rainy days.
1009
+ 雨の日に音楽を聴くのが好きです。
1010
+ 비 오는 날에 음악을 듣는 것을 즐깁니다。
1011
+ J'aime écouter de la musique les jours de pluie.
1012
+ Tôi thích nghe nhạc vào những ngày mưa.
1013
+ Мне нравится слушать музыку в дождливую погоду.
1014
+ ฉันชอบฟังเพลงในวันที่ฝนตก
1015
+ """
1016
+
1017
+
1018
+
1019
+ # 进行分词:(接入TTS项目仅需一行代码调用)Segmentation: (Only one line of code is required to access the TTS project)
1020
+ langlist = LangSegment.getTexts(text)
1021
+ printList(langlist)
1022
+
1023
+
1024
+ # 语种统计:Language statistics:
1025
+ print("\n===================【语种统计】===================")
1026
+ # 获取所有语种数组结果,根据内容字数降序排列
1027
+ # Get the array results in all languages, sorted in descending order according to the number of content words
1028
+ langCounts = LangSegment.getCounts()
1029
+ print(langCounts , "\n")
1030
+
1031
+ # 根据结果获取内容的主要语种 (语言,字数含标点)
1032
+ # Get the main language of content based on the results (language, word count including punctuation)
1033
+ lang , count = langCounts[0]
1034
+ print(f"输入内容的主要语言为 = {lang} ,字数 = {count}")
1035
+ print("==================================================\n")
1036
+
1037
+
1038
+ # 分词输出:lang=语言,text=内容。Word output: lang = language, text = content
1039
+ # ===================【打印结果】===================
1040
+ # {'lang': 'zh', 'text': '你的名字叫'}
1041
+ # {'lang': 'ja', 'text': '佐々木?'}
1042
+ # {'lang': 'zh', 'text': '吗?韩语中的'}
1043
+ # {'lang': 'ko', 'text': '안녕 오빠'}
1044
+ # {'lang': 'zh', 'text': '读什么呢?'}
1045
+ # {'lang': 'ja', 'text': 'あなたの体育の先生は誰ですか?'}
1046
+ # {'lang': 'zh', 'text': ' 此次发布会带来了四款'}
1047
+ # {'lang': 'en', 'text': 'i Phone '}
1048
+ # {'lang': 'zh', 'text': '15系列机型和三款'}
1049
+ # {'lang': 'en', 'text': 'Apple Watch '}
1050
+ # {'lang': 'zh', 'text': '等一系列新品,这次的'}
1051
+ # {'lang': 'en', 'text': 'i Pad Air '}
1052
+ # {'lang': 'zh', 'text': '采用了'}
1053
+ # {'lang': 'en', 'text': 'L C D '}
1054
+ # {'lang': 'zh', 'text': '屏幕'}
1055
+ # ===================【语种统计】===================
1056
+
1057
+ # ===================【语种统计】===================
1058
+ # [('zh', 51), ('ja', 19), ('en', 18), ('ko', 5)]
1059
+
1060
+ # 输入内容的主要语言为 = zh ,字数 = 51
1061
+ # ==================================================
1062
+ # The main language of the input content is = zh, word count = 51
1063
+
1064
+
1065
+ if __name__ == "__main__":
1066
+ main()
1067
+
1068
+
LangSegment/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .LangSegment import LangSegment,getTexts,classify,getCounts,printList,setfilters,getfilters,setPriorityThreshold,getPriorityThreshold,setEnablePreview,getEnablePreview,setKeepPinyin,getKeepPinyin,setLangMerge,getLangMerge
2
+
3
+
4
+ # release
5
+ __version__ = '0.3.5'
6
+
7
+
8
+ # develop
9
+ __develop__ = 'dev-0.0.1'
LangSegment/utils/__init__.py ADDED
File without changes
LangSegment/utils/num.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Digital processing from GPT_SoVITS num.py (thanks)
15
+ """
16
+ Rules to verbalize numbers into Chinese characters.
17
+ https://zh.wikipedia.org/wiki/中文数字#現代中文
18
+ """
19
+
20
+ import re
21
+ from collections import OrderedDict
22
+ from typing import List
23
+
24
+ DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')}
25
+ UNITS = OrderedDict({
26
+ 1: '十',
27
+ 2: '百',
28
+ 3: '千',
29
+ 4: '万',
30
+ 8: '亿',
31
+ })
32
+
33
+ COM_QUANTIFIERS = '(处|台|架|枚|趟|幅|平|方|堵|间|床|株|批|项|例|列|篇|栋|注|亩|封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)'
34
+
35
+ # 分数表达式
36
+ RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
37
+
38
+
39
+ def replace_frac(match) -> str:
40
+ """
41
+ Args:
42
+ match (re.Match)
43
+ Returns:
44
+ str
45
+ """
46
+ sign = match.group(1)
47
+ nominator = match.group(2)
48
+ denominator = match.group(3)
49
+ sign: str = "负" if sign else ""
50
+ nominator: str = num2str(nominator)
51
+ denominator: str = num2str(denominator)
52
+ result = f"{sign}{denominator}分之{nominator}"
53
+ return result
54
+
55
+
56
+ # 百分数表达式
57
+ RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%')
58
+
59
+
60
+ def replace_percentage(match) -> str:
61
+ """
62
+ Args:
63
+ match (re.Match)
64
+ Returns:
65
+ str
66
+ """
67
+ sign = match.group(1)
68
+ percent = match.group(2)
69
+ sign: str = "负" if sign else ""
70
+ percent: str = num2str(percent)
71
+ result = f"{sign}百分之{percent}"
72
+ return result
73
+
74
+
75
+ # 整数表达式
76
+ # 带负号的整数 -10
77
+ RE_INTEGER = re.compile(r'(-)' r'(\d+)')
78
+
79
+
80
+ def replace_negative_num(match) -> str:
81
+ """
82
+ Args:
83
+ match (re.Match)
84
+ Returns:
85
+ str
86
+ """
87
+ sign = match.group(1)
88
+ number = match.group(2)
89
+ sign: str = "负" if sign else ""
90
+ number: str = num2str(number)
91
+ result = f"{sign}{number}"
92
+ return result
93
+
94
+
95
+ # 编号-无符号整形
96
+ # 00078
97
+ RE_DEFAULT_NUM = re.compile(r'\d{3}\d*')
98
+
99
+
100
+ def replace_default_num(match):
101
+ """
102
+ Args:
103
+ match (re.Match)
104
+ Returns:
105
+ str
106
+ """
107
+ number = match.group(0)
108
+ return verbalize_digit(number, alt_one=True)
109
+
110
+
111
+ # 加减乘除
112
+ # RE_ASMD = re.compile(
113
+ # r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
114
+ RE_ASMD = re.compile(
115
+ r'((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))')
116
+
117
+ asmd_map = {
118
+ '+': '加',
119
+ '-': '减',
120
+ '×': '乘',
121
+ '÷': '除',
122
+ '=': '等于'
123
+ }
124
+
125
+ def replace_asmd(match) -> str:
126
+ """
127
+ Args:
128
+ match (re.Match)
129
+ Returns:
130
+ str
131
+ """
132
+ result = match.group(1) + asmd_map[match.group(8)] + match.group(9)
133
+ return result
134
+
135
+
136
+ # 次方专项
137
+ RE_POWER = re.compile(r'[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+')
138
+
139
+ power_map = {
140
+ '⁰': '0',
141
+ '¹': '1',
142
+ '²': '2',
143
+ '³': '3',
144
+ '⁴': '4',
145
+ '⁵': '5',
146
+ '⁶': '6',
147
+ '⁷': '7',
148
+ '⁸': '8',
149
+ '⁹': '9',
150
+ 'ˣ': 'x',
151
+ 'ʸ': 'y',
152
+ 'ⁿ': 'n'
153
+ }
154
+
155
+ def replace_power(match) -> str:
156
+ """
157
+ Args:
158
+ match (re.Match)
159
+ Returns:
160
+ str
161
+ """
162
+ power_num = ""
163
+ for m in match.group(0):
164
+ power_num += power_map[m]
165
+ result = "的" + power_num + "次方"
166
+ return result
167
+
168
+
169
+ # 数字表达式
170
+ # 纯小数
171
+ RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))')
172
+ # 正整数 + 量词
173
+ RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS)
174
+ RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))')
175
+
176
+
177
+ def replace_positive_quantifier(match) -> str:
178
+ """
179
+ Args:
180
+ match (re.Match)
181
+ Returns:
182
+ str
183
+ """
184
+ number = match.group(1)
185
+ match_2 = match.group(2)
186
+ if match_2 == "+":
187
+ match_2 = "多"
188
+ match_2: str = match_2 if match_2 else ""
189
+ quantifiers: str = match.group(3)
190
+ number: str = num2str(number)
191
+ result = f"{number}{match_2}{quantifiers}"
192
+ return result
193
+
194
+
195
+ def replace_number(match) -> str:
196
+ """
197
+ Args:
198
+ match (re.Match)
199
+ Returns:
200
+ str
201
+ """
202
+ sign = match.group(1)
203
+ number = match.group(2)
204
+ pure_decimal = match.group(5)
205
+ if pure_decimal:
206
+ result = num2str(pure_decimal)
207
+ else:
208
+ sign: str = "负" if sign else ""
209
+ number: str = num2str(number)
210
+ result = f"{sign}{number}"
211
+ return result
212
+
213
+
214
+ # 范围表达式
215
+ # match.group(1) and match.group(8) are copy from RE_NUMBER
216
+
217
+ RE_RANGE = re.compile(
218
+ r"""
219
+ (?<![\d\+\-\×÷=]) # 使用反向前瞻以确保数字范围之前没有其他数字和操作符
220
+ ((-?)((\d+)(\.\d+)?)) # 匹配范围起始的负数或正数(整数或小数)
221
+ [-~] # 匹配范围分隔符
222
+ ((-?)((\d+)(\.\d+)?)) # 匹配范围结束的负数或正数(整数或小数)
223
+ (?![\d\+\-\×÷=]) # 使用正向前瞻以确保数字范围之后没有其他数字和操作符
224
+ """, re.VERBOSE)
225
+
226
+
227
+ def replace_range(match) -> str:
228
+ """
229
+ Args:
230
+ match (re.Match)
231
+ Returns:
232
+ str
233
+ """
234
+ first, second = match.group(1), match.group(6)
235
+ first = RE_NUMBER.sub(replace_number, first)
236
+ second = RE_NUMBER.sub(replace_number, second)
237
+ result = f"{first}到{second}"
238
+ return result
239
+
240
+
241
+ # ~至表达式
242
+ RE_TO_RANGE = re.compile(
243
+ r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)')
244
+
245
+ def replace_to_range(match) -> str:
246
+ """
247
+ Args:
248
+ match (re.Match)
249
+ Returns:
250
+ str
251
+ """
252
+ result = match.group(0).replace('~', '至')
253
+ return result
254
+
255
+
256
+ def _get_value(value_string: str, use_zero: bool=True) -> List[str]:
257
+ stripped = value_string.lstrip('0')
258
+ if len(stripped) == 0:
259
+ return []
260
+ elif len(stripped) == 1:
261
+ if use_zero and len(stripped) < len(value_string):
262
+ return [DIGITS['0'], DIGITS[stripped]]
263
+ else:
264
+ return [DIGITS[stripped]]
265
+ else:
266
+ largest_unit = next(
267
+ power for power in reversed(UNITS.keys()) if power < len(stripped))
268
+ first_part = value_string[:-largest_unit]
269
+ second_part = value_string[-largest_unit:]
270
+ return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(
271
+ second_part)
272
+
273
+
274
+ def verbalize_cardinal(value_string: str) -> str:
275
+ if not value_string:
276
+ return ''
277
+
278
+ # 000 -> '零' , 0 -> '零'
279
+ value_string = value_string.lstrip('0')
280
+ if len(value_string) == 0:
281
+ return DIGITS['0']
282
+
283
+ result_symbols = _get_value(value_string)
284
+ # verbalized number starting with '一十*' is abbreviated as `十*`
285
+ if len(result_symbols) >= 2 and result_symbols[0] == DIGITS[
286
+ '1'] and result_symbols[1] == UNITS[1]:
287
+ result_symbols = result_symbols[1:]
288
+ return ''.join(result_symbols)
289
+
290
+
291
+ def verbalize_digit(value_string: str, alt_one=False) -> str:
292
+ result_symbols = [DIGITS[digit] for digit in value_string]
293
+ result = ''.join(result_symbols)
294
+ if alt_one:
295
+ result = result.replace("一", "幺")
296
+ return result
297
+
298
+
299
+ def num2str(value_string: str) -> str:
300
+ integer_decimal = value_string.split('.')
301
+ if len(integer_decimal) == 1:
302
+ integer = integer_decimal[0]
303
+ decimal = ''
304
+ elif len(integer_decimal) == 2:
305
+ integer, decimal = integer_decimal
306
+ else:
307
+ raise ValueError(
308
+ f"The value string: '${value_string}' has more than one point in it."
309
+ )
310
+
311
+ result = verbalize_cardinal(integer)
312
+
313
+ decimal = decimal.rstrip('0')
314
+ if decimal:
315
+ # '.22' is verbalized as '零点二二'
316
+ # '3.20' is verbalized as '三点二
317
+ result = result if result else "零"
318
+ result += '点' + verbalize_digit(decimal)
319
+ return result
320
+
321
+
322
+ if __name__ == "__main__":
323
+
324
+ text = ""
325
+ text = num2str(text)
326
+ print(text)
327
+ pass
README.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Soundsation
3
+ tags:
4
+ - music generation
5
+ - diffusion
6
+ emoji: 🎶
7
+ colorFrom: red
8
+ colorTo: purple
9
+ sdk: gradio
10
+ sdk_version: 5.20.0
11
+ app_file: app.py
12
+ short_description: Blazingly Fast and Embarrassingly Simple Song Generation
13
+ pinned: false
14
+ license: apache-2.0
15
+ ---
16
+
17
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from openai import OpenAI
3
+ import requests
4
+ import json
5
+ # from volcenginesdkarkruntime import Ark
6
+ import torch
7
+ import torchaudio
8
+ from einops import rearrange
9
+ import argparse
10
+ import json
11
+ import os
12
+ import spaces
13
+ from tqdm import tqdm
14
+ import random
15
+ import numpy as np
16
+ import sys
17
+ import base64
18
+ from soundsation.infer.infer_utils import (
19
+ get_reference_latent,
20
+ get_lrc_token,
21
+ get_audio_style_prompt,
22
+ get_text_style_prompt,
23
+ prepare_model,
24
+ get_negative_style_prompt
25
+ )
26
+ from soundsation.infer.infer import inference
27
+
28
+ MAX_SEED = np.iinfo(np.int32).max
29
+ device='cuda'
30
+ cfm, tokenizer, muq, vae, eval_model, eval_muq = prepare_model(device)
31
+ cfm = torch.compile(cfm)
32
+
33
+ @spaces.GPU
34
+ def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42, randomize_seed=False, steps=32, cfg_strength=4.0, file_type='wav', odeint_method='euler', preference_infer="quality first", edit=False, edit_segments=None, device='cuda'):
35
+ max_frames = 2048
36
+ sway_sampling_coef = -1 if steps < 32 else None
37
+ if randomize_seed:
38
+ seed = random.randint(0, MAX_SEED)
39
+ torch.manual_seed(seed)
40
+ vocal_flag = False
41
+ try:
42
+ lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device)
43
+ if current_prompt_type == 'audio':
44
+ style_prompt, vocal_flag = get_audio_style_prompt(muq, ref_audio_path)
45
+ else:
46
+ style_prompt = get_text_style_prompt(muq, text_prompt)
47
+ except Exception as e:
48
+ raise gr.Error(f"Error: {str(e)}")
49
+ negative_style_prompt = get_negative_style_prompt(device)
50
+ latent_prompt, pred_frames = get_reference_latent(device, max_frames, edit, edit_segments, ref_audio_path, vae)
51
+
52
+ if preference_infer == "quality first":
53
+ batch_infer_num = 5
54
+ else:
55
+ batch_infer_num = 1
56
+ generated_song = inference(cfm_model=cfm,
57
+ vae_model=vae,
58
+ eval_model=eval_model,
59
+ eval_muq=eval_muq,
60
+ cond=latent_prompt,
61
+ text=lrc_prompt,
62
+ duration=max_frames,
63
+ style_prompt=style_prompt,
64
+ negative_style_prompt=negative_style_prompt,
65
+ steps=steps,
66
+ cfg_strength=cfg_strength,
67
+ sway_sampling_coef=sway_sampling_coef,
68
+ start_time=start_time,
69
+ file_type=file_type,
70
+ vocal_flag=vocal_flag,
71
+ odeint_method=odeint_method,
72
+ pred_frames=pred_frames,
73
+ batch_infer_num=batch_infer_num,
74
+ )
75
+ return generated_song
76
+
77
+ def R1_infer1(theme, tags_gen, language):
78
+ try:
79
+ client = OpenAI(api_key=os.getenv('HS_DP_API'), base_url = "https://ark.cn-beijing.volces.com/api/v3")
80
+
81
+ llm_prompt = """
82
+ 请围绕"{theme}"主题生成一首符合"{tags}"风格的语言为{language}的完整歌词。严格遵循以下要求:
83
+
84
+ ### **强制格式规则**
85
+ 1. **仅输出时间戳和歌词**,禁止任何括号、旁白、段落标记(如副歌、间奏、尾奏等注释)。
86
+ 2. 每行格式必须为 `[mm:ss.xx]歌词内容`,时间戳与歌词间无空格,歌词内容需完整连贯。
87
+ 3. 时间戳需自然分布,**第一句歌词起始时间不得为 [00:00.00]**,需考虑前奏空白。
88
+
89
+ ### **内容与结构要求**
90
+ 1. 歌词应富有变化,使情绪递进,整体连贯有层次感。**每行歌词长度应自然变化**,切勿长度一致,导致很格式化。
91
+ 2. **时间戳分配应根据歌曲的标签、歌词的情感、节奏来合理推测**,而非机械地按照歌词长度分配。
92
+ 3. 间奏/尾奏仅通过时间空白体现(如从 [02:30.00] 直接跳至 [02:50.00]),**无需文字描述**。
93
+
94
+ ### **负面示例(禁止出现)**
95
+ - 错误:[01:30.00](钢琴间奏)
96
+ - 错误:[02:00.00][副歌]
97
+ - 错误:空行、换行符、注释
98
+ """
99
+
100
+ response = client.chat.completions.create(
101
+ model="ep-20250304144033-nr9wl",
102
+ messages=[
103
+ {"role": "system", "content": "You are a professional musician who has been invited to make music-related comments."},
104
+ {"role": "user", "content": llm_prompt.format(theme=theme, tags=tags_gen, language=language)},
105
+ ],
106
+ stream=False
107
+ )
108
+
109
+ info = response.choices[0].message.content
110
+
111
+ return info
112
+
113
+ except requests.exceptions.RequestException as e:
114
+ print(f'请求出错: {e}')
115
+ return {}
116
+
117
+
118
+
119
+ def R1_infer2(tags_lyrics, lyrics_input):
120
+ client = OpenAI(api_key=os.getenv('HS_DP_API'), base_url = "https://ark.cn-beijing.volces.com/api/v3")
121
+
122
+ llm_prompt = """
123
+ {lyrics_input}这是一首歌的歌词,每一行是一句歌词,{tags_lyrics}是我希望这首歌的风格,我现在想要给这首歌的每一句歌词打时间戳得到LRC,我希望时间戳分配应根据歌曲的标签、歌词的情感、节奏来合理推测,而非机械地按照歌词长度分配。第一句歌词的时间戳应考虑前奏长度,避免歌词从 `[00:00.00]` 直接开始。严格按照 LRC 格式输出歌词,每行格式为 `[mm:ss.xx]歌词内容`。最后的结果只输出LRC,不需要其他的解释。
124
+ """
125
+
126
+ response = client.chat.completions.create(
127
+ model="ep-20250304144033-nr9wl",
128
+ messages=[
129
+ {"role": "system", "content": "You are a professional musician who has been invited to make music-related comments."},
130
+ {"role": "user", "content": llm_prompt.format(lyrics_input=lyrics_input, tags_lyrics=tags_lyrics)},
131
+ ],
132
+ stream=False
133
+ )
134
+
135
+ info = response.choices[0].message.content
136
+
137
+ return info
138
+
139
+ css = """
140
+ /* 固定文本域高度并强制滚动条 */
141
+ .lyrics-scroll-box textarea {
142
+ height: 405px !important; /* 固定高度 */
143
+ max-height: 500px !important; /* 最大高度 */
144
+ overflow-y: auto !important; /* 垂直滚动 */
145
+ white-space: pre-wrap; /* 保留换行 */
146
+ line-height: 1.5; /* 行高优化 */
147
+ }
148
+
149
+ .gr-examples {
150
+ background: transparent !important;
151
+ border: 1px solid #e0e0e0 !important;
152
+ border-radius: 8px;
153
+ margin: 1rem 0 !important;
154
+ padding: 1rem !important;
155
+ }
156
+
157
+ """
158
+
159
+
160
+ with gr.Blocks(css=css) as demo:
161
+ gr.HTML(f"""
162
+ <div style="display: flex; align-items: center;">
163
+ <img src='src/brand.png'
164
+ style='width: 200px; height: 40%; display: block; margin: 0 auto 20px;'>
165
+ </div>
166
+
167
+ <div style="flex: 1; text-align: center;">
168
+ <div style="font-size: 2em; font-weight: bold; text-align: center; margin-bottom: 5px">
169
+ Soundsation
170
+ </div>
171
+ <div style="display:flex; justify-content: center; column-gap:4px;">
172
+ <a href="https://github.com/Soundsation/pipeline">
173
+ <img src='https://img.shields.io/badge/GitHub-Repo-green'>
174
+ </a>
175
+ </div>
176
+ </div>
177
+ """)
178
+
179
+ with gr.Tabs() as tabs:
180
+
181
+ # page 1
182
+ with gr.Tab("Music Generate", id=0):
183
+ with gr.Row():
184
+ with gr.Column():
185
+ lrc = gr.Textbox(
186
+ label="Lyrics",
187
+ placeholder="Input the full lyrics",
188
+ lines=12,
189
+ max_lines=50,
190
+ elem_classes="lyrics-scroll-box",
191
+ value="""[00:04.34]Tell me that I'm special\n[00:06.57]Tell me I look pretty\n[00:08.46]Tell me I'm a little angel\n[00:10.58]Sweetheart of your city\n[00:13.64]Say what I'm dying to hear\n[00:17.35]Cause I'm dying to hear you\n[00:20.86]Tell me I'm that new thing\n[00:22.93]Tell me that I'm relevant\n[00:24.96]Tell me that I got a big heart\n[00:27.04]Then back it up with evidence\n[00:29.94]I need it and I don't know why\n[00:34.28]This late at night\n[00:36.32]Isn't it lonely\n[00:39.24]I'd do anything to make you want me\n[00:43.40]I'd give it all up if you told me\n[00:47.42]That I'd be\n[00:49.43]The number one girl in your eyes\n[00:52.85]Your one and only\n[00:55.74]So what's it gon' take for you to want me\n[00:59.78]I'd give it all up if you told me\n[01:03.89]That I'd be\n[01:05.94]The number one girl in your eyes\n[01:11.34]Tell me I'm going real big places\n[01:14.32]Down to earth so friendly\n[01:16.30]And even through all the phases\n[01:18.46]Tell me you accept me\n[01:21.56]Well that's all I'm dying to hear\n[01:25.30]Yeah I'm dying to hear you\n[01:28.91]Tell me that you need me\n[01:30.85]Tell me that I'm loved\n[01:32.90]Tell me that I'm worth it\n"""
192
+ )
193
+
194
+ current_prompt_type = gr.State(value="audio")
195
+ with gr.Tabs() as inside_tabs:
196
+ with gr.Tab("Audio Prompt"):
197
+ audio_prompt = gr.Audio(label="Audio Prompt", type="filepath", value="./src/prompt/default.wav")
198
+ with gr.Tab("Text Prompt"):
199
+ text_prompt = gr.Textbox(
200
+ label="Text Prompt",
201
+ placeholder="Enter the Text Prompt, eg: emotional piano pop",
202
+ )
203
+ def update_prompt_type(evt: gr.SelectData):
204
+ return "audio" if evt.index == 0 else "text"
205
+
206
+ inside_tabs.select(
207
+ fn=update_prompt_type,
208
+ outputs=current_prompt_type
209
+ )
210
+
211
+ with gr.Column():
212
+ with gr.Accordion("Best Practices Guide", open=True):
213
+ gr.Markdown("""
214
+ 1. **Lyrics Format Requirements**
215
+ - Each line must follow: `[mm:ss.xx]Lyric content`
216
+ - Example of valid format:
217
+ ```
218
+ [00:10.00]Moonlight spills through broken blinds
219
+ [00:13.20]Your shadow dances on the dashboard shrine
220
+ ```
221
+
222
+ 2. **Audio Prompt Requirements**
223
+ - Reference audio should be ≥ 1 second, audio >10 seconds will be randomly clipped into 10 seconds.
224
+ - For optimal results, the 10-second clips should be carefully selected.
225
+ - Shorter clips may lead to incoherent generation.
226
+ 3. **Supported Languages**
227
+ - **English**
228
+ - More languages comming soon.
229
+
230
+ 4. **Editing Function in Advanced Settings**
231
+ - Using full-length audio as reference is recommended for best results.
232
+ - Use -1 to represent the start/end of audio (e.g. [[-1,25], [50,-1]] means "from start to 25s" and "from 50s to end").
233
+
234
+ 5. **Generate Preference**
235
+ - Quality First: Higher quality , slightly slower.
236
+ - Speed First: Faster generation with slightly reduced quality.
237
+
238
+ 6. **Others**
239
+ - If loading audio result is slow, you can select Output Format as mp3 in Advanced Settings.
240
+
241
+ """)
242
+ # Music_Duration = gr.Radio(["95s", "285s"], label="Music Duration", value="95s")
243
+ preference_infer = gr.Radio(["quality first", "speed first"], label="Preference", value="quality first")
244
+ lyrics_btn = gr.Button("Generate", variant="primary")
245
+ audio_output = gr.Audio(label="Audio Result", type="filepath", elem_id="audio_output")
246
+ with gr.Accordion("Advanced Settings", open=False):
247
+ seed = gr.Slider(
248
+ label="Seed",
249
+ minimum=0,
250
+ maximum=MAX_SEED,
251
+ step=1,
252
+ value=0,
253
+ )
254
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
255
+
256
+ steps = gr.Slider(
257
+ minimum=10,
258
+ maximum=100,
259
+ value=32,
260
+ step=1,
261
+ label="Diffusion Steps",
262
+ interactive=True,
263
+ elem_id="step_slider"
264
+ )
265
+ cfg_strength = gr.Slider(
266
+ minimum=1,
267
+ maximum=10,
268
+ value=4.0,
269
+ step=0.5,
270
+ label="CFG Strength",
271
+ interactive=True,
272
+ elem_id="step_slider"
273
+ )
274
+ edit = gr.Checkbox(label="edit", value=False)
275
+ edit_segments = gr.Textbox(
276
+ label="Edit Segments",
277
+ placeholder="Time segments to edit (in seconds). Format: `[[start1,end1],...]",
278
+ )
279
+
280
+ odeint_method = gr.Radio(["euler", "midpoint", "rk4","implicit_adams"], label="ODE Solver", value="euler")
281
+ file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="mp3")
282
+
283
+
284
+ gr.Examples(
285
+ examples=[
286
+ ["./src/prompt/pop_cn.wav"],
287
+ ["./src/prompt/pop_en.wav"],
288
+ ["./src/prompt/rock_cn.wav"],
289
+ ["./src/prompt/rock_en.wav"],
290
+ ["./src/prompt/country_cn.wav"],
291
+ ["./src/prompt/country_en.wav"],
292
+ ["./src/prompt/classic_cn.wav"],
293
+ ["./src/prompt/classic_en.wav"],
294
+ ["./src/prompt/jazz_cn.wav"],
295
+ ["./src/prompt/jazz_en.wav"],
296
+ ["./src/prompt/rap_cn.wav"],
297
+ ["./src/prompt/rap_en.wav"],
298
+ ["./src/prompt/default.wav"]
299
+ ],
300
+ inputs=[audio_prompt],
301
+ label="Audio Examples",
302
+ examples_per_page=13,
303
+ elem_id="audio-examples-container"
304
+ )
305
+
306
+ gr.Examples(
307
+ examples=[
308
+ ["Pop Emotional Piano"],
309
+ ["流行 情感 钢琴"],
310
+ ["Indie folk ballad, coming-of-age themes, acoustic guitar picking with harmonica interludes"],
311
+ ["独立民谣, 成长主题, 原声吉他弹奏与口琴间奏"]
312
+ ],
313
+ inputs=[text_prompt],
314
+ label="Text Examples",
315
+ examples_per_page=4,
316
+ elem_id="text-examples-container"
317
+ )
318
+
319
+ gr.Examples(
320
+ examples=[
321
+ ["""[00:04.34]Tell me that I'm special\n[00:06.57]Tell me I look pretty\n[00:08.46]Tell me I'm a little angel\n[00:10.58]Sweetheart of your city\n[00:13.64]Say what I'm dying to hear\n[00:17.35]Cause I'm dying to hear you\n[00:20.86]Tell me I'm that new thing\n[00:22.93]Tell me that I'm relevant\n[00:24.96]Tell me that I got a big heart\n[00:27.04]Then back it up with evidence\n[00:29.94]I need it and I don't know why\n[00:34.28]This late at night\n[00:36.32]Isn't it lonely\n[00:39.24]I'd do anything to make you want me\n[00:43.40]I'd give it all up if you told me\n[00:47.42]That I'd be\n[00:49.43]The number one girl in your eyes\n[00:52.85]Your one and only\n[00:55.74]So what's it gon' take for you to want me\n[00:59.78]I'd give it all up if you told me\n[01:03.89]That I'd be\n[01:05.94]The number one girl in your eyes\n[01:11.34]Tell me I'm going real big places\n[01:14.32]Down to earth so friendly\n[01:16.30]And even through all the phases\n[01:18.46]Tell me you accept me\n[01:21.56]Well that's all I'm dying to hear\n[01:25.30]Yeah I'm dying to hear you\n[01:28.91]Tell me that you need me\n[01:30.85]Tell me that I'm loved\n[01:32.90]Tell me that I'm worth it\n"""],
322
+ ["""[00:00.52]Abracadabra abracadabra\n[00:03.97]Ha\n[00:04.66]Abracadabra abracadabra\n[00:12.02]Yeah\n[00:15.80]Pay the toll to the angels\n[00:19.08]Drawin' circles in the clouds\n[00:23.31]Keep your mind on the distance\n[00:26.67]When the devil turns around\n[00:30.95]Hold me in your heart tonight\n[00:34.11]In the magic of the dark moonlight\n[00:38.44]Save me from this empty fight\n[00:43.83]In the game of life\n[00:45.84]Like a poem said by a lady in red\n[00:49.45]You hear the last few words of your life\n[00:53.15]With a haunting dance now you're both in a trance\n[00:56.90]It's time to cast your spell on the night\n[01:01.40]Abracadabra ama-ooh-na-na\n[01:04.88]Abracadabra porta-ooh-ga-ga\n[01:08.92]Abracadabra abra-ooh-na-na\n[01:12.30]In her tongue she's sayin'\n[01:14.76]Death or love tonight\n[01:18.61]Abracadabra abracadabra\n[01:22.18]Abracadabra abracadabra\n[01:26.08]Feel the beat under your feet\n[01:27.82]The floor's on fire\n[01:29.90]Abracadabra abracadabra\n"""],
323
+ ["""[00:00.27]只因你太美 baby 只因你太美 baby\n[00:08.95]只因你实在是太美 baby\n[00:13.99]只因你太美 baby\n[00:18.89]迎面走来的你让我如此蠢蠢欲动\n[00:20.88]这种感觉我从未有\n[00:21.79]Cause I got a crush on you who you\n[00:25.74]你是我的我是你的谁\n[00:28.09]再多一眼看一眼就会爆炸\n[00:30.31]再近一点靠近点快被融化\n[00:32.49]想要把你占为己有 baby bae\n[00:34.60]不管走到哪里\n[00:35.44]都会想起的人是你 you you\n[00:38.12]我应该拿你怎样\n[00:39.61]Uh 所有人都在看着你\n[00:42.36]我的心总是不安\n[00:44.18]Oh 我现在已病入膏肓\n[00:46.63]Eh oh\n[00:47.84]难道真的因你而疯狂吗\n[00:51.57]我本来不是这种人\n[00:53.59]因你变成奇怪的人\n[00:55.77]第一次呀变成这样的我\n[01:01.23]不管我怎么去否认\n[01:03.21]只因你太美 baby 只因你太美 baby\n[01:11.46]只因你实在是太美 baby\n[01:16.75]只因你太美 baby\n[01:21.09]Oh eh oh\n[01:22.82]现在确认地告诉我\n[01:25.26]Oh eh oh\n[01:27.31]你到底属于谁\n[01:29.98]Oh eh oh\n[01:31.70]现在确认地告诉我\n[01:34.45]Oh eh oh\n"""]
324
+ ],
325
+
326
+ inputs=[lrc],
327
+ label="Lrc Examples",
328
+ examples_per_page=3,
329
+ elem_id="lrc-examples-container",
330
+ )
331
+
332
+
333
+ # page 2
334
+ with gr.Tab("Lyrics Generate", id=1):
335
+ with gr.Row():
336
+ with gr.Column():
337
+ with gr.Accordion("Notice", open=False):
338
+ gr.Markdown("**Two Generation Modes:**\n1. Generate from theme & tags\n2. Add timestamps to existing lyrics")
339
+
340
+ with gr.Group():
341
+ gr.Markdown("### Method 1: Generate from Theme")
342
+ theme = gr.Textbox(label="theme", placeholder="Enter song theme, e.g: Love and Heartbreak")
343
+ tags_gen = gr.Textbox(label="tags", placeholder="Enter song tags, e.g: pop confidence healing")
344
+ language = gr.Radio(["cn", "en"], label="Language", value="en")
345
+ gen_from_theme_btn = gr.Button("Generate LRC (From Theme)", variant="primary")
346
+
347
+ gr.Examples(
348
+ examples=[
349
+ [
350
+ "Love and Heartbreak",
351
+ "vocal emotional piano pop",
352
+ "en"
353
+ ],
354
+ [
355
+ "Heroic Epic",
356
+ "choir orchestral powerful",
357
+ "cn"
358
+ ]
359
+ ],
360
+ inputs=[theme, tags_gen, language],
361
+ label="Examples: Generate from Theme"
362
+ )
363
+
364
+ with gr.Group(visible=True):
365
+ gr.Markdown("### Method 2: Add Timestamps to Lyrics")
366
+ tags_lyrics = gr.Textbox(label="tags", placeholder="Enter song tags, e.g: ballad piano slow")
367
+ lyrics_input = gr.Textbox(
368
+ label="Raw Lyrics (without timestamps)",
369
+ placeholder="Enter plain lyrics (without timestamps), e.g:\nYesterday\nAll my troubles...",
370
+ lines=10,
371
+ max_lines=50,
372
+ elem_classes="lyrics-scroll-box"
373
+ )
374
+
375
+ gen_from_lyrics_btn = gr.Button("Generate LRC (From Lyrics)", variant="primary")
376
+
377
+ gr.Examples(
378
+ examples=[
379
+ [
380
+ "acoustic folk happy",
381
+ """I'm sitting here in the boring room\nIt's just another rainy Sunday afternoon"""
382
+ ],
383
+ [
384
+ "electronic dance energetic",
385
+ """We're living in a material world\nAnd I am a material girl"""
386
+ ]
387
+ ],
388
+ inputs=[tags_lyrics, lyrics_input],
389
+ label="Examples: Generate from Lyrics"
390
+ )
391
+
392
+
393
+ with gr.Column():
394
+ lrc_output = gr.Textbox(
395
+ label="Generated LRC",
396
+ placeholder="Timed lyrics will appear here",
397
+ lines=57,
398
+ elem_classes="lrc-output",
399
+ show_copy_button=True
400
+ )
401
+
402
+ # Bind functions
403
+ gen_from_theme_btn.click(
404
+ fn=R1_infer1,
405
+ inputs=[theme, tags_gen, language],
406
+ outputs=lrc_output
407
+ )
408
+
409
+ gen_from_lyrics_btn.click(
410
+ fn=R1_infer2,
411
+ inputs=[tags_lyrics, lyrics_input],
412
+ outputs=lrc_output
413
+ )
414
+
415
+ tabs.select(
416
+ lambda s: None,
417
+ None,
418
+ None
419
+ )
420
+
421
+ lyrics_btn.click(
422
+ fn=infer_music,
423
+ inputs=[lrc, audio_prompt, text_prompt, current_prompt_type, seed, randomize_seed, steps, cfg_strength, file_type, odeint_method, preference_infer, edit, edit_segments],
424
+ outputs=audio_output
425
+ )
426
+
427
+
428
+ if __name__ == "__main__":
429
+ demo.launch()
430
+
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ espeak-ng
pretrained/eval.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class Generator(nn.Module):
8
+
9
+ def __init__(self,
10
+ in_features,
11
+ ffd_hidden_size,
12
+ num_classes,
13
+ attn_layer_num,
14
+
15
+ ):
16
+ super(Generator, self).__init__()
17
+
18
+ self.attn = nn.ModuleList(
19
+ [
20
+ nn.MultiheadAttention(
21
+ embed_dim=in_features,
22
+ num_heads=8,
23
+ dropout=0.2,
24
+ batch_first=True,
25
+ )
26
+ for _ in range(attn_layer_num)
27
+ ]
28
+ )
29
+
30
+ self.ffd = nn.Sequential(
31
+ nn.Linear(in_features, ffd_hidden_size),
32
+ nn.ReLU(),
33
+ nn.Linear(ffd_hidden_size, in_features)
34
+ )
35
+
36
+ self.dropout = nn.Dropout(0.2)
37
+
38
+ self.fc = nn.Linear(in_features * 2, num_classes)
39
+
40
+ self.proj = nn.Tanh()
41
+
42
+
43
+ def forward(self, ssl_feature, judge_id=None):
44
+ '''
45
+ ssl_feature: [B, T, D]
46
+ output: [B, num_classes]
47
+ '''
48
+
49
+ B, T, D = ssl_feature.shape
50
+
51
+ ssl_feature = self.ffd(ssl_feature)
52
+
53
+ tmp_ssl_feature = ssl_feature
54
+
55
+ for attn in self.attn:
56
+ tmp_ssl_feature, _ = attn(tmp_ssl_feature, tmp_ssl_feature, tmp_ssl_feature)
57
+
58
+ ssl_feature = self.dropout(torch.concat([torch.mean(tmp_ssl_feature, dim=1), torch.max(ssl_feature, dim=1)[0]], dim=1)) # B, 2D
59
+
60
+ x = self.fc(ssl_feature) # B, num_classes
61
+
62
+ x = self.proj(x) * 2.0 + 3
63
+
64
+ return x
65
+
66
+
pretrained/eval.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81cbd54af8b103251e425fcbd8f5313975cb742e760c3dae1e10f99969933fd6
3
+ size 100792276
pretrained/eval.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ generator:
2
+ _target_: pretrained.eval.Generator
3
+ in_features: 1024
4
+ ffd_hidden_size: 4096
5
+ num_classes: 5
6
+ attn_layer_num: 4
requirements.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==5.20.0
2
+ accelerate==1.4.0
3
+ inflect==7.5.0
4
+ hydra-core==1.3.2
5
+ torchdiffeq==0.2.5
6
+ torchaudio==2.6.0
7
+ x-transformers==2.1.2
8
+ transformers==4.49.0
9
+ librosa==0.10.2.post1
10
+ pyarrow==19.0.1
11
+ pandas==2.2.3
12
+ pylance==0.23.2
13
+ ema-pytorch==0.7.7
14
+ prefigure==0.0.10
15
+ bitsandbytes==0.45.3
16
+ muq==0.1.0
17
+ mutagen==1.47.0
18
+ pyopenjtalk
19
+ pykakasi==2.3.0
20
+ jieba==0.42.1
21
+ cn2an==0.5.23
22
+ pypinyin==0.53.0
23
+ onnxruntime==1.20.1
24
+ Unidecode==1.3.8
25
+ phonemizer==3.3.0
26
+ # LangSegment==0.3.5
27
+ liger_kernel==0.5.4
28
+ openai==1.65.2
29
+ pydantic==2.10.6
30
+ einops==0.8.1
31
+ lazy_loader==0.4
32
+ scipy==1.15.2
33
+ ftfy==6.3.1
34
+ py3langid==0.3.0
35
+ torchdiffeq==0.2.5
soundsation/config/config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "soundsation",
3
+ "model": {
4
+ "dim": 2048,
5
+ "depth": 16,
6
+ "heads": 32,
7
+ "ff_mult": 4,
8
+ "text_dim": 512,
9
+ "conv_layers": 4,
10
+ "mel_dim": 64,
11
+ "text_num_embeds": 363
12
+ }
13
+ }
soundsation/config/defaults.ini ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [DEFAULTS]
3
+
4
+ #name of the run
5
+ exp_name = F5
6
+
7
+ # the batch size
8
+ batch_size = 8
9
+
10
+ # the chunk size
11
+ max_frames = 3000
12
+ min_frames = 10
13
+
14
+ # number of CPU workers for the DataLoader
15
+ num_workers = 4
16
+
17
+ # the random seed
18
+ seed = 42
19
+
20
+ # Batches for gradient accumulation
21
+ accum_batches = 1
22
+
23
+ # Number of steps between checkpoints
24
+ checkpoint_every = 10000
25
+
26
+ # trainer checkpoint file to restart training from
27
+ ckpt_path = ''
28
+
29
+ # model checkpoint file to start a new training run from
30
+ pretrained_ckpt_path = ''
31
+
32
+ # Checkpoint path for the pretransform model if needed
33
+ pretransform_ckpt_path = ''
34
+
35
+ # configuration model specifying model hyperparameters
36
+ model_config = ''
37
+
38
+ # configuration for datasets
39
+ dataset_config = ''
40
+
41
+ # directory to save the checkpoints in
42
+ save_dir = ''
43
+
44
+ # grad norm
45
+ max_grad_norm = 1.0
46
+
47
+ # grad accu
48
+ grad_accumulation_steps = 1
49
+
50
+ # lr
51
+ learning_rate = 7.5e-5
52
+
53
+ # epoch
54
+ epochs = 110
55
+
56
+ # warmup steps
57
+ num_warmup_updates = 2000
58
+
59
+ # save checkpoint per steps
60
+ save_per_updates = 5000
61
+
62
+ # save last checkpoint per steps
63
+ last_per_steps = 5000
64
+
65
+ prompt_path = "/mnt/sfs/music/lance/style-lance-full|/mnt/sfs/music/lance/style-lance-cnen-music-second"
66
+ lrc_path = "/mnt/sfs/music/lance/lrc-lance-emb-full|/mnt/sfs/music/lance/lrc-lance-cnen-second"
67
+ latent_path = "/mnt/sfs/music/lance/latent-lance|/mnt/sfs/music/lance/latent-lance-cnen-music-second-1|/mnt/sfs/music/lance/latent-lance-cnen-music-second-2"
68
+
69
+ audio_drop_prob = 0.3
70
+ cond_drop_prob = 0.0
71
+ style_drop_prob = 0.1
72
+ lrc_drop_prob = 0.1
73
+
74
+ align_lyrics = 0
75
+ lyrics_slice = 0
76
+ parse_lyrics = 1
77
+ skip_empty_lyrics = 0
78
+ lyrics_shift = -1
79
+
80
+ use_style_prompt = 1
81
+
82
+ tokenizer_type = gpt2
83
+
84
+ reset_lr = 0
85
+
86
+ resumable_with_seed = 666
87
+
88
+ downsample_rate = 2048
89
+
90
+ grad_ckpt = 0
91
+
92
+ dataset_path = "/mnt/sfs/music/hkchen/workspace/F5-TTS-HW/filelists/music123latent_asred_bpmstyle_cnen_pure1"
93
+
94
+ pure_prob = 0.0
soundsation/g2p/g2p/__init__.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from soundsation.g2p.g2p import cleaners
7
+ from tokenizers import Tokenizer
8
+ from soundsation.g2p.g2p.text_tokenizers import TextTokenizer
9
+ import LangSegment
10
+ import json
11
+ import re
12
+
13
+
14
+ class PhonemeBpeTokenizer:
15
+
16
+ def __init__(self, vacab_path="./soundsation/g2p/g2p/vocab.json"):
17
+ self.lang2backend = {
18
+ "zh": "cmn",
19
+ "ja": "ja",
20
+ "en": "en-us",
21
+ "fr": "fr-fr",
22
+ "ko": "ko",
23
+ "de": "de",
24
+ }
25
+ self.text_tokenizers = {}
26
+ self.int_text_tokenizers()
27
+
28
+ with open(vacab_path, "r") as f:
29
+ json_data = f.read()
30
+ data = json.loads(json_data)
31
+ self.vocab = data["vocab"]
32
+ LangSegment.setfilters(["en", "zh", "ja", "ko", "fr", "de"])
33
+
34
+ def int_text_tokenizers(self):
35
+ for key, value in self.lang2backend.items():
36
+ self.text_tokenizers[key] = TextTokenizer(language=value)
37
+
38
+ def tokenize(self, text, sentence, language):
39
+
40
+ # 1. convert text to phoneme
41
+ phonemes = []
42
+ if language == "auto":
43
+ seglist = LangSegment.getTexts(text)
44
+ tmp_ph = []
45
+ for seg in seglist:
46
+ tmp_ph.append(
47
+ self._clean_text(
48
+ seg["text"], sentence, seg["lang"], ["cjekfd_cleaners"]
49
+ )
50
+ )
51
+ phonemes = "|_|".join(tmp_ph)
52
+ else:
53
+ phonemes = self._clean_text(text, sentence, language, ["cjekfd_cleaners"])
54
+ # print('clean text: ', phonemes)
55
+
56
+ # 2. tokenize phonemes
57
+ phoneme_tokens = self.phoneme2token(phonemes)
58
+ # print('encode: ', phoneme_tokens)
59
+
60
+ # # 3. decode tokens [optional]
61
+ # decoded_text = self.tokenizer.decode(phoneme_tokens)
62
+ # print('decoded: ', decoded_text)
63
+
64
+ return phonemes, phoneme_tokens
65
+
66
+ def _clean_text(self, text, sentence, language, cleaner_names):
67
+ for name in cleaner_names:
68
+ cleaner = getattr(cleaners, name)
69
+ if not cleaner:
70
+ raise Exception("Unknown cleaner: %s" % name)
71
+ text = cleaner(text, sentence, language, self.text_tokenizers)
72
+ return text
73
+
74
+ def phoneme2token(self, phonemes):
75
+ tokens = []
76
+ if isinstance(phonemes, list):
77
+ for phone in phonemes:
78
+ phone = phone.split("\t")[0]
79
+ phonemes_split = phone.split("|")
80
+ tokens.append(
81
+ [self.vocab[p] for p in phonemes_split if p in self.vocab]
82
+ )
83
+ else:
84
+ phonemes = phonemes.split("\t")[0]
85
+ phonemes_split = phonemes.split("|")
86
+ tokens = [self.vocab[p] for p in phonemes_split if p in self.vocab]
87
+ return tokens
soundsation/g2p/g2p/chinese_model_g2p.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import DataLoader
10
+ import json
11
+ from transformers import BertTokenizer
12
+ from torch.utils.data import Dataset
13
+ from transformers.models.bert.modeling_bert import *
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from onnxruntime import InferenceSession, GraphOptimizationLevel, SessionOptions
17
+
18
+
19
+ class PolyDataset(Dataset):
20
+ def __init__(self, words, labels, word_pad_idx=0, label_pad_idx=-1):
21
+ self.dataset = self.preprocess(words, labels)
22
+ self.word_pad_idx = word_pad_idx
23
+ self.label_pad_idx = label_pad_idx
24
+
25
+ def preprocess(self, origin_sentences, origin_labels):
26
+ """
27
+ Maps tokens and tags to their indices and stores them in the dict data.
28
+ examples:
29
+ word:['[CLS]', '浙', '商', '银', '行', '企', '业', '信', '贷', '部']
30
+ sentence:([101, 3851, 1555, 7213, 6121, 821, 689, 928, 6587, 6956],
31
+ array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))
32
+ label:[3, 13, 13, 13, 0, 0, 0, 0, 0]
33
+ """
34
+ data = []
35
+ labels = []
36
+ sentences = []
37
+ # tokenize
38
+ for line in origin_sentences:
39
+ # replace each token by its index
40
+ # we can not use encode_plus because our sentences are aligned to labels in list type
41
+ words = []
42
+ word_lens = []
43
+ for token in line:
44
+ words.append(token)
45
+ word_lens.append(1)
46
+ token_start_idxs = 1 + np.cumsum([0] + word_lens[:-1])
47
+ sentences.append(((words, token_start_idxs), 0))
48
+ ###
49
+ for tag in origin_labels:
50
+ labels.append(tag)
51
+
52
+ for sentence, label in zip(sentences, labels):
53
+ data.append((sentence, label))
54
+ return data
55
+
56
+ def __getitem__(self, idx):
57
+ """sample data to get batch"""
58
+ word = self.dataset[idx][0]
59
+ label = self.dataset[idx][1]
60
+ return [word, label]
61
+
62
+ def __len__(self):
63
+ """get dataset size"""
64
+ return len(self.dataset)
65
+
66
+ def collate_fn(self, batch):
67
+
68
+ sentences = [x[0][0] for x in batch]
69
+ ori_sents = [x[0][1] for x in batch]
70
+ labels = [x[1] for x in batch]
71
+ batch_len = len(sentences)
72
+
73
+ # compute length of longest sentence in batch
74
+ max_len = max([len(s[0]) for s in sentences])
75
+ max_label_len = 0
76
+ batch_data = np.ones((batch_len, max_len))
77
+ batch_label_starts = []
78
+
79
+ # padding and aligning
80
+ for j in range(batch_len):
81
+ cur_len = len(sentences[j][0])
82
+ batch_data[j][:cur_len] = sentences[j][0]
83
+ label_start_idx = sentences[j][-1]
84
+ label_starts = np.zeros(max_len)
85
+ label_starts[[idx for idx in label_start_idx if idx < max_len]] = 1
86
+ batch_label_starts.append(label_starts)
87
+ max_label_len = max(int(sum(label_starts)), max_label_len)
88
+
89
+ # padding label
90
+ batch_labels = self.label_pad_idx * np.ones((batch_len, max_label_len))
91
+ batch_pmasks = self.label_pad_idx * np.ones((batch_len, max_label_len))
92
+ for j in range(batch_len):
93
+ cur_tags_len = len(labels[j])
94
+ batch_labels[j][:cur_tags_len] = labels[j]
95
+ batch_pmasks[j][:cur_tags_len] = [
96
+ 1 if item > 0 else 0 for item in labels[j]
97
+ ]
98
+
99
+ # convert data to torch LongTensors
100
+ batch_data = torch.tensor(batch_data, dtype=torch.long)
101
+ batch_label_starts = torch.tensor(batch_label_starts, dtype=torch.long)
102
+ batch_labels = torch.tensor(batch_labels, dtype=torch.long)
103
+ batch_pmasks = torch.tensor(batch_pmasks, dtype=torch.long)
104
+ return [batch_data, batch_label_starts, batch_labels, batch_pmasks, ori_sents]
105
+
106
+
107
+ class BertPolyPredict:
108
+ def __init__(self, bert_model, jsonr_file, json_file):
109
+ self.tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=True)
110
+ with open(jsonr_file, "r", encoding="utf8") as fp:
111
+ self.pron_dict = json.load(fp)
112
+ with open(json_file, "r", encoding="utf8") as fp:
113
+ self.pron_dict_id_2_pinyin = json.load(fp)
114
+ self.num_polyphone = len(self.pron_dict)
115
+ self.device = "cpu"
116
+ self.polydataset = PolyDataset
117
+ options = SessionOptions() # initialize session options
118
+ options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
119
+ print(os.path.join(bert_model, "poly_bert_model.onnx"))
120
+ self.session = InferenceSession(
121
+ os.path.join(bert_model, "poly_bert_model.onnx"),
122
+ sess_options=options,
123
+ providers=[
124
+ "CUDAExecutionProvider",
125
+ "CPUExecutionProvider",
126
+ ], # CPUExecutionProvider #CUDAExecutionProvider
127
+ )
128
+ # self.session.set_providers(['CUDAExecutionProvider', "CPUExecutionProvider"], [ {'device_id': 0}])
129
+
130
+ # disable session.run() fallback mechanism, it prevents for a reset of the execution provider
131
+ self.session.disable_fallback()
132
+
133
+ def predict_process(self, txt_list):
134
+ word_test, label_test, texts_test = self.get_examples_po(txt_list)
135
+ data = self.polydataset(word_test, label_test)
136
+ predict_loader = DataLoader(
137
+ data, batch_size=1, shuffle=False, collate_fn=data.collate_fn
138
+ )
139
+ pred_tags = self.predict_onnx(predict_loader)
140
+ return pred_tags
141
+
142
+ def predict_onnx(self, dev_loader):
143
+ pred_tags = []
144
+ with torch.no_grad():
145
+ for idx, batch_samples in enumerate(dev_loader):
146
+ # [batch_data, batch_label_starts, batch_labels, batch_pmasks, ori_sents]
147
+ batch_data, batch_label_starts, batch_labels, batch_pmasks, _ = (
148
+ batch_samples
149
+ )
150
+ # shift tensors to GPU if available
151
+ batch_data = batch_data.to(self.device)
152
+ batch_label_starts = batch_label_starts.to(self.device)
153
+ batch_labels = batch_labels.to(self.device)
154
+ batch_pmasks = batch_pmasks.to(self.device)
155
+ batch_data = np.asarray(batch_data, dtype=np.int32)
156
+ batch_pmasks = np.asarray(batch_pmasks, dtype=np.int32)
157
+ # batch_output = self.session.run(output_names=['outputs'], input_feed={"input_ids":batch_data, "input_pmasks": batch_pmasks})[0][0]
158
+ batch_output = self.session.run(
159
+ output_names=["outputs"], input_feed={"input_ids": batch_data}
160
+ )[0]
161
+ label_masks = batch_pmasks == 1
162
+ batch_labels = batch_labels.to("cpu").numpy()
163
+ for i, indices in enumerate(np.argmax(batch_output, axis=2)):
164
+ for j, idx in enumerate(indices):
165
+ if label_masks[i][j]:
166
+ # pred_tag.append(idx)
167
+ pred_tags.append(self.pron_dict_id_2_pinyin[str(idx + 1)])
168
+ return pred_tags
169
+
170
+ def get_examples_po(self, text_list):
171
+
172
+ word_list = []
173
+ label_list = []
174
+ sentence_list = []
175
+ id = 0
176
+ for line in [text_list]:
177
+ sentence = line[0]
178
+ words = []
179
+ tokens = line[0]
180
+ index = line[-1]
181
+ front = index
182
+ back = len(tokens) - index - 1
183
+ labels = [0] * front + [1] + [0] * back
184
+ words = ["[CLS]"] + [item for item in sentence]
185
+ words = self.tokenizer.convert_tokens_to_ids(words)
186
+ word_list.append(words)
187
+ label_list.append(labels)
188
+ sentence_list.append(sentence)
189
+
190
+ id += 1
191
+ # mask_list.append(masks)
192
+ assert len(labels) + 1 == len(words), print(
193
+ (
194
+ poly,
195
+ sentence,
196
+ words,
197
+ labels,
198
+ sentence,
199
+ len(sentence),
200
+ len(words),
201
+ len(labels),
202
+ )
203
+ )
204
+ assert len(labels) + 1 == len(
205
+ words
206
+ ), "Number of labels does not match number of words"
207
+ assert len(labels) == len(
208
+ sentence
209
+ ), "Number of labels does not match number of sentences"
210
+ assert len(word_list) == len(
211
+ label_list
212
+ ), "Number of label sentences does not match number of word sentences"
213
+ return word_list, label_list, text_list
soundsation/g2p/g2p/cleaners.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+ from soundsation.g2p.g2p.japanese import japanese_to_ipa
8
+ from soundsation.g2p.g2p.mandarin import chinese_to_ipa
9
+ from soundsation.g2p.g2p.english import english_to_ipa
10
+ from soundsation.g2p.g2p.french import french_to_ipa
11
+ from soundsation.g2p.g2p.korean import korean_to_ipa
12
+ from soundsation.g2p.g2p.german import german_to_ipa
13
+
14
+
15
+ def cjekfd_cleaners(text, sentence, language, text_tokenizers):
16
+
17
+ if language == "zh":
18
+ return chinese_to_ipa(text, sentence, text_tokenizers["zh"])
19
+ elif language == "ja":
20
+ return japanese_to_ipa(text, text_tokenizers["ja"])
21
+ elif language == "en":
22
+ return english_to_ipa(text, text_tokenizers["en"])
23
+ elif language == "fr":
24
+ return french_to_ipa(text, text_tokenizers["fr"])
25
+ elif language == "ko":
26
+ return korean_to_ipa(text, text_tokenizers["ko"])
27
+ elif language == "de":
28
+ return german_to_ipa(text, text_tokenizers["de"])
29
+ else:
30
+ raise Exception("Unknown language: %s" % language)
31
+ return None
soundsation/g2p/g2p/english.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+ from unidecode import unidecode
8
+ import inflect
9
+
10
+ """
11
+ Text clean time
12
+ """
13
+ _inflect = inflect.engine()
14
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
15
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
16
+ _percent_number_re = re.compile(r"([0-9\.\,]*[0-9]+%)")
17
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
18
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
19
+ _fraction_re = re.compile(r"([0-9]+)/([0-9]+)")
20
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
21
+ _number_re = re.compile(r"[0-9]+")
22
+
23
+ # List of (regular expression, replacement) pairs for abbreviations:
24
+ _abbreviations = [
25
+ (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
26
+ for x in [
27
+ ("mrs", "misess"),
28
+ ("mr", "mister"),
29
+ ("dr", "doctor"),
30
+ ("st", "saint"),
31
+ ("co", "company"),
32
+ ("jr", "junior"),
33
+ ("maj", "major"),
34
+ ("gen", "general"),
35
+ ("drs", "doctors"),
36
+ ("rev", "reverend"),
37
+ ("lt", "lieutenant"),
38
+ ("hon", "honorable"),
39
+ ("sgt", "sergeant"),
40
+ ("capt", "captain"),
41
+ ("esq", "esquire"),
42
+ ("ltd", "limited"),
43
+ ("col", "colonel"),
44
+ ("ft", "fort"),
45
+ ("etc", "et cetera"),
46
+ ("btw", "by the way"),
47
+ ]
48
+ ]
49
+
50
+ _special_map = [
51
+ ("t|ɹ", "tɹ"),
52
+ ("d|ɹ", "dɹ"),
53
+ ("t|s", "ts"),
54
+ ("d|z", "dz"),
55
+ ("ɪ|ɹ", "ɪɹ"),
56
+ ("ɐ", "ɚ"),
57
+ ("ᵻ", "ɪ"),
58
+ ("əl", "l"),
59
+ ("x", "k"),
60
+ ("ɬ", "l"),
61
+ ("ʔ", "t"),
62
+ ("n̩", "n"),
63
+ ("oː|ɹ", "oːɹ"),
64
+ ]
65
+
66
+
67
+ def expand_abbreviations(text):
68
+ for regex, replacement in _abbreviations:
69
+ text = re.sub(regex, replacement, text)
70
+ return text
71
+
72
+
73
+ def _remove_commas(m):
74
+ return m.group(1).replace(",", "")
75
+
76
+
77
+ def _expand_decimal_point(m):
78
+ return m.group(1).replace(".", " point ")
79
+
80
+
81
+ def _expand_percent(m):
82
+ return m.group(1).replace("%", " percent ")
83
+
84
+
85
+ def _expand_dollars(m):
86
+ match = m.group(1)
87
+ parts = match.split(".")
88
+ if len(parts) > 2:
89
+ return " " + match + " dollars " # Unexpected format
90
+ dollars = int(parts[0]) if parts[0] else 0
91
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
92
+ if dollars and cents:
93
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
94
+ cent_unit = "cent" if cents == 1 else "cents"
95
+ return " %s %s, %s %s " % (dollars, dollar_unit, cents, cent_unit)
96
+ elif dollars:
97
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
98
+ return " %s %s " % (dollars, dollar_unit)
99
+ elif cents:
100
+ cent_unit = "cent" if cents == 1 else "cents"
101
+ return " %s %s " % (cents, cent_unit)
102
+ else:
103
+ return " zero dollars "
104
+
105
+
106
+ def fraction_to_words(numerator, denominator):
107
+ if numerator == 1 and denominator == 2:
108
+ return " one half "
109
+ if numerator == 1 and denominator == 4:
110
+ return " one quarter "
111
+ if denominator == 2:
112
+ return " " + _inflect.number_to_words(numerator) + " halves "
113
+ if denominator == 4:
114
+ return " " + _inflect.number_to_words(numerator) + " quarters "
115
+ return (
116
+ " "
117
+ + _inflect.number_to_words(numerator)
118
+ + " "
119
+ + _inflect.ordinal(_inflect.number_to_words(denominator))
120
+ + " "
121
+ )
122
+
123
+
124
+ def _expand_fraction(m):
125
+ numerator = int(m.group(1))
126
+ denominator = int(m.group(2))
127
+ return fraction_to_words(numerator, denominator)
128
+
129
+
130
+ def _expand_ordinal(m):
131
+ return " " + _inflect.number_to_words(m.group(0)) + " "
132
+
133
+
134
+ def _expand_number(m):
135
+ num = int(m.group(0))
136
+ if num > 1000 and num < 3000:
137
+ if num == 2000:
138
+ return " two thousand "
139
+ elif num > 2000 and num < 2010:
140
+ return " two thousand " + _inflect.number_to_words(num % 100) + " "
141
+ elif num % 100 == 0:
142
+ return " " + _inflect.number_to_words(num // 100) + " hundred "
143
+ else:
144
+ return (
145
+ " "
146
+ + _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(
147
+ ", ", " "
148
+ )
149
+ + " "
150
+ )
151
+ else:
152
+ return " " + _inflect.number_to_words(num, andword="") + " "
153
+
154
+
155
+ # Normalize numbers pronunciation
156
+ def normalize_numbers(text):
157
+ text = re.sub(_comma_number_re, _remove_commas, text)
158
+ text = re.sub(_pounds_re, r"\1 pounds", text)
159
+ text = re.sub(_dollars_re, _expand_dollars, text)
160
+ text = re.sub(_fraction_re, _expand_fraction, text)
161
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
162
+ text = re.sub(_percent_number_re, _expand_percent, text)
163
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
164
+ text = re.sub(_number_re, _expand_number, text)
165
+ return text
166
+
167
+
168
+ def _english_to_ipa(text):
169
+ # text = unidecode(text).lower()
170
+ text = expand_abbreviations(text)
171
+ text = normalize_numbers(text)
172
+ return text
173
+
174
+
175
+ # special map
176
+ def special_map(text):
177
+ for regex, replacement in _special_map:
178
+ regex = regex.replace("|", "\|")
179
+ while re.search(r"(^|[_|]){}([_|]|$)".format(regex), text):
180
+ text = re.sub(
181
+ r"(^|[_|]){}([_|]|$)".format(regex), r"\1{}\2".format(replacement), text
182
+ )
183
+ # text = re.sub(r'([,.!?])', r'|\1', text)
184
+ return text
185
+
186
+
187
+ # Add some special operation
188
+ def english_to_ipa(text, text_tokenizer):
189
+ if type(text) == str:
190
+ text = _english_to_ipa(text)
191
+ else:
192
+ text = [_english_to_ipa(t) for t in text]
193
+ phonemes = text_tokenizer(text)
194
+ if phonemes[-1] in "p⁼ʰmftnlkxʃs`ɹaoəɛɪeɑʊŋiuɥwæjː":
195
+ phonemes += "|_"
196
+ if type(text) == str:
197
+ return special_map(phonemes)
198
+ else:
199
+ result_ph = []
200
+ for phone in phonemes:
201
+ result_ph.append(special_map(phone))
202
+ return result_ph
soundsation/g2p/g2p/french.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+
8
+ """
9
+ Text clean time
10
+ """
11
+ # List of (regular expression, replacement) pairs for abbreviations in french:
12
+ _abbreviations = [
13
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
14
+ for x in [
15
+ ("M", "monsieur"),
16
+ ("Mlle", "mademoiselle"),
17
+ ("Mlles", "mesdemoiselles"),
18
+ ("Mme", "Madame"),
19
+ ("Mmes", "Mesdames"),
20
+ ("N.B", "nota bene"),
21
+ ("M", "monsieur"),
22
+ ("p.c.q", "parce que"),
23
+ ("Pr", "professeur"),
24
+ ("qqch", "quelque chose"),
25
+ ("rdv", "rendez-vous"),
26
+ ("max", "maximum"),
27
+ ("min", "minimum"),
28
+ ("no", "numéro"),
29
+ ("adr", "adresse"),
30
+ ("dr", "docteur"),
31
+ ("st", "saint"),
32
+ ("co", "companie"),
33
+ ("jr", "junior"),
34
+ ("sgt", "sergent"),
35
+ ("capt", "capitain"),
36
+ ("col", "colonel"),
37
+ ("av", "avenue"),
38
+ ("av. J.-C", "avant Jésus-Christ"),
39
+ ("apr. J.-C", "après Jésus-Christ"),
40
+ ("art", "article"),
41
+ ("boul", "boulevard"),
42
+ ("c.-à-d", "c’est-à-dire"),
43
+ ("etc", "et cetera"),
44
+ ("ex", "exemple"),
45
+ ("excl", "exclusivement"),
46
+ ("boul", "boulevard"),
47
+ ]
48
+ ] + [
49
+ (re.compile("\\b%s" % x[0]), x[1])
50
+ for x in [
51
+ ("Mlle", "mademoiselle"),
52
+ ("Mlles", "mesdemoiselles"),
53
+ ("Mme", "Madame"),
54
+ ("Mmes", "Mesdames"),
55
+ ]
56
+ ]
57
+
58
+ rep_map = {
59
+ ":": ",",
60
+ ";": ",",
61
+ ",": ",",
62
+ "。": ".",
63
+ "!": "!",
64
+ "?": "?",
65
+ "\n": ".",
66
+ "·": ",",
67
+ "、": ",",
68
+ "...": ".",
69
+ "…": ".",
70
+ "$": ".",
71
+ "“": "",
72
+ "”": "",
73
+ "‘": "",
74
+ "’": "",
75
+ "(": "",
76
+ ")": "",
77
+ "(": "",
78
+ ")": "",
79
+ "《": "",
80
+ "》": "",
81
+ "【": "",
82
+ "】": "",
83
+ "[": "",
84
+ "]": "",
85
+ "—": "",
86
+ "~": "-",
87
+ "~": "-",
88
+ "「": "",
89
+ "」": "",
90
+ "¿": "",
91
+ "¡": "",
92
+ }
93
+
94
+
95
+ def collapse_whitespace(text):
96
+ # Regular expression matching whitespace:
97
+ _whitespace_re = re.compile(r"\s+")
98
+ return re.sub(_whitespace_re, " ", text).strip()
99
+
100
+
101
+ def remove_punctuation_at_begin(text):
102
+ return re.sub(r"^[,.!?]+", "", text)
103
+
104
+
105
+ def remove_aux_symbols(text):
106
+ text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
107
+ return text
108
+
109
+
110
+ def replace_symbols(text):
111
+ text = text.replace(";", ",")
112
+ text = text.replace("-", " ")
113
+ text = text.replace(":", ",")
114
+ text = text.replace("&", " et ")
115
+ return text
116
+
117
+
118
+ def expand_abbreviations(text):
119
+ for regex, replacement in _abbreviations:
120
+ text = re.sub(regex, replacement, text)
121
+ return text
122
+
123
+
124
+ def replace_punctuation(text):
125
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
126
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
127
+ return replaced_text
128
+
129
+
130
+ def text_normalize(text):
131
+ text = expand_abbreviations(text)
132
+ text = replace_punctuation(text)
133
+ text = replace_symbols(text)
134
+ text = remove_aux_symbols(text)
135
+ text = remove_punctuation_at_begin(text)
136
+ text = collapse_whitespace(text)
137
+ text = re.sub(r"([^\.,!\?\-…])$", r"\1", text)
138
+ return text
139
+
140
+
141
+ def french_to_ipa(text, text_tokenizer):
142
+ if type(text) == str:
143
+ text = text_normalize(text)
144
+ phonemes = text_tokenizer(text)
145
+ return phonemes
146
+ else:
147
+ for i, t in enumerate(text):
148
+ text[i] = text_normalize(t)
149
+ return text_tokenizer(text)
soundsation/g2p/g2p/german.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+
8
+ """
9
+ Text clean time
10
+ """
11
+ rep_map = {
12
+ ":": ",",
13
+ ";": ",",
14
+ ",": ",",
15
+ "。": ".",
16
+ "!": "!",
17
+ "?": "?",
18
+ "\n": ".",
19
+ "·": ",",
20
+ "、": ",",
21
+ "...": ".",
22
+ "…": ".",
23
+ "$": ".",
24
+ "“": "",
25
+ "”": "",
26
+ "‘": "",
27
+ "’": "",
28
+ "(": "",
29
+ ")": "",
30
+ "(": "",
31
+ ")": "",
32
+ "《": "",
33
+ "》": "",
34
+ "【": "",
35
+ "】": "",
36
+ "[": "",
37
+ "]": "",
38
+ "—": "",
39
+ "~": "-",
40
+ "~": "-",
41
+ "「": "",
42
+ "」": "",
43
+ "¿": "",
44
+ "¡": "",
45
+ }
46
+
47
+
48
+ def collapse_whitespace(text):
49
+ # Regular expression matching whitespace:
50
+ _whitespace_re = re.compile(r"\s+")
51
+ return re.sub(_whitespace_re, " ", text).strip()
52
+
53
+
54
+ def remove_punctuation_at_begin(text):
55
+ return re.sub(r"^[,.!?]+", "", text)
56
+
57
+
58
+ def remove_aux_symbols(text):
59
+ text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
60
+ return text
61
+
62
+
63
+ def replace_symbols(text):
64
+ text = text.replace(";", ",")
65
+ text = text.replace("-", " ")
66
+ text = text.replace(":", ",")
67
+ return text
68
+
69
+
70
+ def replace_punctuation(text):
71
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
72
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
73
+ return replaced_text
74
+
75
+
76
+ def text_normalize(text):
77
+ text = replace_punctuation(text)
78
+ text = replace_symbols(text)
79
+ text = remove_aux_symbols(text)
80
+ text = remove_punctuation_at_begin(text)
81
+ text = collapse_whitespace(text)
82
+ text = re.sub(r"([^\.,!\?\-…])$", r"\1", text)
83
+ return text
84
+
85
+
86
+ def german_to_ipa(text, text_tokenizer):
87
+ if type(text) == str:
88
+ text = text_normalize(text)
89
+ phonemes = text_tokenizer(text)
90
+ return phonemes
91
+ else:
92
+ for i, t in enumerate(text):
93
+ text[i] = text_normalize(t)
94
+ return text_tokenizer(text)
soundsation/g2p/g2p/japanese.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import io, re, os, sys, time, argparse, pdb, json
7
+ from io import StringIO
8
+ from typing import Optional
9
+ import numpy as np
10
+ import traceback
11
+ import pyopenjtalk
12
+ from pykakasi import kakasi
13
+
14
+ punctuation = [",", ".", "!", "?", ":", ";", "'", "…"]
15
+
16
+ jp_xphone2ipa = [
17
+ " a a",
18
+ " i i",
19
+ " u ɯ",
20
+ " e e",
21
+ " o o",
22
+ " a: aː",
23
+ " i: iː",
24
+ " u: ɯː",
25
+ " e: eː",
26
+ " o: oː",
27
+ " k k",
28
+ " s s",
29
+ " t t",
30
+ " n n",
31
+ " h ç",
32
+ " f ɸ",
33
+ " m m",
34
+ " y j",
35
+ " r ɾ",
36
+ " w ɰᵝ",
37
+ " N ɴ",
38
+ " g g",
39
+ " j d ʑ",
40
+ " z z",
41
+ " d d",
42
+ " b b",
43
+ " p p",
44
+ " q q",
45
+ " v v",
46
+ " : :",
47
+ " by b j",
48
+ " ch t ɕ",
49
+ " dy d e j",
50
+ " ty t e j",
51
+ " gy g j",
52
+ " gw g ɯ",
53
+ " hy ç j",
54
+ " ky k j",
55
+ " kw k ɯ",
56
+ " my m j",
57
+ " ny n j",
58
+ " py p j",
59
+ " ry ɾ j",
60
+ " sh ɕ",
61
+ " ts t s ɯ",
62
+ ]
63
+
64
+ _mora_list_minimum: list[tuple[str, Optional[str], str]] = [
65
+ ("ヴォ", "v", "o"),
66
+ ("ヴェ", "v", "e"),
67
+ ("ヴィ", "v", "i"),
68
+ ("ヴァ", "v", "a"),
69
+ ("ヴ", "v", "u"),
70
+ ("ン", None, "N"),
71
+ ("ワ", "w", "a"),
72
+ ("ロ", "r", "o"),
73
+ ("レ", "r", "e"),
74
+ ("ル", "r", "u"),
75
+ ("リョ", "ry", "o"),
76
+ ("リュ", "ry", "u"),
77
+ ("リャ", "ry", "a"),
78
+ ("リェ", "ry", "e"),
79
+ ("リ", "r", "i"),
80
+ ("ラ", "r", "a"),
81
+ ("ヨ", "y", "o"),
82
+ ("ユ", "y", "u"),
83
+ ("ヤ", "y", "a"),
84
+ ("モ", "m", "o"),
85
+ ("メ", "m", "e"),
86
+ ("ム", "m", "u"),
87
+ ("ミョ", "my", "o"),
88
+ ("ミュ", "my", "u"),
89
+ ("ミャ", "my", "a"),
90
+ ("ミェ", "my", "e"),
91
+ ("ミ", "m", "i"),
92
+ ("マ", "m", "a"),
93
+ ("ポ", "p", "o"),
94
+ ("ボ", "b", "o"),
95
+ ("ホ", "h", "o"),
96
+ ("ペ", "p", "e"),
97
+ ("ベ", "b", "e"),
98
+ ("ヘ", "h", "e"),
99
+ ("プ", "p", "u"),
100
+ ("ブ", "b", "u"),
101
+ ("フォ", "f", "o"),
102
+ ("フェ", "f", "e"),
103
+ ("フィ", "f", "i"),
104
+ ("ファ", "f", "a"),
105
+ ("フ", "f", "u"),
106
+ ("ピョ", "py", "o"),
107
+ ("ピュ", "py", "u"),
108
+ ("ピャ", "py", "a"),
109
+ ("ピェ", "py", "e"),
110
+ ("ピ", "p", "i"),
111
+ ("ビョ", "by", "o"),
112
+ ("ビュ", "by", "u"),
113
+ ("ビャ", "by", "a"),
114
+ ("ビェ", "by", "e"),
115
+ ("ビ", "b", "i"),
116
+ ("ヒョ", "hy", "o"),
117
+ ("ヒュ", "hy", "u"),
118
+ ("ヒャ", "hy", "a"),
119
+ ("ヒェ", "hy", "e"),
120
+ ("ヒ", "h", "i"),
121
+ ("パ", "p", "a"),
122
+ ("バ", "b", "a"),
123
+ ("ハ", "h", "a"),
124
+ ("ノ", "n", "o"),
125
+ ("ネ", "n", "e"),
126
+ ("ヌ", "n", "u"),
127
+ ("ニョ", "ny", "o"),
128
+ ("ニュ", "ny", "u"),
129
+ ("ニャ", "ny", "a"),
130
+ ("ニェ", "ny", "e"),
131
+ ("ニ", "n", "i"),
132
+ ("ナ", "n", "a"),
133
+ ("ドゥ", "d", "u"),
134
+ ("ド", "d", "o"),
135
+ ("トゥ", "t", "u"),
136
+ ("ト", "t", "o"),
137
+ ("デョ", "dy", "o"),
138
+ ("デュ", "dy", "u"),
139
+ ("デャ", "dy", "a"),
140
+ # ("デェ", "dy", "e"),
141
+ ("ディ", "d", "i"),
142
+ ("デ", "d", "e"),
143
+ ("テョ", "ty", "o"),
144
+ ("テュ", "ty", "u"),
145
+ ("テャ", "ty", "a"),
146
+ ("ティ", "t", "i"),
147
+ ("テ", "t", "e"),
148
+ ("ツォ", "ts", "o"),
149
+ ("ツェ", "ts", "e"),
150
+ ("ツィ", "ts", "i"),
151
+ ("ツァ", "ts", "a"),
152
+ ("ツ", "ts", "u"),
153
+ ("ッ", None, "q"), # 「cl」から「q」に変更
154
+ ("チョ", "ch", "o"),
155
+ ("チュ", "ch", "u"),
156
+ ("チャ", "ch", "a"),
157
+ ("チェ", "ch", "e"),
158
+ ("チ", "ch", "i"),
159
+ ("ダ", "d", "a"),
160
+ ("タ", "t", "a"),
161
+ ("ゾ", "z", "o"),
162
+ ("ソ", "s", "o"),
163
+ ("ゼ", "z", "e"),
164
+ ("セ", "s", "e"),
165
+ ("ズィ", "z", "i"),
166
+ ("ズ", "z", "u"),
167
+ ("スィ", "s", "i"),
168
+ ("ス", "s", "u"),
169
+ ("ジョ", "j", "o"),
170
+ ("ジュ", "j", "u"),
171
+ ("ジャ", "j", "a"),
172
+ ("ジェ", "j", "e"),
173
+ ("ジ", "j", "i"),
174
+ ("ショ", "sh", "o"),
175
+ ("シュ", "sh", "u"),
176
+ ("シャ", "sh", "a"),
177
+ ("シェ", "sh", "e"),
178
+ ("シ", "sh", "i"),
179
+ ("ザ", "z", "a"),
180
+ ("サ", "s", "a"),
181
+ ("ゴ", "g", "o"),
182
+ ("コ", "k", "o"),
183
+ ("ゲ", "g", "e"),
184
+ ("ケ", "k", "e"),
185
+ ("グヮ", "gw", "a"),
186
+ ("グ", "g", "u"),
187
+ ("クヮ", "kw", "a"),
188
+ ("ク", "k", "u"),
189
+ ("ギョ", "gy", "o"),
190
+ ("ギュ", "gy", "u"),
191
+ ("ギャ", "gy", "a"),
192
+ ("ギェ", "gy", "e"),
193
+ ("ギ", "g", "i"),
194
+ ("キョ", "ky", "o"),
195
+ ("キュ", "ky", "u"),
196
+ ("キャ", "ky", "a"),
197
+ ("キェ", "ky", "e"),
198
+ ("キ", "k", "i"),
199
+ ("ガ", "g", "a"),
200
+ ("カ", "k", "a"),
201
+ ("オ", None, "o"),
202
+ ("エ", None, "e"),
203
+ ("ウォ", "w", "o"),
204
+ ("ウェ", "w", "e"),
205
+ ("ウィ", "w", "i"),
206
+ ("ウ", None, "u"),
207
+ ("イェ", "y", "e"),
208
+ ("イ", None, "i"),
209
+ ("ア", None, "a"),
210
+ ]
211
+
212
+ _mora_list_additional: list[tuple[str, Optional[str], str]] = [
213
+ ("ヴョ", "by", "o"),
214
+ ("ヴュ", "by", "u"),
215
+ ("ヴャ", "by", "a"),
216
+ ("ヲ", None, "o"),
217
+ ("ヱ", None, "e"),
218
+ ("ヰ", None, "i"),
219
+ ("ヮ", "w", "a"),
220
+ ("ョ", "y", "o"),
221
+ ("ュ", "y", "u"),
222
+ ("ヅ", "z", "u"),
223
+ ("ヂ", "j", "i"),
224
+ ("ヶ", "k", "e"),
225
+ ("ャ", "y", "a"),
226
+ ("ォ", None, "o"),
227
+ ("ェ", None, "e"),
228
+ ("ゥ", None, "u"),
229
+ ("ィ", None, "i"),
230
+ ("ァ", None, "a"),
231
+ ]
232
+
233
+ # 例: "vo" -> "ヴォ", "a" -> "ア"
234
+ mora_phonemes_to_mora_kata: dict[str, str] = {
235
+ (consonant or "") + vowel: kana for [kana, consonant, vowel] in _mora_list_minimum
236
+ }
237
+
238
+ # 例: "ヴォ" -> ("v", "o"), "ア" -> (None, "a")
239
+ mora_kata_to_mora_phonemes: dict[str, tuple[Optional[str], str]] = {
240
+ kana: (consonant, vowel)
241
+ for [kana, consonant, vowel] in _mora_list_minimum + _mora_list_additional
242
+ }
243
+
244
+
245
+ # 正規化で記号を変換するための辞書
246
+ rep_map = {
247
+ ":": ":",
248
+ ";": ";",
249
+ ",": ",",
250
+ "。": ".",
251
+ "!": "!",
252
+ "?": "?",
253
+ "\n": ".",
254
+ ".": ".",
255
+ "⋯": "…",
256
+ "···": "…",
257
+ "・・・": "…",
258
+ "·": ",",
259
+ "・": ",",
260
+ "•": ",",
261
+ "、": ",",
262
+ "$": ".",
263
+ # "“": "'",
264
+ # "”": "'",
265
+ # '"': "'",
266
+ "‘": "'",
267
+ "’": "'",
268
+ # "(": "'",
269
+ # ")": "'",
270
+ # "(": "'",
271
+ # ")": "'",
272
+ # "《": "'",
273
+ # "》": "'",
274
+ # "【": "'",
275
+ # "】": "'",
276
+ # "[": "'",
277
+ # "]": "'",
278
+ # "——": "-",
279
+ # "−": "-",
280
+ # "-": "-",
281
+ # "『": "'",
282
+ # "』": "'",
283
+ # "〈": "'",
284
+ # "〉": "'",
285
+ # "«": "'",
286
+ # "»": "'",
287
+ # # "~": "-", # これは長音記号「ー」として扱うよう変更
288
+ # # "~": "-", # これは長音記号「ー」として扱うよう変更
289
+ # "「": "'",
290
+ # "」": "'",
291
+ }
292
+
293
+
294
+ def _numeric_feature_by_regex(regex, s):
295
+ match = re.search(regex, s)
296
+ if match is None:
297
+ return -50
298
+ return int(match.group(1))
299
+
300
+
301
+ def replace_punctuation(text: str) -> str:
302
+ """句読点等を「.」「,」「!」「?」「'」「-」に正規化し、OpenJTalkで読みが取得できるもののみ残す:
303
+ 漢字・平仮名・カタカナ、アルファベット、ギリシャ文字
304
+ """
305
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
306
+ # print("before: ", text)
307
+ # 句読点を辞書で置換
308
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
309
+
310
+ replaced_text = re.sub(
311
+ # ↓ ひらがな、カタカナ、漢字
312
+ r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005"
313
+ # ↓ 半角アルファベット(大文字と小文字)
314
+ + r"\u0041-\u005A\u0061-\u007A"
315
+ # ↓ 全角アルファベット(大文字と小文字)
316
+ + r"\uFF21-\uFF3A\uFF41-\uFF5A"
317
+ # ↓ ギリシャ文字
318
+ + r"\u0370-\u03FF\u1F00-\u1FFF"
319
+ # ↓ "!", "?", "…", ",", ".", "'", "-", 但し`…`はすでに`...`に変換されている
320
+ + "".join(punctuation) + r"]+",
321
+ # 上述以外の文字を削除
322
+ "",
323
+ replaced_text,
324
+ )
325
+ # print("after: ", replaced_text)
326
+ return replaced_text
327
+
328
+
329
+ def fix_phone_tone(phone_tone_list: list[tuple[str, int]]) -> list[tuple[str, int]]:
330
+ """
331
+ `phone_tone_list`のtone(アクセントの値)を0か1の範囲に修正する。
332
+ 例: [(a, 0), (i, -1), (u, -1)] → [(a, 1), (i, 0), (u, 0)]
333
+ """
334
+ tone_values = set(tone for _, tone in phone_tone_list)
335
+ if len(tone_values) == 1:
336
+ assert tone_values == {0}, tone_values
337
+ return phone_tone_list
338
+ elif len(tone_values) == 2:
339
+ if tone_values == {0, 1}:
340
+ return phone_tone_list
341
+ elif tone_values == {-1, 0}:
342
+ return [
343
+ (letter, 0 if tone == -1 else 1) for letter, tone in phone_tone_list
344
+ ]
345
+ else:
346
+ raise ValueError(f"Unexpected tone values: {tone_values}")
347
+ else:
348
+ raise ValueError(f"Unexpected tone values: {tone_values}")
349
+
350
+
351
+ def fix_phone_tone_wplen(phone_tone_list, word_phone_length_list):
352
+ phones = []
353
+ tones = []
354
+ w_p_len = []
355
+ p_len = len(phone_tone_list)
356
+ idx = 0
357
+ w_idx = 0
358
+ while idx < p_len:
359
+ offset = 0
360
+ if phone_tone_list[idx] == "▁":
361
+ w_p_len.append(w_idx + 1)
362
+
363
+ curr_w_p_len = word_phone_length_list[w_idx]
364
+ for i in range(curr_w_p_len):
365
+ p, t = phone_tone_list[idx]
366
+ if p == ":" and len(phones) > 0:
367
+ if phones[-1][-1] != ":":
368
+ phones[-1] += ":"
369
+ offset -= 1
370
+ else:
371
+ phones.append(p)
372
+ tones.append(str(t))
373
+ idx += 1
374
+ if idx >= p_len:
375
+ break
376
+ w_p_len.append(curr_w_p_len + offset)
377
+ w_idx += 1
378
+ # print(w_p_len)
379
+ return phones, tones, w_p_len
380
+
381
+
382
+ def g2phone_tone_wo_punct(prosodies) -> list[tuple[str, int]]:
383
+ """
384
+ テキストに対して、音素とアクセント(0か1)のペアのリストを返す。
385
+ ただし「!」「.」「?」等の非音素記号(punctuation)は全て消える(ポーズ記号も残さない)。
386
+ 非音素記号を含める処理は`align_tones()`で行われる。
387
+ また「っ」は「cl」でなく「q」に変換される(「ん」は「N」のまま)。
388
+ 例: "こんにちは、世界ー。。元気?!" →
389
+ [('k', 0), ('o', 0), ('N', 1), ('n', 1), ('i', 1), ('ch', 1), ('i', 1), ('w', 1), ('a', 1), ('s', 1), ('e', 1), ('k', 0), ('a', 0), ('i', 0), ('i', 0), ('g', 1), ('e', 1), ('N', 0), ('k', 0), ('i', 0)]
390
+ """
391
+ result: list[tuple[str, int]] = []
392
+ current_phrase: list[tuple[str, int]] = []
393
+ current_tone = 0
394
+ last_accent = ""
395
+ for i, letter in enumerate(prosodies):
396
+ # 特殊記号の処理
397
+
398
+ # 文頭記号、無視する
399
+ if letter == "^":
400
+ assert i == 0, "Unexpected ^"
401
+ # アクセント句の終わりに来る記号
402
+ elif letter in ("$", "?", "_", "#"):
403
+ # 保持しているフレーズを、アクセント数値を0-1に修正し結果に追加
404
+ result.extend(fix_phone_tone(current_phrase))
405
+ # 末尾に来る終了記号、無視(文中の疑問文は`_`になる)
406
+ if letter in ("$", "?"):
407
+ assert i == len(prosodies) - 1, f"Unexpected {letter}"
408
+ # あとは"_"(ポーズ)と"#"(アクセント句の境界)のみ
409
+ # これらは残さず、次のアクセント句に備える。
410
+
411
+ current_phrase = []
412
+ # 0を基準点にしてそこから上昇・下降する(負の場合は上の`fix_phone_tone`で直る)
413
+ current_tone = 0
414
+ last_accent = ""
415
+ # アクセント上昇記号
416
+ elif letter == "[":
417
+ if last_accent != letter:
418
+ current_tone = current_tone + 1
419
+ last_accent = letter
420
+ # アクセント下降記号
421
+ elif letter == "]":
422
+ if last_accent != letter:
423
+ current_tone = current_tone - 1
424
+ last_accent = letter
425
+ # それ以外は通常の音素
426
+ else:
427
+ if letter == "cl": # 「っ」の処理
428
+ letter = "q"
429
+ current_phrase.append((letter, current_tone))
430
+ return result
431
+
432
+
433
+ def handle_long(sep_phonemes: list[list[str]]) -> list[list[str]]:
434
+ for i in range(len(sep_phonemes)):
435
+ if sep_phonemes[i][0] == "ー":
436
+ # sep_phonemes[i][0] = sep_phonemes[i - 1][-1]
437
+ sep_phonemes[i][0] = ":"
438
+ if "ー" in sep_phonemes[i]:
439
+ for j in range(len(sep_phonemes[i])):
440
+ if sep_phonemes[i][j] == "ー":
441
+ # sep_phonemes[i][j] = sep_phonemes[i][j - 1][-1]
442
+ sep_phonemes[i][j] = ":"
443
+ return sep_phonemes
444
+
445
+
446
+ def handle_long_word(sep_phonemes: list[list[str]]) -> list[list[str]]:
447
+ res = []
448
+ for i in range(len(sep_phonemes)):
449
+ if sep_phonemes[i][0] == "ー":
450
+ sep_phonemes[i][0] = sep_phonemes[i - 1][-1]
451
+ # sep_phonemes[i][0] = ':'
452
+ if "ー" in sep_phonemes[i]:
453
+ for j in range(len(sep_phonemes[i])):
454
+ if sep_phonemes[i][j] == "ー":
455
+ sep_phonemes[i][j] = sep_phonemes[i][j - 1][-1]
456
+ # sep_phonemes[i][j] = ':'
457
+ res.append(sep_phonemes[i])
458
+ res.append("▁")
459
+ return res
460
+
461
+
462
+ def align_tones(
463
+ phones_with_punct: list[str], phone_tone_list: list[tuple[str, int]]
464
+ ) -> list[tuple[str, int]]:
465
+ """
466
+ 例:
467
+ …私は、、そう思う。
468
+ phones_with_punct:
469
+ [".", ".", ".", "w", "a", "t", "a", "sh", "i", "w", "a", ",", ",", "s", "o", "o", "o", "m", "o", "u", "."]
470
+ phone_tone_list:
471
+ [("w", 0), ("a", 0), ("t", 1), ("a", 1), ("sh", 1), ("i", 1), ("w", 1), ("a", 1), ("s", 0), ("o", 0), ("o", 1), ("o", 1), ("m", 1), ("o", 1), ("u", 0))]
472
+ Return:
473
+ [(".", 0), (".", 0), (".", 0), ("w", 0), ("a", 0), ("t", 1), ("a", 1), ("sh", 1), ("i", 1), ("w", 1), ("a", 1), (",", 0), (",", 0), ("s", 0), ("o", 0), ("o", 1), ("o", 1), ("m", 1), ("o", 1), ("u", 0), (".", 0)]
474
+ """
475
+ result: list[tuple[str, int]] = []
476
+ tone_index = 0
477
+ for phone in phones_with_punct:
478
+ if tone_index >= len(phone_tone_list):
479
+ # 余ったpunctuationがある場合 → (punctuation, 0)を追加
480
+ result.append((phone, 0))
481
+ elif phone == phone_tone_list[tone_index][0]:
482
+ # phone_tone_listの現在の音素と一致する場合 → toneをそこから取得、(phone, tone)を追加
483
+ result.append((phone, phone_tone_list[tone_index][1]))
484
+ # 探すindexを1つ進める
485
+ tone_index += 1
486
+ elif phone in punctuation or phone == "▁":
487
+ # phoneがpunctuationの場合 → (phone, 0)を追加
488
+ result.append((phone, 0))
489
+ else:
490
+ print(f"phones: {phones_with_punct}")
491
+ print(f"phone_tone_list: {phone_tone_list}")
492
+ print(f"result: {result}")
493
+ print(f"tone_index: {tone_index}")
494
+ print(f"phone: {phone}")
495
+ raise ValueError(f"Unexpected phone: {phone}")
496
+ return result
497
+
498
+
499
+ def kata2phoneme_list(text: str) -> list[str]:
500
+ """
501
+ 原則カタカナの`text`を受け取り、それをそのままいじらずに音素記号のリストに変換。
502
+ 注意点:
503
+ - punctuationが来た場合(punctuationが1文字の場合がありうる)、処理せず1文字のリストを返す
504
+ - 冒頭に続く「ー」はそのまま「ー」のままにする(`handle_long()`で処理される)
505
+ - 文中の「ー」は前の音素記号の最後の音素記号に変換される。
506
+ 例:
507
+ `ーーソーナノカーー` → ["ー", "ー", "s", "o", "o", "n", "a", "n", "o", "k", "a", "a", "a"]
508
+ `?` → ["?"]
509
+ """
510
+ if text in punctuation:
511
+ return [text]
512
+ # `text`がカタカナ(`ー`含む)のみからなるかどうかをチェック
513
+ if re.fullmatch(r"[\u30A0-\u30FF]+", text) is None:
514
+ raise ValueError(f"Input must be katakana only: {text}")
515
+ sorted_keys = sorted(mora_kata_to_mora_phonemes.keys(), key=len, reverse=True)
516
+ pattern = "|".join(map(re.escape, sorted_keys))
517
+
518
+ def mora2phonemes(mora: str) -> str:
519
+ cosonant, vowel = mora_kata_to_mora_phonemes[mora]
520
+ if cosonant is None:
521
+ return f" {vowel}"
522
+ return f" {cosonant} {vowel}"
523
+
524
+ spaced_phonemes = re.sub(pattern, lambda m: mora2phonemes(m.group()), text)
525
+
526
+ # 長音記号「ー」の処理
527
+ long_pattern = r"(\w)(ー*)"
528
+ long_replacement = lambda m: m.group(1) + (" " + m.group(1)) * len(m.group(2))
529
+ spaced_phonemes = re.sub(long_pattern, long_replacement, spaced_phonemes)
530
+ # spaced_phonemes += ' ▁'
531
+ return spaced_phonemes.strip().split(" ")
532
+
533
+
534
+ def frontend2phoneme(labels, drop_unvoiced_vowels=False):
535
+ N = len(labels)
536
+
537
+ phones = []
538
+ for n in range(N):
539
+ lab_curr = labels[n]
540
+ # print(lab_curr)
541
+ # current phoneme
542
+ p3 = re.search(r"\-(.*?)\+", lab_curr).group(1)
543
+
544
+ # deal unvoiced vowels as normal vowels
545
+ if drop_unvoiced_vowels and p3 in "AEIOU":
546
+ p3 = p3.lower()
547
+
548
+ # deal with sil at the beginning and the end of text
549
+ if p3 == "sil":
550
+ # assert n == 0 or n == N - 1
551
+ # if n == 0:
552
+ # phones.append("^")
553
+ # elif n == N - 1:
554
+ # # check question form or not
555
+ # e3 = _numeric_feature_by_regex(r"!(\d+)_", lab_curr)
556
+ # if e3 == 0:
557
+ # phones.append("$")
558
+ # elif e3 == 1:
559
+ # phones.append("?")
560
+ continue
561
+ elif p3 == "pau":
562
+ phones.append("_")
563
+ continue
564
+ else:
565
+ phones.append(p3)
566
+
567
+ # accent type and position info (forward or backward)
568
+ a1 = _numeric_feature_by_regex(r"/A:([0-9\-]+)\+", lab_curr)
569
+ a2 = _numeric_feature_by_regex(r"\+(\d+)\+", lab_curr)
570
+ a3 = _numeric_feature_by_regex(r"\+(\d+)/", lab_curr)
571
+
572
+ # number of mora in accent phrase
573
+ f1 = _numeric_feature_by_regex(r"/F:(\d+)_", lab_curr)
574
+
575
+ a2_next = _numeric_feature_by_regex(r"\+(\d+)\+", labels[n + 1])
576
+ # accent phrase border
577
+ # print(p3, a1, a2, a3, f1, a2_next, lab_curr)
578
+ if a3 == 1 and a2_next == 1 and p3 in "aeiouAEIOUNcl":
579
+ phones.append("#")
580
+ # pitch falling
581
+ elif a1 == 0 and a2_next == a2 + 1 and a2 != f1:
582
+ phones.append("]")
583
+ # pitch rising
584
+ elif a2 == 1 and a2_next == 2:
585
+ phones.append("[")
586
+
587
+ # phones = ' '.join(phones)
588
+ return phones
589
+
590
+
591
+ class JapanesePhoneConverter(object):
592
+ def __init__(self, lexicon_path=None, ipa_dict_path=None):
593
+ # lexicon_lines = open(lexicon_path, 'r', encoding='utf-8').readlines()
594
+ # self.lexicon = {}
595
+ # self.single_dict = {}
596
+ # self.double_dict = {}
597
+ # for curr_line in lexicon_lines:
598
+ # k,v = curr_line.strip().split('+',1)
599
+ # self.lexicon[k] = v
600
+ # if len(k) == 2:
601
+ # self.double_dict[k] = v
602
+ # elif len(k) == 1:
603
+ # self.single_dict[k] = v
604
+ self.ipa_dict = {}
605
+ for curr_line in jp_xphone2ipa:
606
+ k, v = curr_line.strip().split(" ", 1)
607
+ self.ipa_dict[k] = re.sub("\s", "", v)
608
+ # kakasi1 = kakasi()
609
+ # kakasi1.setMode("H","K")
610
+ # kakasi1.setMode("J","K")
611
+ # kakasi1.setMode("r","Hepburn")
612
+ self.japan_JH2K = kakasi()
613
+ self.table = {ord(f): ord(t) for f, t in zip("67", "_¯")}
614
+
615
+ def text2sep_kata(self, parsed) -> tuple[list[str], list[str]]:
616
+ """
617
+ `text_normalize`で正規化済みの`norm_text`を受け取り、それを単語分割し、
618
+ 分割された単語リストとその読み(カタカナor記号1文字)のリス���のタプルを返す。
619
+ 単語分割結果は、`g2p()`の`word2ph`で1文字あたりに割り振る音素記号の数を決めるために使う。
620
+ 例:
621
+ `私はそう思う!って感じ?` →
622
+ ["私", "は", "そう", "思う", "!", "って", "感じ", "?"], ["ワタシ", "ワ", "ソー", "オモウ", "!", "ッテ", "カンジ", "?"]
623
+ """
624
+ # parsed: OpenJTalkの解析結果
625
+ sep_text: list[str] = []
626
+ sep_kata: list[str] = []
627
+ fix_parsed = []
628
+ i = 0
629
+ while i <= len(parsed) - 1:
630
+ # word: 実際の単語の文字列
631
+ # yomi: その読み、但し無声化サインの`’`は除去
632
+ # print(parsed)
633
+ yomi = parsed[i]["pron"]
634
+ tmp_parsed = parsed[i]
635
+ if i != len(parsed) - 1 and parsed[i + 1]["string"] in [
636
+ "々",
637
+ "ゝ",
638
+ "ヽ",
639
+ "ゞ",
640
+ "ヾ",
641
+ "゛",
642
+ ]:
643
+ word = parsed[i]["string"] + parsed[i + 1]["string"]
644
+ i += 1
645
+ else:
646
+ word = parsed[i]["string"]
647
+ word, yomi = replace_punctuation(word), yomi.replace("’", "")
648
+ """
649
+ ここで`yomi`の取りうる値は以下の通りのはず。
650
+ - `word`が通常単語 → 通常の読み(カタカナ)
651
+ (カタカナからなり、長音記号も含みうる、`アー` 等)
652
+ - `word`が`ー` から始まる → `ーラー` や `ーーー` など
653
+ - `word`が句読点や空白等 → `、`
654
+ - `word`が`?` → `?`(全角になる)
655
+ 他にも`word`が読めないキリル文字アラビア文字等が来ると`、`になるが、正規化でこの場合は起きないはず。
656
+ また元のコードでは`yomi`が空白の場合の処理があったが、これは起きないはず。
657
+ 処理すべきは`yomi`が`、`の場合のみのはず。
658
+ """
659
+ assert yomi != "", f"Empty yomi: {word}"
660
+ if yomi == "、":
661
+ # wordは正規化されているので、`.`, `,`, `!`, `'`, `-`のいずれか
662
+ if word not in (
663
+ ".",
664
+ ",",
665
+ "!",
666
+ "'",
667
+ "-",
668
+ "?",
669
+ ":",
670
+ ";",
671
+ "…",
672
+ "",
673
+ ):
674
+ # ここはpyopenjtalkが読めない文字等のときに起こる
675
+ #print(
676
+ # "{}Cannot read:{}, yomi:{}, new_word:{};".format(
677
+ # parsed, word, yomi, self.japan_JH2K.convert(word)[0]["kana"]
678
+ # )
679
+ #)
680
+ # raise ValueError(word)
681
+ word = self.japan_JH2K.convert(word)[0]["kana"]
682
+ # print(word, self.japan_JH2K.convert(word)[0]['kana'], kata2phoneme_list(self.japan_JH2K.convert(word)[0]['kana']))
683
+ tmp_parsed["pron"] = word
684
+ # yomi = "-"
685
+ # word = ','
686
+ # yomiは元の記号のままに変更
687
+ # else:
688
+ # parsed[i]['pron'] = parsed[i]["string"]
689
+ yomi = word
690
+ elif yomi == "?":
691
+ assert word == "?", f"yomi `?` comes from: {word}"
692
+ yomi = "?"
693
+ if word == "":
694
+ i += 1
695
+ continue
696
+ sep_text.append(word)
697
+ sep_kata.append(yomi)
698
+ # print(word, yomi, parts)
699
+ fix_parsed.append(tmp_parsed)
700
+ i += 1
701
+ # print(sep_text, sep_kata)
702
+ return sep_text, sep_kata, fix_parsed
703
+
704
+ def getSentencePhone(self, sentence, blank_mode=True, phoneme_mode=False):
705
+ # print("origin:", sentence)
706
+ words = []
707
+ words_phone_len = []
708
+ short_char_flag = False
709
+ output_duration_flag = []
710
+ output_before_sil_flag = []
711
+ normed_text = []
712
+ sentence = sentence.strip().strip("'")
713
+ sentence = re.sub(r"\s+", "", sentence)
714
+ output_res = []
715
+ failed_words = []
716
+ last_long_pause = 4
717
+ last_word = None
718
+ frontend_text = pyopenjtalk.run_frontend(sentence)
719
+ # print("frontend_text: ", frontend_text)
720
+ try:
721
+ frontend_text = pyopenjtalk.estimate_accent(frontend_text)
722
+ except:
723
+ pass
724
+ # print("estimate_accent: ", frontend_text)
725
+ # sep_text: 単語単位の単語のリスト
726
+ # sep_kata: 単語単位の単語のカタカナ読みのリスト
727
+ sep_text, sep_kata, frontend_text = self.text2sep_kata(frontend_text)
728
+ # print("sep_text: ", sep_text)
729
+ # print("sep_kata: ", sep_kata)
730
+ # print("frontend_text: ", frontend_text)
731
+ # sep_phonemes: 各単語ご���の音素のリストのリスト
732
+ sep_phonemes = handle_long_word([kata2phoneme_list(i) for i in sep_kata])
733
+ # print("sep_phonemes: ", sep_phonemes)
734
+
735
+ pron_text = [x["pron"].strip().replace("’", "") for x in frontend_text]
736
+ # pdb.set_trace()
737
+ prosodys = pyopenjtalk.make_label(frontend_text)
738
+ prosodys = frontend2phoneme(prosodys, drop_unvoiced_vowels=True)
739
+ # print("prosodys: ", ' '.join(prosodys))
740
+ # print("pron_text: ", pron_text)
741
+ normed_text = [x["string"].strip() for x in frontend_text]
742
+ # punctuationがすべて消えた、音素とアクセントのタプルのリスト
743
+ phone_tone_list_wo_punct = g2phone_tone_wo_punct(prosodys)
744
+ # print("phone_tone_list_wo_punct: ", phone_tone_list_wo_punct)
745
+
746
+ # phone_w_punct: sep_phonemesを結合した、punctuationを元のまま保持した音素列
747
+ phone_w_punct: list[str] = []
748
+ w_p_len = []
749
+ for i in sep_phonemes:
750
+ phone_w_punct += i
751
+ w_p_len.append(len(i))
752
+ phone_w_punct = phone_w_punct[:-1]
753
+ # punctuation無しのアクセント情報を使って、punctuationを含めたアクセント情報を作る
754
+ # print("phone_w_punct: ", phone_w_punct)
755
+ # print("phone_tone_list_wo_punct: ", phone_tone_list_wo_punct)
756
+ phone_tone_list = align_tones(phone_w_punct, phone_tone_list_wo_punct)
757
+
758
+ jp_item = {}
759
+ jp_p = ""
760
+ jp_t = ""
761
+ # mye rye pye bye nye
762
+ # je she
763
+ # print(phone_tone_list)
764
+ for p, t in phone_tone_list:
765
+ if p in self.ipa_dict:
766
+ curr_p = self.ipa_dict[p]
767
+ jp_p += curr_p
768
+ jp_t += str(t + 6) * len(curr_p)
769
+ elif p in punctuation:
770
+ jp_p += p
771
+ jp_t += "0"
772
+ elif p == "▁":
773
+ jp_p += p
774
+ jp_t += " "
775
+ else:
776
+ print(p, t)
777
+ jp_p += "|"
778
+ jp_t += "0"
779
+ # return phones, tones, w_p_len
780
+ jp_p = jp_p.replace("▁", " ")
781
+ jp_t = jp_t.translate(self.table)
782
+ jp_l = ""
783
+ for t in jp_t:
784
+ if t == " ":
785
+ jp_l += " "
786
+ else:
787
+ jp_l += "2"
788
+ # print(jp_p)
789
+ # print(jp_t)
790
+ # print(jp_l)
791
+ # print(len(jp_p_len), sum(w_p_len), len(jp_p), sum(jp_p_len))
792
+ assert len(jp_p) == len(jp_t) and len(jp_p) == len(jp_l)
793
+
794
+ jp_item["jp_p"] = jp_p.replace("| |", "|").rstrip("|")
795
+ jp_item["jp_t"] = jp_t
796
+ jp_item["jp_l"] = jp_l
797
+ jp_item["jp_normed_text"] = " ".join(normed_text)
798
+ jp_item["jp_pron_text"] = " ".join(pron_text)
799
+ # jp_item['jp_ruoma'] = sep_phonemes
800
+ # print(len(normed_text), len(sep_phonemes))
801
+ # print(normed_text)
802
+ return jp_item
803
+
804
+
805
+ jpc = JapanesePhoneConverter()
806
+
807
+
808
+ def japanese_to_ipa(text, text_tokenizer):
809
+ # phonemes = text_tokenizer(text)
810
+ if type(text) == str:
811
+ return jpc.getSentencePhone(text)["jp_p"]
812
+ else:
813
+ result_ph = []
814
+ for t in text:
815
+ result_ph.append(jpc.getSentencePhone(t)["jp_p"])
816
+ return result_ph
soundsation/g2p/g2p/korean.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+
8
+ """
9
+ Text clean time
10
+ """
11
+ english_dictionary = {
12
+ "KOREA": "코리아",
13
+ "IDOL": "아이돌",
14
+ "IT": "아이티",
15
+ "IQ": "아이큐",
16
+ "UP": "업",
17
+ "DOWN": "다운",
18
+ "PC": "피씨",
19
+ "CCTV": "씨씨티비",
20
+ "SNS": "에스엔에스",
21
+ "AI": "에이아이",
22
+ "CEO": "씨이오",
23
+ "A": "에이",
24
+ "B": "비",
25
+ "C": "씨",
26
+ "D": "디",
27
+ "E": "이",
28
+ "F": "에프",
29
+ "G": "지",
30
+ "H": "에이치",
31
+ "I": "아이",
32
+ "J": "제이",
33
+ "K": "케이",
34
+ "L": "엘",
35
+ "M": "엠",
36
+ "N": "엔",
37
+ "O": "오",
38
+ "P": "피",
39
+ "Q": "큐",
40
+ "R": "알",
41
+ "S": "에스",
42
+ "T": "티",
43
+ "U": "유",
44
+ "V": "브이",
45
+ "W": "더블유",
46
+ "X": "엑스",
47
+ "Y": "와이",
48
+ "Z": "제트",
49
+ }
50
+
51
+
52
+ def normalize(text):
53
+ text = text.strip()
54
+ text = re.sub(
55
+ "[⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]", "", text
56
+ )
57
+ text = normalize_english(text)
58
+ text = text.lower()
59
+ return text
60
+
61
+
62
+ def normalize_english(text):
63
+ def fn(m):
64
+ word = m.group()
65
+ if word in english_dictionary:
66
+ return english_dictionary.get(word)
67
+ return word
68
+
69
+ text = re.sub("([A-Za-z]+)", fn, text)
70
+ return text
71
+
72
+
73
+ def korean_to_ipa(text, text_tokenizer):
74
+ if type(text) == str:
75
+ text = normalize(text)
76
+ phonemes = text_tokenizer(text)
77
+ return phonemes
78
+ else:
79
+ for i, t in enumerate(text):
80
+ text[i] = normalize(t)
81
+ return text_tokenizer(text)
soundsation/g2p/g2p/mandarin.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+ import jieba
8
+ import cn2an
9
+ from pypinyin import lazy_pinyin, BOPOMOFO
10
+ from typing import List
11
+ from soundsation.g2p.g2p.chinese_model_g2p import BertPolyPredict
12
+ from soundsation.g2p.utils.front_utils import *
13
+ import os
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ # from g2pw import G2PWConverter
17
+
18
+
19
+ # set blank level, {0:"none",1:"char", 2:"word"}
20
+ BLANK_LEVEL = 0
21
+
22
+ # conv = G2PWConverter(style='pinyin', enable_non_tradional_chinese=True)
23
+ resource_path = r"./soundsation/g2p"
24
+ poly_all_class_path = os.path.join(
25
+ resource_path, "sources", "g2p_chinese_model", "polychar.txt"
26
+ )
27
+ if not os.path.exists(poly_all_class_path):
28
+ print(
29
+ "Incorrect path for polyphonic character class dictionary: {}, please check...".format(
30
+ poly_all_class_path
31
+ )
32
+ )
33
+ exit()
34
+ poly_dict = generate_poly_lexicon(poly_all_class_path)
35
+
36
+ # Set up G2PW model parameters
37
+ g2pw_poly_model_path = os.path.join(resource_path, "sources", "g2p_chinese_model")
38
+ if not os.path.exists(g2pw_poly_model_path):
39
+ print(
40
+ "Incorrect path for g2pw polyphonic character model: {}, please check...".format(
41
+ g2pw_poly_model_path
42
+ )
43
+ )
44
+ exit()
45
+
46
+ json_file_path = os.path.join(
47
+ resource_path, "sources", "g2p_chinese_model", "polydict.json"
48
+ )
49
+ if not os.path.exists(json_file_path):
50
+ print(
51
+ "Incorrect path for g2pw id to pinyin dictionary: {}, please check...".format(
52
+ json_file_path
53
+ )
54
+ )
55
+ exit()
56
+
57
+ jsonr_file_path = os.path.join(
58
+ resource_path, "sources", "g2p_chinese_model", "polydict_r.json"
59
+ )
60
+ if not os.path.exists(jsonr_file_path):
61
+ print(
62
+ "Incorrect path for g2pw pinyin to id dictionary: {}, please check...".format(
63
+ jsonr_file_path
64
+ )
65
+ )
66
+ exit()
67
+
68
+ g2pw_poly_predict = BertPolyPredict(
69
+ g2pw_poly_model_path, jsonr_file_path, json_file_path
70
+ )
71
+
72
+
73
+ """
74
+ Text clean time
75
+ """
76
+ # List of (Latin alphabet, bopomofo) pairs:
77
+ _latin_to_bopomofo = [
78
+ (re.compile("%s" % x[0], re.IGNORECASE), x[1])
79
+ for x in [
80
+ ("a", "ㄟˉ"),
81
+ ("b", "ㄅㄧˋ"),
82
+ ("c", "ㄙㄧˉ"),
83
+ ("d", "ㄉㄧˋ"),
84
+ ("e", "ㄧˋ"),
85
+ ("f", "ㄝˊㄈㄨˋ"),
86
+ ("g", "ㄐㄧˋ"),
87
+ ("h", "ㄝˇㄑㄩˋ"),
88
+ ("i", "ㄞˋ"),
89
+ ("j", "ㄐㄟˋ"),
90
+ ("k", "ㄎㄟˋ"),
91
+ ("l", "ㄝˊㄛˋ"),
92
+ ("m", "ㄝˊㄇㄨˋ"),
93
+ ("n", "ㄣˉ"),
94
+ ("o", "ㄡˉ"),
95
+ ("p", "ㄆㄧˉ"),
96
+ ("q", "ㄎㄧㄡˉ"),
97
+ ("r", "ㄚˋ"),
98
+ ("s", "ㄝˊㄙˋ"),
99
+ ("t", "ㄊㄧˋ"),
100
+ ("u", "ㄧㄡˉ"),
101
+ ("v", "ㄨㄧˉ"),
102
+ ("w", "ㄉㄚˋㄅㄨˋㄌㄧㄡˋ"),
103
+ ("x", "ㄝˉㄎㄨˋㄙˋ"),
104
+ ("y", "ㄨㄞˋ"),
105
+ ("z", "ㄗㄟˋ"),
106
+ ]
107
+ ]
108
+
109
+ # List of (bopomofo, ipa) pairs:
110
+ _bopomofo_to_ipa = [
111
+ (re.compile("%s" % x[0]), x[1])
112
+ for x in [
113
+ ("ㄅㄛ", "p⁼wo"),
114
+ ("ㄆㄛ", "pʰwo"),
115
+ ("ㄇㄛ", "mwo"),
116
+ ("ㄈㄛ", "fwo"),
117
+ ("ㄧㄢ", "|jɛn"),
118
+ ("ㄩㄢ", "|ɥæn"),
119
+ ("ㄧㄣ", "|in"),
120
+ ("ㄩㄣ", "|ɥn"),
121
+ ("ㄧㄥ", "|iŋ"),
122
+ ("ㄨㄥ", "|ʊŋ"),
123
+ ("ㄩㄥ", "|jʊŋ"),
124
+ # Add
125
+ ("ㄧㄚ", "|ia"),
126
+ ("ㄧㄝ", "|iɛ"),
127
+ ("ㄧㄠ", "|iɑʊ"),
128
+ ("ㄧㄡ", "|ioʊ"),
129
+ ("ㄧㄤ", "|iɑŋ"),
130
+ ("ㄨㄚ", "|ua"),
131
+ ("ㄨㄛ", "|uo"),
132
+ ("ㄨㄞ", "|uaɪ"),
133
+ ("ㄨㄟ", "|ueɪ"),
134
+ ("ㄨㄢ", "|uan"),
135
+ ("ㄨㄣ", "|uən"),
136
+ ("ㄨㄤ", "|uɑŋ"),
137
+ ("ㄩㄝ", "|ɥɛ"),
138
+ # End
139
+ ("ㄅ", "p⁼"),
140
+ ("ㄆ", "pʰ"),
141
+ ("ㄇ", "m"),
142
+ ("ㄈ", "f"),
143
+ ("ㄉ", "t⁼"),
144
+ ("ㄊ", "tʰ"),
145
+ ("ㄋ", "n"),
146
+ ("ㄌ", "l"),
147
+ ("ㄍ", "k⁼"),
148
+ ("ㄎ", "kʰ"),
149
+ ("ㄏ", "x"),
150
+ ("ㄐ", "tʃ⁼"),
151
+ ("ㄑ", "tʃʰ"),
152
+ ("ㄒ", "ʃ"),
153
+ ("ㄓ", "ts`⁼"),
154
+ ("ㄔ", "ts`ʰ"),
155
+ ("ㄕ", "s`"),
156
+ ("ㄖ", "ɹ`"),
157
+ ("ㄗ", "ts⁼"),
158
+ ("ㄘ", "tsʰ"),
159
+ ("ㄙ", "|s"),
160
+ ("ㄚ", "|a"),
161
+ ("ㄛ", "|o"),
162
+ ("ㄜ", "|ə"),
163
+ ("ㄝ", "|ɛ"),
164
+ ("ㄞ", "|aɪ"),
165
+ ("ㄟ", "|eɪ"),
166
+ ("ㄠ", "|ɑʊ"),
167
+ ("ㄡ", "|oʊ"),
168
+ ("ㄢ", "|an"),
169
+ ("ㄣ", "|ən"),
170
+ ("ㄤ", "|ɑŋ"),
171
+ ("ㄥ", "|əŋ"),
172
+ ("ㄦ", "əɹ"),
173
+ ("ㄧ", "|i"),
174
+ ("ㄨ", "|u"),
175
+ ("ㄩ", "|ɥ"),
176
+ ("ˉ", "→|"),
177
+ ("ˊ", "↑|"),
178
+ ("ˇ", "↓↑|"),
179
+ ("ˋ", "↓|"),
180
+ ("˙", "|"),
181
+ ]
182
+ ]
183
+ must_not_er_words = {"女儿", "老儿", "男儿", "少儿", "小儿"}
184
+
185
+
186
+ chinese_lexicon_path = hf_hub_download(
187
+ repo_id="josephchay/Soundsation",
188
+ filename="soundsation/g2p/sources/chinese_lexicon.txt",
189
+ repo_type="space"
190
+ )
191
+ word_pinyin_dict = {}
192
+ with open(chinese_lexicon_path, "r", encoding="utf-8") as fread:
193
+ txt_list = fread.readlines()
194
+ for txt in txt_list:
195
+ word, pinyin = txt.strip().split("\t")
196
+ word_pinyin_dict[word] = pinyin
197
+ fread.close()
198
+
199
+ pinyin_2_bopomofo_dict = {}
200
+ with open(
201
+ r"./soundsation/g2p/sources/pinyin_2_bpmf.txt", "r", encoding="utf-8"
202
+ ) as fread:
203
+ txt_list = fread.readlines()
204
+ for txt in txt_list:
205
+ pinyin, bopomofo = txt.strip().split("\t")
206
+ pinyin_2_bopomofo_dict[pinyin] = bopomofo
207
+ fread.close()
208
+
209
+ tone_dict = {
210
+ "0": "˙",
211
+ "5": "˙",
212
+ "1": "",
213
+ "2": "ˊ",
214
+ "3": "ˇ",
215
+ "4": "ˋ",
216
+ }
217
+
218
+ bopomofos2pinyin_dict = {}
219
+ with open(
220
+ r"./soundsation/g2p/sources/bpmf_2_pinyin.txt", "r", encoding="utf-8"
221
+ ) as fread:
222
+ txt_list = fread.readlines()
223
+ for txt in txt_list:
224
+ v, k = txt.strip().split("\t")
225
+ bopomofos2pinyin_dict[k] = v
226
+ fread.close()
227
+
228
+
229
+ def bpmf_to_pinyin(text):
230
+ bopomofo_list = text.split("|")
231
+ pinyin_list = []
232
+ for info in bopomofo_list:
233
+ pinyin = ""
234
+ for c in info:
235
+ if c in bopomofos2pinyin_dict:
236
+ pinyin += bopomofos2pinyin_dict[c]
237
+ if len(pinyin) == 0:
238
+ continue
239
+ if pinyin[-1] not in "01234":
240
+ pinyin += "1"
241
+ if pinyin[:-1] == "ve":
242
+ pinyin = "y" + pinyin
243
+ if pinyin[:-1] == "sh":
244
+ pinyin = pinyin[:-1] + "i" + pinyin[-1]
245
+ if pinyin == "sh":
246
+ pinyin = pinyin[:-1] + "i"
247
+ if pinyin[:-1] == "s":
248
+ pinyin = "si" + pinyin[-1]
249
+ if pinyin[:-1] == "c":
250
+ pinyin = "ci" + pinyin[-1]
251
+ if pinyin[:-1] == "i":
252
+ pinyin = "yi" + pinyin[-1]
253
+ if pinyin[:-1] == "iou":
254
+ pinyin = "you" + pinyin[-1]
255
+ if pinyin[:-1] == "ien":
256
+ pinyin = "yin" + pinyin[-1]
257
+ if "iou" in pinyin and pinyin[-4:-1] == "iou":
258
+ pinyin = pinyin[:-4] + "iu" + pinyin[-1]
259
+ if "uei" in pinyin:
260
+ if pinyin[:-1] == "uei":
261
+ pinyin = "wei" + pinyin[-1]
262
+ elif pinyin[-4:-1] == "uei":
263
+ pinyin = pinyin[:-4] + "ui" + pinyin[-1]
264
+ if "uen" in pinyin and pinyin[-4:-1] == "uen":
265
+ if pinyin[:-1] == "uen":
266
+ pinyin = "wen" + pinyin[-1]
267
+ elif pinyin[-4:-1] == "uei":
268
+ pinyin = pinyin[:-4] + "un" + pinyin[-1]
269
+ if "van" in pinyin and pinyin[-4:-1] == "van":
270
+ if pinyin[:-1] == "van":
271
+ pinyin = "yuan" + pinyin[-1]
272
+ elif pinyin[-4:-1] == "van":
273
+ pinyin = pinyin[:-4] + "uan" + pinyin[-1]
274
+ if "ueng" in pinyin and pinyin[-5:-1] == "ueng":
275
+ pinyin = pinyin[:-5] + "ong" + pinyin[-1]
276
+ if pinyin[:-1] == "veng":
277
+ pinyin = "yong" + pinyin[-1]
278
+ if "veng" in pinyin and pinyin[-5:-1] == "veng":
279
+ pinyin = pinyin[:-5] + "iong" + pinyin[-1]
280
+ if pinyin[:-1] == "ieng":
281
+ pinyin = "ying" + pinyin[-1]
282
+ if pinyin[:-1] == "u":
283
+ pinyin = "wu" + pinyin[-1]
284
+ if pinyin[:-1] == "v":
285
+ pinyin = "yv" + pinyin[-1]
286
+ if pinyin[:-1] == "ing":
287
+ pinyin = "ying" + pinyin[-1]
288
+ if pinyin[:-1] == "z":
289
+ pinyin = "zi" + pinyin[-1]
290
+ if pinyin[:-1] == "zh":
291
+ pinyin = "zhi" + pinyin[-1]
292
+ if pinyin[0] == "u":
293
+ pinyin = "w" + pinyin[1:]
294
+ if pinyin[0] == "i":
295
+ pinyin = "y" + pinyin[1:]
296
+ pinyin = pinyin.replace("ien", "in")
297
+
298
+ pinyin_list.append(pinyin)
299
+ return " ".join(pinyin_list)
300
+
301
+
302
+ # Convert numbers to Chinese pronunciation
303
+ def number_to_chinese(text):
304
+ # numbers = re.findall(r'\d+(?:\.?\d+)?', text)
305
+ # for number in numbers:
306
+ # text = text.replace(number, cn2an.an2cn(number), 1)
307
+ text = cn2an.transform(text, "an2cn")
308
+ return text
309
+
310
+
311
+ def normalization(text):
312
+ text = text.replace(",", ",")
313
+ text = text.replace("。", ".")
314
+ text = text.replace("!", "!")
315
+ text = text.replace("?", "?")
316
+ text = text.replace(";", ";")
317
+ text = text.replace(":", ":")
318
+ text = text.replace("、", ",")
319
+ text = text.replace("‘", "'")
320
+ text = text.replace("’", "'")
321
+ text = text.replace("⋯", "…")
322
+ text = text.replace("···", "…")
323
+ text = text.replace("・・・", "…")
324
+ text = text.replace("...", "…")
325
+ text = re.sub(r"\s+", "", text)
326
+ text = re.sub(r"[^\u4e00-\u9fff\s_,\.\?!;:\'…]", "", text)
327
+ text = re.sub(r"\s*([,\.\?!;:\'…])\s*", r"\1", text)
328
+ return text
329
+
330
+
331
+ def change_tone(bopomofo: str, tone: str) -> str:
332
+ if bopomofo[-1] not in "˙ˊˇˋ":
333
+ bopomofo = bopomofo + tone
334
+ else:
335
+ bopomofo = bopomofo[:-1] + tone
336
+ return bopomofo
337
+
338
+
339
+ def er_sandhi(word: str, bopomofos: List[str]) -> List[str]:
340
+ if len(word) > 1 and word[-1] == "儿" and word not in must_not_er_words:
341
+ bopomofos[-1] = change_tone(bopomofos[-1], "˙")
342
+ return bopomofos
343
+
344
+
345
+ def bu_sandhi(word: str, bopomofos: List[str]) -> List[str]:
346
+ valid_char = set(word)
347
+ if len(valid_char) == 1 and "不" in valid_char:
348
+ pass
349
+ elif word in ["不字"]:
350
+ pass
351
+ elif len(word) == 3 and word[1] == "不" and bopomofos[1][:-1] == "ㄅㄨ":
352
+ bopomofos[1] = bopomofos[1][:-1] + "˙"
353
+ else:
354
+ for i, char in enumerate(word):
355
+ if (
356
+ i + 1 < len(bopomofos)
357
+ and char == "不"
358
+ and i + 1 < len(word)
359
+ and 0 < len(bopomofos[i + 1])
360
+ and bopomofos[i + 1][-1] == "ˋ"
361
+ ):
362
+ bopomofos[i] = bopomofos[i][:-1] + "ˊ"
363
+ return bopomofos
364
+
365
+
366
+ def yi_sandhi(word: str, bopomofos: List[str]) -> List[str]:
367
+ punc = ":,;。?!“”‘’':,;.?!()(){}【】[]-~`、 "
368
+ if word.find("一") != -1 and any(
369
+ [item.isnumeric() for item in word if item != "一"]
370
+ ):
371
+ for i in range(len(word)):
372
+ if (
373
+ i == 0
374
+ and word[0] == "一"
375
+ and len(word) > 1
376
+ and word[1]
377
+ not in [
378
+ "零",
379
+ "一",
380
+ "二",
381
+ "三",
382
+ "四",
383
+ "五",
384
+ "六",
385
+ "七",
386
+ "八",
387
+ "九",
388
+ "十",
389
+ ]
390
+ ):
391
+ if len(bopomofos[0]) > 0 and bopomofos[1][-1] in ["ˋ", "˙"]:
392
+ bopomofos[0] = change_tone(bopomofos[0], "ˊ")
393
+ else:
394
+ bopomofos[0] = change_tone(bopomofos[0], "ˋ")
395
+ elif word[i] == "一":
396
+ bopomofos[i] = change_tone(bopomofos[i], "")
397
+ return bopomofos
398
+ elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
399
+ bopomofos[1] = change_tone(bopomofos[1], "˙")
400
+ elif word.startswith("第一"):
401
+ bopomofos[1] = change_tone(bopomofos[1], "")
402
+ elif word.startswith("一月") or word.startswith("一日") or word.startswith("一号"):
403
+ bopomofos[0] = change_tone(bopomofos[0], "")
404
+ else:
405
+ for i, char in enumerate(word):
406
+ if char == "一" and i + 1 < len(word):
407
+ if (
408
+ len(bopomofos) > i + 1
409
+ and len(bopomofos[i + 1]) > 0
410
+ and bopomofos[i + 1][-1] in {"ˋ"}
411
+ ):
412
+ bopomofos[i] = change_tone(bopomofos[i], "ˊ")
413
+ else:
414
+ if word[i + 1] not in punc:
415
+ bopomofos[i] = change_tone(bopomofos[i], "ˋ")
416
+ else:
417
+ pass
418
+ return bopomofos
419
+
420
+
421
+ def merge_bu(seg: List) -> List:
422
+ new_seg = []
423
+ last_word = ""
424
+ for word in seg:
425
+ if word != "不":
426
+ if last_word == "不":
427
+ word = last_word + word
428
+ new_seg.append(word)
429
+ last_word = word
430
+ return new_seg
431
+
432
+
433
+ def merge_er(seg: List) -> List:
434
+ new_seg = []
435
+ for i, word in enumerate(seg):
436
+ if i - 1 >= 0 and word == "儿":
437
+ new_seg[-1] = new_seg[-1] + seg[i]
438
+ else:
439
+ new_seg.append(word)
440
+ return new_seg
441
+
442
+
443
+ def merge_yi(seg: List) -> List:
444
+ new_seg = []
445
+ # function 1
446
+ for i, word in enumerate(seg):
447
+ if (
448
+ i - 1 >= 0
449
+ and word == "一"
450
+ and i + 1 < len(seg)
451
+ and seg[i - 1] == seg[i + 1]
452
+ ):
453
+ if i - 1 < len(new_seg):
454
+ new_seg[i - 1] = new_seg[i - 1] + "一" + new_seg[i - 1]
455
+ else:
456
+ new_seg.append(word)
457
+ new_seg.append(seg[i + 1])
458
+ else:
459
+ if i - 2 >= 0 and seg[i - 1] == "一" and seg[i - 2] == word:
460
+ continue
461
+ else:
462
+ new_seg.append(word)
463
+ seg = new_seg
464
+ new_seg = []
465
+ isnumeric_flag = False
466
+ for i, word in enumerate(seg):
467
+ if all([item.isnumeric() for item in word]) and not isnumeric_flag:
468
+ isnumeric_flag = True
469
+ new_seg.append(word)
470
+ else:
471
+ new_seg.append(word)
472
+ seg = new_seg
473
+ new_seg = []
474
+ # function 2
475
+ for i, word in enumerate(seg):
476
+ if new_seg and new_seg[-1] == "一":
477
+ new_seg[-1] = new_seg[-1] + word
478
+ else:
479
+ new_seg.append(word)
480
+ return new_seg
481
+
482
+
483
+ # Word Segmentation, and convert Chinese pronunciation to pinyin (bopomofo)
484
+ def chinese_to_bopomofo(text_short, sentence):
485
+ # bopomofos = conv(text_short)
486
+ words = jieba.lcut(text_short, cut_all=False)
487
+ words = merge_yi(words)
488
+ words = merge_bu(words)
489
+ words = merge_er(words)
490
+ text = ""
491
+
492
+ char_index = 0
493
+ for word in words:
494
+ bopomofos = []
495
+ if word in word_pinyin_dict and word not in poly_dict:
496
+ pinyin = word_pinyin_dict[word]
497
+ for py in pinyin.split(" "):
498
+ if py[:-1] in pinyin_2_bopomofo_dict and py[-1] in tone_dict:
499
+ bopomofos.append(
500
+ pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
501
+ )
502
+ if BLANK_LEVEL == 1:
503
+ bopomofos.append("_")
504
+ else:
505
+ bopomofos_lazy = lazy_pinyin(word, BOPOMOFO)
506
+ bopomofos += bopomofos_lazy
507
+ if BLANK_LEVEL == 1:
508
+ bopomofos.append("_")
509
+ else:
510
+ for i in range(len(word)):
511
+ c = word[i]
512
+ if c in poly_dict:
513
+ poly_pinyin = g2pw_poly_predict.predict_process(
514
+ [text_short, char_index + i]
515
+ )[0]
516
+ py = poly_pinyin[2:-1]
517
+ bopomofos.append(
518
+ pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
519
+ )
520
+ if BLANK_LEVEL == 1:
521
+ bopomofos.append("_")
522
+ elif c in word_pinyin_dict:
523
+ py = word_pinyin_dict[c]
524
+ bopomofos.append(
525
+ pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
526
+ )
527
+ if BLANK_LEVEL == 1:
528
+ bopomofos.append("_")
529
+ else:
530
+ bopomofos.append(c)
531
+ if BLANK_LEVEL == 1:
532
+ bopomofos.append("_")
533
+ if BLANK_LEVEL == 2:
534
+ bopomofos.append("_")
535
+ char_index += len(word)
536
+
537
+ if (
538
+ len(word) == 3
539
+ and bopomofos[0][-1] == "ˇ"
540
+ and bopomofos[1][-1] == "ˇ"
541
+ and bopomofos[-1][-1] == "ˇ"
542
+ ):
543
+ bopomofos[0] = bopomofos[0] + "ˊ"
544
+ bopomofos[1] = bopomofos[1] + "ˊ"
545
+ if len(word) == 2 and bopomofos[0][-1] == "ˇ" and bopomofos[-1][-1] == "ˇ":
546
+ bopomofos[0] = bopomofos[0][:-1] + "ˊ"
547
+ bopomofos = bu_sandhi(word, bopomofos)
548
+ bopomofos = yi_sandhi(word, bopomofos)
549
+ bopomofos = er_sandhi(word, bopomofos)
550
+ if not re.search("[\u4e00-\u9fff]", word):
551
+ text += "|" + word
552
+ continue
553
+ for i in range(len(bopomofos)):
554
+ bopomofos[i] = re.sub(r"([\u3105-\u3129])$", r"\1ˉ", bopomofos[i])
555
+ if text != "":
556
+ text += "|"
557
+ text += "|".join(bopomofos)
558
+ return text
559
+
560
+
561
+ # Convert latin pronunciation to pinyin (bopomofo)
562
+ def latin_to_bopomofo(text):
563
+ for regex, replacement in _latin_to_bopomofo:
564
+ text = re.sub(regex, replacement, text)
565
+ return text
566
+
567
+
568
+ # Convert pinyin (bopomofo) to IPA
569
+ def bopomofo_to_ipa(text):
570
+ for regex, replacement in _bopomofo_to_ipa:
571
+ text = re.sub(regex, replacement, text)
572
+ return text
573
+
574
+
575
+ def _chinese_to_ipa(text, sentence):
576
+ text = number_to_chinese(text.strip())
577
+ text = normalization(text)
578
+ text = chinese_to_bopomofo(text, sentence)
579
+ # pinyin = bpmf_to_pinyin(text)
580
+ text = latin_to_bopomofo(text)
581
+ text = bopomofo_to_ipa(text)
582
+ text = re.sub("([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ\2", text)
583
+ text = re.sub("([s][⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ\2", text)
584
+ text = re.sub(r"^\||[^\w\s_,\.\?!;:\'…\|→↓↑⁼ʰ`]", "", text)
585
+ text = re.sub(r"([,\.\?!;:\'…])", r"|\1|", text)
586
+ text = re.sub(r"\|+", "|", text)
587
+ text = text.rstrip("|")
588
+ return text
589
+
590
+
591
+ # Convert Chinese to IPA
592
+ def chinese_to_ipa(text, sentence, text_tokenizer):
593
+ # phonemes = text_tokenizer(text.strip())
594
+ if type(text) == str:
595
+ return _chinese_to_ipa(text, sentence)
596
+ else:
597
+ result_ph = []
598
+ for t in text:
599
+ result_ph.append(_chinese_to_ipa(t, sentence))
600
+ return result_ph
soundsation/g2p/g2p/text_tokenizers.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+ import os
8
+ from typing import List, Pattern, Union
9
+ from phonemizer.utils import list2str, str2list
10
+ from phonemizer.backend import EspeakBackend
11
+ from phonemizer.backend.espeak.language_switch import LanguageSwitch
12
+ from phonemizer.backend.espeak.words_mismatch import WordMismatch
13
+ from phonemizer.punctuation import Punctuation
14
+ from phonemizer.separator import Separator
15
+
16
+
17
+ class TextTokenizer:
18
+ """Phonemize Text."""
19
+
20
+ def __init__(
21
+ self,
22
+ language="en-us",
23
+ backend="espeak",
24
+ separator=Separator(word="|_|", syllable="-", phone="|"),
25
+ preserve_punctuation=True,
26
+ with_stress: bool = False,
27
+ tie: Union[bool, str] = False,
28
+ language_switch: LanguageSwitch = "remove-flags",
29
+ words_mismatch: WordMismatch = "ignore",
30
+ ) -> None:
31
+ self.preserve_punctuation_marks = ",.?!;:'…"
32
+ self.backend = EspeakBackend(
33
+ language,
34
+ punctuation_marks=self.preserve_punctuation_marks,
35
+ preserve_punctuation=preserve_punctuation,
36
+ with_stress=with_stress,
37
+ tie=tie,
38
+ language_switch=language_switch,
39
+ words_mismatch=words_mismatch,
40
+ )
41
+
42
+ self.separator = separator
43
+
44
+ # convert chinese punctuation to english punctuation
45
+ def convert_chinese_punctuation(self, text: str) -> str:
46
+ text = text.replace(",", ",")
47
+ text = text.replace("。", ".")
48
+ text = text.replace("!", "!")
49
+ text = text.replace("?", "?")
50
+ text = text.replace(";", ";")
51
+ text = text.replace(":", ":")
52
+ text = text.replace("、", ",")
53
+ text = text.replace("‘", "'")
54
+ text = text.replace("’", "'")
55
+ text = text.replace("⋯", "…")
56
+ text = text.replace("···", "…")
57
+ text = text.replace("・・・", "…")
58
+ text = text.replace("...", "…")
59
+ return text
60
+
61
+ def __call__(self, text, strip=True) -> List[str]:
62
+
63
+ text_type = type(text)
64
+ normalized_text = []
65
+ for line in str2list(text):
66
+ line = self.convert_chinese_punctuation(line.strip())
67
+ line = re.sub(r"[^\w\s_,\.\?!;:\'…]", "", line)
68
+ line = re.sub(r"\s*([,\.\?!;:\'…])\s*", r"\1", line)
69
+ line = re.sub(r"\s+", " ", line)
70
+ normalized_text.append(line)
71
+ # print("Normalized test: ", normalized_text[0])
72
+ phonemized = self.backend.phonemize(
73
+ normalized_text, separator=self.separator, strip=strip, njobs=1
74
+ )
75
+ if text_type == str:
76
+ phonemized = re.sub(r"([,\.\?!;:\'…])", r"|\1|", list2str(phonemized))
77
+ phonemized = re.sub(r"\|+", "|", phonemized)
78
+ phonemized = phonemized.rstrip("|")
79
+ else:
80
+ for i in range(len(phonemized)):
81
+ phonemized[i] = re.sub(r"([,\.\?!;:\'…])", r"|\1|", phonemized[i])
82
+ phonemized[i] = re.sub(r"\|+", "|", phonemized[i])
83
+ phonemized[i] = phonemized[i].rstrip("|")
84
+ return phonemized
soundsation/g2p/g2p/vocab.json ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab": {
3
+ ",": 0,
4
+ ".": 1,
5
+ "?": 2,
6
+ "!": 3,
7
+ "_": 4,
8
+ "iː": 5,
9
+ "ɪ": 6,
10
+ "ɜː": 7,
11
+ "ɚ": 8,
12
+ "oːɹ": 9,
13
+ "ɔː": 10,
14
+ "ɔːɹ": 11,
15
+ "ɑː": 12,
16
+ "uː": 13,
17
+ "ʊ": 14,
18
+ "ɑːɹ": 15,
19
+ "ʌ": 16,
20
+ "ɛ": 17,
21
+ "æ": 18,
22
+ "eɪ": 19,
23
+ "aɪ": 20,
24
+ "ɔɪ": 21,
25
+ "aʊ": 22,
26
+ "oʊ": 23,
27
+ "ɪɹ": 24,
28
+ "ɛɹ": 25,
29
+ "ʊɹ": 26,
30
+ "p": 27,
31
+ "b": 28,
32
+ "t": 29,
33
+ "d": 30,
34
+ "k": 31,
35
+ "ɡ": 32,
36
+ "f": 33,
37
+ "v": 34,
38
+ "θ": 35,
39
+ "ð": 36,
40
+ "s": 37,
41
+ "z": 38,
42
+ "ʃ": 39,
43
+ "ʒ": 40,
44
+ "h": 41,
45
+ "tʃ": 42,
46
+ "dʒ": 43,
47
+ "m": 44,
48
+ "n": 45,
49
+ "ŋ": 46,
50
+ "j": 47,
51
+ "w": 48,
52
+ "ɹ": 49,
53
+ "l": 50,
54
+ "tɹ": 51,
55
+ "dɹ": 52,
56
+ "ts": 53,
57
+ "dz": 54,
58
+ "i": 55,
59
+ "ɔ": 56,
60
+ "ə": 57,
61
+ "ɾ": 58,
62
+ "iə": 59,
63
+ "r": 60,
64
+ "u": 61,
65
+ "oː": 62,
66
+ "ɛː": 63,
67
+ "ɪː": 64,
68
+ "aɪə": 65,
69
+ "aɪɚ": 66,
70
+ "ɑ̃": 67,
71
+ "ç": 68,
72
+ "ɔ̃": 69,
73
+ "ææ": 70,
74
+ "ɐɐ": 71,
75
+ "ɡʲ": 72,
76
+ "nʲ": 73,
77
+ "iːː": 74,
78
+
79
+ "p⁼": 75,
80
+ "pʰ": 76,
81
+ "t⁼": 77,
82
+ "tʰ": 78,
83
+ "k⁼": 79,
84
+ "kʰ": 80,
85
+ "x": 81,
86
+ "tʃ⁼": 82,
87
+ "tʃʰ": 83,
88
+ "ts`⁼": 84,
89
+ "ts`ʰ": 85,
90
+ "s`": 86,
91
+ "ɹ`": 87,
92
+ "ts⁼": 88,
93
+ "tsʰ": 89,
94
+ "p⁼wo": 90,
95
+ "p⁼wo→": 91,
96
+ "p⁼wo↑": 92,
97
+ "p⁼wo↓↑": 93,
98
+ "p⁼wo↓": 94,
99
+ "pʰwo": 95,
100
+ "pʰwo→": 96,
101
+ "pʰwo↑": 97,
102
+ "pʰwo↓↑": 98,
103
+ "pʰwo↓": 99,
104
+ "mwo": 100,
105
+ "mwo→": 101,
106
+ "mwo↑": 102,
107
+ "mwo↓↑": 103,
108
+ "mwo↓": 104,
109
+ "fwo": 105,
110
+ "fwo→": 106,
111
+ "fwo↑": 107,
112
+ "fwo↓↑": 108,
113
+ "fwo↓": 109,
114
+ "jɛn": 110,
115
+ "jɛn→": 111,
116
+ "jɛn↑": 112,
117
+ "jɛn↓↑": 113,
118
+ "jɛn↓": 114,
119
+ "ɥæn": 115,
120
+ "ɥæn→": 116,
121
+ "ɥæn↑": 117,
122
+ "ɥæn↓↑": 118,
123
+ "ɥæn↓": 119,
124
+ "in": 120,
125
+ "in→": 121,
126
+ "in↑": 122,
127
+ "in↓↑": 123,
128
+ "in↓": 124,
129
+ "ɥn": 125,
130
+ "ɥn→": 126,
131
+ "ɥn↑": 127,
132
+ "ɥn↓↑": 128,
133
+ "ɥn↓": 129,
134
+ "iŋ": 130,
135
+ "iŋ→": 131,
136
+ "iŋ↑": 132,
137
+ "iŋ↓↑": 133,
138
+ "iŋ↓": 134,
139
+ "ʊŋ": 135,
140
+ "ʊŋ→": 136,
141
+ "ʊŋ↑": 137,
142
+ "ʊŋ↓↑": 138,
143
+ "ʊŋ↓": 139,
144
+ "jʊŋ": 140,
145
+ "jʊŋ→": 141,
146
+ "jʊŋ↑": 142,
147
+ "jʊŋ↓↑": 143,
148
+ "jʊŋ↓": 144,
149
+ "ia": 145,
150
+ "ia→": 146,
151
+ "ia↑": 147,
152
+ "ia↓↑": 148,
153
+ "ia↓": 149,
154
+ "iɛ": 150,
155
+ "iɛ→": 151,
156
+ "iɛ↑": 152,
157
+ "iɛ↓↑": 153,
158
+ "iɛ↓": 154,
159
+ "iɑʊ": 155,
160
+ "iɑʊ→": 156,
161
+ "iɑʊ↑": 157,
162
+ "iɑʊ↓↑": 158,
163
+ "iɑʊ↓": 159,
164
+ "ioʊ": 160,
165
+ "ioʊ→": 161,
166
+ "ioʊ↑": 162,
167
+ "ioʊ↓↑": 163,
168
+ "ioʊ↓": 164,
169
+ "iɑŋ": 165,
170
+ "iɑŋ→": 166,
171
+ "iɑŋ↑": 167,
172
+ "iɑŋ↓↑": 168,
173
+ "iɑŋ↓": 169,
174
+ "ua": 170,
175
+ "ua→": 171,
176
+ "ua↑": 172,
177
+ "ua↓↑": 173,
178
+ "ua↓": 174,
179
+ "uo": 175,
180
+ "uo→": 176,
181
+ "uo↑": 177,
182
+ "uo↓↑": 178,
183
+ "uo↓": 179,
184
+ "uaɪ": 180,
185
+ "uaɪ→": 181,
186
+ "uaɪ↑": 182,
187
+ "uaɪ↓↑": 183,
188
+ "uaɪ↓": 184,
189
+ "ueɪ": 185,
190
+ "ueɪ→": 186,
191
+ "ueɪ↑": 187,
192
+ "ueɪ↓↑": 188,
193
+ "ueɪ↓": 189,
194
+ "uan": 190,
195
+ "uan→": 191,
196
+ "uan↑": 192,
197
+ "uan↓↑": 193,
198
+ "uan↓": 194,
199
+ "uən": 195,
200
+ "uən→": 196,
201
+ "uən↑": 197,
202
+ "uən↓↑": 198,
203
+ "uən↓": 199,
204
+ "uɑŋ": 200,
205
+ "uɑŋ→": 201,
206
+ "uɑŋ↑": 202,
207
+ "uɑŋ↓↑": 203,
208
+ "uɑŋ↓": 204,
209
+ "ɥɛ": 205,
210
+ "ɥɛ→": 206,
211
+ "ɥɛ↑": 207,
212
+ "ɥɛ↓↑": 208,
213
+ "ɥɛ↓": 209,
214
+ "a": 210,
215
+ "a→": 211,
216
+ "a↑": 212,
217
+ "a↓↑": 213,
218
+ "a↓": 214,
219
+ "o": 215,
220
+ "o→": 216,
221
+ "o↑": 217,
222
+ "o↓↑": 218,
223
+ "o↓": 219,
224
+ "ə→": 220,
225
+ "ə↑": 221,
226
+ "ə↓↑": 222,
227
+ "ə↓": 223,
228
+ "ɛ→": 224,
229
+ "ɛ↑": 225,
230
+ "ɛ↓↑": 226,
231
+ "ɛ↓": 227,
232
+ "aɪ→": 228,
233
+ "aɪ↑": 229,
234
+ "aɪ↓↑": 230,
235
+ "aɪ↓": 231,
236
+ "eɪ→": 232,
237
+ "eɪ↑": 233,
238
+ "eɪ↓↑": 234,
239
+ "eɪ↓": 235,
240
+ "ɑʊ": 236,
241
+ "ɑʊ→": 237,
242
+ "ɑʊ↑": 238,
243
+ "ɑʊ↓↑": 239,
244
+ "ɑʊ↓": 240,
245
+ "oʊ→": 241,
246
+ "oʊ↑": 242,
247
+ "oʊ↓↑": 243,
248
+ "oʊ↓": 244,
249
+ "an": 245,
250
+ "an→": 246,
251
+ "an↑": 247,
252
+ "an↓↑": 248,
253
+ "an↓": 249,
254
+ "ən": 250,
255
+ "ən→": 251,
256
+ "ən↑": 252,
257
+ "ən↓↑": 253,
258
+ "ən↓": 254,
259
+ "ɑŋ": 255,
260
+ "ɑŋ→": 256,
261
+ "ɑŋ↑": 257,
262
+ "ɑŋ↓↑": 258,
263
+ "ɑŋ↓": 259,
264
+ "əŋ": 260,
265
+ "əŋ→": 261,
266
+ "əŋ↑": 262,
267
+ "əŋ↓↑": 263,
268
+ "əŋ↓": 264,
269
+ "əɹ": 265,
270
+ "əɹ→": 266,
271
+ "əɹ↑": 267,
272
+ "əɹ↓↑": 268,
273
+ "əɹ↓": 269,
274
+ "i→": 270,
275
+ "i↑": 271,
276
+ "i↓↑": 272,
277
+ "i↓": 273,
278
+ "u→": 274,
279
+ "u↑": 275,
280
+ "u↓↑": 276,
281
+ "u↓": 277,
282
+ "ɥ": 278,
283
+ "ɥ→": 279,
284
+ "ɥ↑": 280,
285
+ "ɥ↓↑": 281,
286
+ "ɥ↓": 282,
287
+ "ts`⁼ɹ": 283,
288
+ "ts`⁼ɹ→": 284,
289
+ "ts`⁼ɹ↑": 285,
290
+ "ts`⁼ɹ↓↑": 286,
291
+ "ts`⁼ɹ↓": 287,
292
+ "ts`ʰɹ": 288,
293
+ "ts`ʰɹ→": 289,
294
+ "ts`ʰɹ↑": 290,
295
+ "ts`ʰɹ↓↑": 291,
296
+ "ts`ʰɹ↓": 292,
297
+ "s`ɹ": 293,
298
+ "s`ɹ→": 294,
299
+ "s`ɹ↑": 295,
300
+ "s`ɹ↓↑": 296,
301
+ "s`ɹ���": 297,
302
+ "ɹ`ɹ": 298,
303
+ "ɹ`ɹ→": 299,
304
+ "ɹ`ɹ↑": 300,
305
+ "ɹ`ɹ↓↑": 301,
306
+ "ɹ`ɹ↓": 302,
307
+ "ts⁼ɹ": 303,
308
+ "ts⁼ɹ→": 304,
309
+ "ts⁼ɹ↑": 305,
310
+ "ts⁼ɹ↓↑": 306,
311
+ "ts⁼ɹ↓": 307,
312
+ "tsʰɹ": 308,
313
+ "tsʰɹ→": 309,
314
+ "tsʰɹ↑": 310,
315
+ "tsʰɹ↓↑": 311,
316
+ "tsʰɹ↓": 312,
317
+ "sɹ": 313,
318
+ "sɹ→": 314,
319
+ "sɹ↑": 315,
320
+ "sɹ↓↑": 316,
321
+ "sɹ↓": 317,
322
+
323
+ "ɯ": 318,
324
+ "e": 319,
325
+ "aː": 320,
326
+ "ɯː": 321,
327
+ "eː": 322,
328
+ "ç": 323,
329
+ "ɸ": 324,
330
+ "ɰᵝ": 325,
331
+ "ɴ": 326,
332
+ "g": 327,
333
+ "dʑ": 328,
334
+ "q": 329,
335
+ "ː": 330,
336
+ "bj": 331,
337
+ "tɕ": 332,
338
+ "dej": 333,
339
+ "tej": 334,
340
+ "gj": 335,
341
+ "gɯ": 336,
342
+ "çj": 337,
343
+ "kj": 338,
344
+ "kɯ": 339,
345
+ "mj": 340,
346
+ "nj": 341,
347
+ "pj": 342,
348
+ "ɾj": 343,
349
+ "ɕ": 344,
350
+ "tsɯ": 345,
351
+
352
+ "ɐ": 346,
353
+ "ɑ": 347,
354
+ "ɒ": 348,
355
+ "ɜ": 349,
356
+ "ɫ": 350,
357
+ "ʑ": 351,
358
+ "ʲ": 352,
359
+
360
+ "y": 353,
361
+ "ø": 354,
362
+ "œ": 355,
363
+ "ʁ": 356,
364
+ "̃": 357,
365
+ "ɲ": 358,
366
+
367
+ ":": 359,
368
+ ";": 360,
369
+ "'": 361,
370
+ "…": 362
371
+ }
372
+ }
soundsation/g2p/g2p_generation.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import sys
8
+
9
+ from soundsation.g2p.g2p import PhonemeBpeTokenizer
10
+ from soundsation.g2p.utils.g2p import phonemizer_g2p
11
+ import tqdm
12
+ from typing import List
13
+ import json
14
+ import os
15
+ import re
16
+
17
+
18
+ def ph_g2p(text, language):
19
+
20
+ return phonemizer_g2p(text=text, language=language)
21
+
22
+
23
+ def g2p(text, sentence, language):
24
+
25
+ return text_tokenizer.tokenize(text=text, sentence=sentence, language=language)
26
+
27
+
28
+ def is_chinese(char):
29
+ if char >= "\u4e00" and char <= "\u9fa5":
30
+ return True
31
+ else:
32
+ return False
33
+
34
+
35
+ def is_alphabet(char):
36
+ if (char >= "\u0041" and char <= "\u005a") or (
37
+ char >= "\u0061" and char <= "\u007a"
38
+ ):
39
+ return True
40
+ else:
41
+ return False
42
+
43
+
44
+ def is_other(char):
45
+ if not (is_chinese(char) or is_alphabet(char)):
46
+ return True
47
+ else:
48
+ return False
49
+
50
+
51
+ def get_segment(text: str) -> List[str]:
52
+ # sentence --> [ch_part, en_part, ch_part, ...]
53
+ segments = []
54
+ types = []
55
+ flag = 0
56
+ temp_seg = ""
57
+ temp_lang = ""
58
+
59
+ # Determine the type of each character. type: blank, chinese, alphabet, number, unk and point.
60
+ for i, ch in enumerate(text):
61
+ if is_chinese(ch):
62
+ types.append("zh")
63
+ elif is_alphabet(ch):
64
+ types.append("en")
65
+ else:
66
+ types.append("other")
67
+
68
+ assert len(types) == len(text)
69
+
70
+ for i in range(len(types)):
71
+ # find the first char of the seg
72
+ if flag == 0:
73
+ temp_seg += text[i]
74
+ temp_lang = types[i]
75
+ flag = 1
76
+ else:
77
+ if temp_lang == "other":
78
+ if types[i] == temp_lang:
79
+ temp_seg += text[i]
80
+ else:
81
+ temp_seg += text[i]
82
+ temp_lang = types[i]
83
+ else:
84
+ if types[i] == temp_lang:
85
+ temp_seg += text[i]
86
+ elif types[i] == "other":
87
+ temp_seg += text[i]
88
+ else:
89
+ segments.append((temp_seg, temp_lang))
90
+ temp_seg = text[i]
91
+ temp_lang = types[i]
92
+ flag = 1
93
+
94
+ segments.append((temp_seg, temp_lang))
95
+ return segments
96
+
97
+
98
+ def chn_eng_g2p(text: str):
99
+ # now only en and ch
100
+ segments = get_segment(text)
101
+ all_phoneme = ""
102
+ all_tokens = []
103
+
104
+ for index in range(len(segments)):
105
+ seg = segments[index]
106
+ phoneme, token = g2p(seg[0], text, seg[1])
107
+ all_phoneme += phoneme + "|"
108
+ all_tokens += token
109
+
110
+ if seg[1] == "en" and index == len(segments) - 1 and all_phoneme[-2] == "_":
111
+ all_phoneme = all_phoneme[:-2]
112
+ all_tokens = all_tokens[:-1]
113
+ return all_phoneme, all_tokens
114
+
115
+
116
+ text_tokenizer = PhonemeBpeTokenizer()
117
+ with open("./soundsation/g2p/g2p/vocab.json", "r") as f:
118
+ json_data = f.read()
119
+ data = json.loads(json_data)
120
+ vocab = data["vocab"]
121
+
122
+ if __name__ == '__main__':
123
+ phone, token = chn_eng_g2p("你好,hello world")
124
+ phone, token = chn_eng_g2p("你好,hello world, Bonjour, 테스트 해 보겠습니다, 五月雨緑")
125
+ print(phone)
126
+ print(token)
127
+
128
+ #phone, token = text_tokenizer.tokenize("你好,hello world, Bonjour, 테스트 해 보겠습니다, 五月雨緑", "", "auto")
129
+ phone, token = text_tokenizer.tokenize("緑", "", "auto")
130
+ #phone, token = text_tokenizer.tokenize("आइए इसका परीक्षण करें", "", "auto")
131
+ #phone, token = text_tokenizer.tokenize("आइए इसका परीक्षण करें", "", "other")
132
+ print(phone)
133
+ print(token)
soundsation/g2p/sources/bpmf_2_pinyin.txt ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ b ㄅ
2
+ p ㄆ
3
+ m ㄇ
4
+ f ㄈ
5
+ d ㄉ
6
+ t ㄊ
7
+ n ㄋ
8
+ l ㄌ
9
+ g ㄍ
10
+ k ㄎ
11
+ h ㄏ
12
+ j ㄐ
13
+ q ㄑ
14
+ x ㄒ
15
+ zh ㄓ
16
+ ch ㄔ
17
+ sh ㄕ
18
+ r ㄖ
19
+ z ㄗ
20
+ c ㄘ
21
+ s ㄙ
22
+ i ㄧ
23
+ u ㄨ
24
+ v ㄩ
25
+ a ㄚ
26
+ o ㄛ
27
+ e ㄜ
28
+ e ㄝ
29
+ ai ㄞ
30
+ ei ㄟ
31
+ ao ㄠ
32
+ ou ㄡ
33
+ an ㄢ
34
+ en ㄣ
35
+ ang ㄤ
36
+ eng ㄥ
37
+ er ㄦ
38
+ 2 ˊ
39
+ 3 ˇ
40
+ 4 ˋ
41
+ 0 ˙
soundsation/g2p/sources/chinese_lexicon.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3a7685d1c3e68eb2fa304bfc63e90c90c3c1a1948839a5b1b507b2131b3e2fb
3
+ size 14779443
soundsation/g2p/sources/g2p_chinese_model/config.json ADDED
@@ -0,0 +1,819 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/BERT-POLY-v2/pretrained_models/mini_bert",
3
+ "architectures": [
4
+ "BertPoly"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "directionality": "bidi",
9
+ "gradient_checkpointing": false,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 384,
13
+ "id2label": {
14
+ "0": "LABEL_0",
15
+ "1": "LABEL_1",
16
+ "2": "LABEL_2",
17
+ "3": "LABEL_3",
18
+ "4": "LABEL_4",
19
+ "5": "LABEL_5",
20
+ "6": "LABEL_6",
21
+ "7": "LABEL_7",
22
+ "8": "LABEL_8",
23
+ "9": "LABEL_9",
24
+ "10": "LABEL_10",
25
+ "11": "LABEL_11",
26
+ "12": "LABEL_12",
27
+ "13": "LABEL_13",
28
+ "14": "LABEL_14",
29
+ "15": "LABEL_15",
30
+ "16": "LABEL_16",
31
+ "17": "LABEL_17",
32
+ "18": "LABEL_18",
33
+ "19": "LABEL_19",
34
+ "20": "LABEL_20",
35
+ "21": "LABEL_21",
36
+ "22": "LABEL_22",
37
+ "23": "LABEL_23",
38
+ "24": "LABEL_24",
39
+ "25": "LABEL_25",
40
+ "26": "LABEL_26",
41
+ "27": "LABEL_27",
42
+ "28": "LABEL_28",
43
+ "29": "LABEL_29",
44
+ "30": "LABEL_30",
45
+ "31": "LABEL_31",
46
+ "32": "LABEL_32",
47
+ "33": "LABEL_33",
48
+ "34": "LABEL_34",
49
+ "35": "LABEL_35",
50
+ "36": "LABEL_36",
51
+ "37": "LABEL_37",
52
+ "38": "LABEL_38",
53
+ "39": "LABEL_39",
54
+ "40": "LABEL_40",
55
+ "41": "LABEL_41",
56
+ "42": "LABEL_42",
57
+ "43": "LABEL_43",
58
+ "44": "LABEL_44",
59
+ "45": "LABEL_45",
60
+ "46": "LABEL_46",
61
+ "47": "LABEL_47",
62
+ "48": "LABEL_48",
63
+ "49": "LABEL_49",
64
+ "50": "LABEL_50",
65
+ "51": "LABEL_51",
66
+ "52": "LABEL_52",
67
+ "53": "LABEL_53",
68
+ "54": "LABEL_54",
69
+ "55": "LABEL_55",
70
+ "56": "LABEL_56",
71
+ "57": "LABEL_57",
72
+ "58": "LABEL_58",
73
+ "59": "LABEL_59",
74
+ "60": "LABEL_60",
75
+ "61": "LABEL_61",
76
+ "62": "LABEL_62",
77
+ "63": "LABEL_63",
78
+ "64": "LABEL_64",
79
+ "65": "LABEL_65",
80
+ "66": "LABEL_66",
81
+ "67": "LABEL_67",
82
+ "68": "LABEL_68",
83
+ "69": "LABEL_69",
84
+ "70": "LABEL_70",
85
+ "71": "LABEL_71",
86
+ "72": "LABEL_72",
87
+ "73": "LABEL_73",
88
+ "74": "LABEL_74",
89
+ "75": "LABEL_75",
90
+ "76": "LABEL_76",
91
+ "77": "LABEL_77",
92
+ "78": "LABEL_78",
93
+ "79": "LABEL_79",
94
+ "80": "LABEL_80",
95
+ "81": "LABEL_81",
96
+ "82": "LABEL_82",
97
+ "83": "LABEL_83",
98
+ "84": "LABEL_84",
99
+ "85": "LABEL_85",
100
+ "86": "LABEL_86",
101
+ "87": "LABEL_87",
102
+ "88": "LABEL_88",
103
+ "89": "LABEL_89",
104
+ "90": "LABEL_90",
105
+ "91": "LABEL_91",
106
+ "92": "LABEL_92",
107
+ "93": "LABEL_93",
108
+ "94": "LABEL_94",
109
+ "95": "LABEL_95",
110
+ "96": "LABEL_96",
111
+ "97": "LABEL_97",
112
+ "98": "LABEL_98",
113
+ "99": "LABEL_99",
114
+ "100": "LABEL_100",
115
+ "101": "LABEL_101",
116
+ "102": "LABEL_102",
117
+ "103": "LABEL_103",
118
+ "104": "LABEL_104",
119
+ "105": "LABEL_105",
120
+ "106": "LABEL_106",
121
+ "107": "LABEL_107",
122
+ "108": "LABEL_108",
123
+ "109": "LABEL_109",
124
+ "110": "LABEL_110",
125
+ "111": "LABEL_111",
126
+ "112": "LABEL_112",
127
+ "113": "LABEL_113",
128
+ "114": "LABEL_114",
129
+ "115": "LABEL_115",
130
+ "116": "LABEL_116",
131
+ "117": "LABEL_117",
132
+ "118": "LABEL_118",
133
+ "119": "LABEL_119",
134
+ "120": "LABEL_120",
135
+ "121": "LABEL_121",
136
+ "122": "LABEL_122",
137
+ "123": "LABEL_123",
138
+ "124": "LABEL_124",
139
+ "125": "LABEL_125",
140
+ "126": "LABEL_126",
141
+ "127": "LABEL_127",
142
+ "128": "LABEL_128",
143
+ "129": "LABEL_129",
144
+ "130": "LABEL_130",
145
+ "131": "LABEL_131",
146
+ "132": "LABEL_132",
147
+ "133": "LABEL_133",
148
+ "134": "LABEL_134",
149
+ "135": "LABEL_135",
150
+ "136": "LABEL_136",
151
+ "137": "LABEL_137",
152
+ "138": "LABEL_138",
153
+ "139": "LABEL_139",
154
+ "140": "LABEL_140",
155
+ "141": "LABEL_141",
156
+ "142": "LABEL_142",
157
+ "143": "LABEL_143",
158
+ "144": "LABEL_144",
159
+ "145": "LABEL_145",
160
+ "146": "LABEL_146",
161
+ "147": "LABEL_147",
162
+ "148": "LABEL_148",
163
+ "149": "LABEL_149",
164
+ "150": "LABEL_150",
165
+ "151": "LABEL_151",
166
+ "152": "LABEL_152",
167
+ "153": "LABEL_153",
168
+ "154": "LABEL_154",
169
+ "155": "LABEL_155",
170
+ "156": "LABEL_156",
171
+ "157": "LABEL_157",
172
+ "158": "LABEL_158",
173
+ "159": "LABEL_159",
174
+ "160": "LABEL_160",
175
+ "161": "LABEL_161",
176
+ "162": "LABEL_162",
177
+ "163": "LABEL_163",
178
+ "164": "LABEL_164",
179
+ "165": "LABEL_165",
180
+ "166": "LABEL_166",
181
+ "167": "LABEL_167",
182
+ "168": "LABEL_168",
183
+ "169": "LABEL_169",
184
+ "170": "LABEL_170",
185
+ "171": "LABEL_171",
186
+ "172": "LABEL_172",
187
+ "173": "LABEL_173",
188
+ "174": "LABEL_174",
189
+ "175": "LABEL_175",
190
+ "176": "LABEL_176",
191
+ "177": "LABEL_177",
192
+ "178": "LABEL_178",
193
+ "179": "LABEL_179",
194
+ "180": "LABEL_180",
195
+ "181": "LABEL_181",
196
+ "182": "LABEL_182",
197
+ "183": "LABEL_183",
198
+ "184": "LABEL_184",
199
+ "185": "LABEL_185",
200
+ "186": "LABEL_186",
201
+ "187": "LABEL_187",
202
+ "188": "LABEL_188",
203
+ "189": "LABEL_189",
204
+ "190": "LABEL_190",
205
+ "191": "LABEL_191",
206
+ "192": "LABEL_192",
207
+ "193": "LABEL_193",
208
+ "194": "LABEL_194",
209
+ "195": "LABEL_195",
210
+ "196": "LABEL_196",
211
+ "197": "LABEL_197",
212
+ "198": "LABEL_198",
213
+ "199": "LABEL_199",
214
+ "200": "LABEL_200",
215
+ "201": "LABEL_201",
216
+ "202": "LABEL_202",
217
+ "203": "LABEL_203",
218
+ "204": "LABEL_204",
219
+ "205": "LABEL_205",
220
+ "206": "LABEL_206",
221
+ "207": "LABEL_207",
222
+ "208": "LABEL_208",
223
+ "209": "LABEL_209",
224
+ "210": "LABEL_210",
225
+ "211": "LABEL_211",
226
+ "212": "LABEL_212",
227
+ "213": "LABEL_213",
228
+ "214": "LABEL_214",
229
+ "215": "LABEL_215",
230
+ "216": "LABEL_216",
231
+ "217": "LABEL_217",
232
+ "218": "LABEL_218",
233
+ "219": "LABEL_219",
234
+ "220": "LABEL_220",
235
+ "221": "LABEL_221",
236
+ "222": "LABEL_222",
237
+ "223": "LABEL_223",
238
+ "224": "LABEL_224",
239
+ "225": "LABEL_225",
240
+ "226": "LABEL_226",
241
+ "227": "LABEL_227",
242
+ "228": "LABEL_228",
243
+ "229": "LABEL_229",
244
+ "230": "LABEL_230",
245
+ "231": "LABEL_231",
246
+ "232": "LABEL_232",
247
+ "233": "LABEL_233",
248
+ "234": "LABEL_234",
249
+ "235": "LABEL_235",
250
+ "236": "LABEL_236",
251
+ "237": "LABEL_237",
252
+ "238": "LABEL_238",
253
+ "239": "LABEL_239",
254
+ "240": "LABEL_240",
255
+ "241": "LABEL_241",
256
+ "242": "LABEL_242",
257
+ "243": "LABEL_243",
258
+ "244": "LABEL_244",
259
+ "245": "LABEL_245",
260
+ "246": "LABEL_246",
261
+ "247": "LABEL_247",
262
+ "248": "LABEL_248",
263
+ "249": "LABEL_249",
264
+ "250": "LABEL_250",
265
+ "251": "LABEL_251",
266
+ "252": "LABEL_252",
267
+ "253": "LABEL_253",
268
+ "254": "LABEL_254",
269
+ "255": "LABEL_255",
270
+ "256": "LABEL_256",
271
+ "257": "LABEL_257",
272
+ "258": "LABEL_258",
273
+ "259": "LABEL_259",
274
+ "260": "LABEL_260",
275
+ "261": "LABEL_261",
276
+ "262": "LABEL_262",
277
+ "263": "LABEL_263",
278
+ "264": "LABEL_264",
279
+ "265": "LABEL_265",
280
+ "266": "LABEL_266",
281
+ "267": "LABEL_267",
282
+ "268": "LABEL_268",
283
+ "269": "LABEL_269",
284
+ "270": "LABEL_270",
285
+ "271": "LABEL_271",
286
+ "272": "LABEL_272",
287
+ "273": "LABEL_273",
288
+ "274": "LABEL_274",
289
+ "275": "LABEL_275",
290
+ "276": "LABEL_276",
291
+ "277": "LABEL_277",
292
+ "278": "LABEL_278",
293
+ "279": "LABEL_279",
294
+ "280": "LABEL_280",
295
+ "281": "LABEL_281",
296
+ "282": "LABEL_282",
297
+ "283": "LABEL_283",
298
+ "284": "LABEL_284",
299
+ "285": "LABEL_285",
300
+ "286": "LABEL_286",
301
+ "287": "LABEL_287",
302
+ "288": "LABEL_288",
303
+ "289": "LABEL_289",
304
+ "290": "LABEL_290",
305
+ "291": "LABEL_291",
306
+ "292": "LABEL_292",
307
+ "293": "LABEL_293",
308
+ "294": "LABEL_294",
309
+ "295": "LABEL_295",
310
+ "296": "LABEL_296",
311
+ "297": "LABEL_297",
312
+ "298": "LABEL_298",
313
+ "299": "LABEL_299",
314
+ "300": "LABEL_300",
315
+ "301": "LABEL_301",
316
+ "302": "LABEL_302",
317
+ "303": "LABEL_303",
318
+ "304": "LABEL_304",
319
+ "305": "LABEL_305",
320
+ "306": "LABEL_306",
321
+ "307": "LABEL_307",
322
+ "308": "LABEL_308",
323
+ "309": "LABEL_309",
324
+ "310": "LABEL_310",
325
+ "311": "LABEL_311",
326
+ "312": "LABEL_312",
327
+ "313": "LABEL_313",
328
+ "314": "LABEL_314",
329
+ "315": "LABEL_315",
330
+ "316": "LABEL_316",
331
+ "317": "LABEL_317",
332
+ "318": "LABEL_318",
333
+ "319": "LABEL_319",
334
+ "320": "LABEL_320",
335
+ "321": "LABEL_321",
336
+ "322": "LABEL_322",
337
+ "323": "LABEL_323",
338
+ "324": "LABEL_324",
339
+ "325": "LABEL_325",
340
+ "326": "LABEL_326",
341
+ "327": "LABEL_327",
342
+ "328": "LABEL_328",
343
+ "329": "LABEL_329",
344
+ "330": "LABEL_330",
345
+ "331": "LABEL_331",
346
+ "332": "LABEL_332",
347
+ "333": "LABEL_333",
348
+ "334": "LABEL_334",
349
+ "335": "LABEL_335",
350
+ "336": "LABEL_336",
351
+ "337": "LABEL_337",
352
+ "338": "LABEL_338",
353
+ "339": "LABEL_339",
354
+ "340": "LABEL_340",
355
+ "341": "LABEL_341",
356
+ "342": "LABEL_342",
357
+ "343": "LABEL_343",
358
+ "344": "LABEL_344",
359
+ "345": "LABEL_345",
360
+ "346": "LABEL_346",
361
+ "347": "LABEL_347",
362
+ "348": "LABEL_348",
363
+ "349": "LABEL_349",
364
+ "350": "LABEL_350",
365
+ "351": "LABEL_351",
366
+ "352": "LABEL_352",
367
+ "353": "LABEL_353",
368
+ "354": "LABEL_354",
369
+ "355": "LABEL_355",
370
+ "356": "LABEL_356",
371
+ "357": "LABEL_357",
372
+ "358": "LABEL_358",
373
+ "359": "LABEL_359",
374
+ "360": "LABEL_360",
375
+ "361": "LABEL_361",
376
+ "362": "LABEL_362",
377
+ "363": "LABEL_363",
378
+ "364": "LABEL_364",
379
+ "365": "LABEL_365",
380
+ "366": "LABEL_366",
381
+ "367": "LABEL_367",
382
+ "368": "LABEL_368",
383
+ "369": "LABEL_369",
384
+ "370": "LABEL_370",
385
+ "371": "LABEL_371",
386
+ "372": "LABEL_372",
387
+ "373": "LABEL_373",
388
+ "374": "LABEL_374",
389
+ "375": "LABEL_375",
390
+ "376": "LABEL_376",
391
+ "377": "LABEL_377",
392
+ "378": "LABEL_378",
393
+ "379": "LABEL_379",
394
+ "380": "LABEL_380",
395
+ "381": "LABEL_381",
396
+ "382": "LABEL_382",
397
+ "383": "LABEL_383",
398
+ "384": "LABEL_384",
399
+ "385": "LABEL_385",
400
+ "386": "LABEL_386",
401
+ "387": "LABEL_387",
402
+ "388": "LABEL_388",
403
+ "389": "LABEL_389",
404
+ "390": "LABEL_390"
405
+ },
406
+ "initializer_range": 0.02,
407
+ "intermediate_size": 1536,
408
+ "label2id": {
409
+ "LABEL_0": 0,
410
+ "LABEL_1": 1,
411
+ "LABEL_10": 10,
412
+ "LABEL_100": 100,
413
+ "LABEL_101": 101,
414
+ "LABEL_102": 102,
415
+ "LABEL_103": 103,
416
+ "LABEL_104": 104,
417
+ "LABEL_105": 105,
418
+ "LABEL_106": 106,
419
+ "LABEL_107": 107,
420
+ "LABEL_108": 108,
421
+ "LABEL_109": 109,
422
+ "LABEL_11": 11,
423
+ "LABEL_110": 110,
424
+ "LABEL_111": 111,
425
+ "LABEL_112": 112,
426
+ "LABEL_113": 113,
427
+ "LABEL_114": 114,
428
+ "LABEL_115": 115,
429
+ "LABEL_116": 116,
430
+ "LABEL_117": 117,
431
+ "LABEL_118": 118,
432
+ "LABEL_119": 119,
433
+ "LABEL_12": 12,
434
+ "LABEL_120": 120,
435
+ "LABEL_121": 121,
436
+ "LABEL_122": 122,
437
+ "LABEL_123": 123,
438
+ "LABEL_124": 124,
439
+ "LABEL_125": 125,
440
+ "LABEL_126": 126,
441
+ "LABEL_127": 127,
442
+ "LABEL_128": 128,
443
+ "LABEL_129": 129,
444
+ "LABEL_13": 13,
445
+ "LABEL_130": 130,
446
+ "LABEL_131": 131,
447
+ "LABEL_132": 132,
448
+ "LABEL_133": 133,
449
+ "LABEL_134": 134,
450
+ "LABEL_135": 135,
451
+ "LABEL_136": 136,
452
+ "LABEL_137": 137,
453
+ "LABEL_138": 138,
454
+ "LABEL_139": 139,
455
+ "LABEL_14": 14,
456
+ "LABEL_140": 140,
457
+ "LABEL_141": 141,
458
+ "LABEL_142": 142,
459
+ "LABEL_143": 143,
460
+ "LABEL_144": 144,
461
+ "LABEL_145": 145,
462
+ "LABEL_146": 146,
463
+ "LABEL_147": 147,
464
+ "LABEL_148": 148,
465
+ "LABEL_149": 149,
466
+ "LABEL_15": 15,
467
+ "LABEL_150": 150,
468
+ "LABEL_151": 151,
469
+ "LABEL_152": 152,
470
+ "LABEL_153": 153,
471
+ "LABEL_154": 154,
472
+ "LABEL_155": 155,
473
+ "LABEL_156": 156,
474
+ "LABEL_157": 157,
475
+ "LABEL_158": 158,
476
+ "LABEL_159": 159,
477
+ "LABEL_16": 16,
478
+ "LABEL_160": 160,
479
+ "LABEL_161": 161,
480
+ "LABEL_162": 162,
481
+ "LABEL_163": 163,
482
+ "LABEL_164": 164,
483
+ "LABEL_165": 165,
484
+ "LABEL_166": 166,
485
+ "LABEL_167": 167,
486
+ "LABEL_168": 168,
487
+ "LABEL_169": 169,
488
+ "LABEL_17": 17,
489
+ "LABEL_170": 170,
490
+ "LABEL_171": 171,
491
+ "LABEL_172": 172,
492
+ "LABEL_173": 173,
493
+ "LABEL_174": 174,
494
+ "LABEL_175": 175,
495
+ "LABEL_176": 176,
496
+ "LABEL_177": 177,
497
+ "LABEL_178": 178,
498
+ "LABEL_179": 179,
499
+ "LABEL_18": 18,
500
+ "LABEL_180": 180,
501
+ "LABEL_181": 181,
502
+ "LABEL_182": 182,
503
+ "LABEL_183": 183,
504
+ "LABEL_184": 184,
505
+ "LABEL_185": 185,
506
+ "LABEL_186": 186,
507
+ "LABEL_187": 187,
508
+ "LABEL_188": 188,
509
+ "LABEL_189": 189,
510
+ "LABEL_19": 19,
511
+ "LABEL_190": 190,
512
+ "LABEL_191": 191,
513
+ "LABEL_192": 192,
514
+ "LABEL_193": 193,
515
+ "LABEL_194": 194,
516
+ "LABEL_195": 195,
517
+ "LABEL_196": 196,
518
+ "LABEL_197": 197,
519
+ "LABEL_198": 198,
520
+ "LABEL_199": 199,
521
+ "LABEL_2": 2,
522
+ "LABEL_20": 20,
523
+ "LABEL_200": 200,
524
+ "LABEL_201": 201,
525
+ "LABEL_202": 202,
526
+ "LABEL_203": 203,
527
+ "LABEL_204": 204,
528
+ "LABEL_205": 205,
529
+ "LABEL_206": 206,
530
+ "LABEL_207": 207,
531
+ "LABEL_208": 208,
532
+ "LABEL_209": 209,
533
+ "LABEL_21": 21,
534
+ "LABEL_210": 210,
535
+ "LABEL_211": 211,
536
+ "LABEL_212": 212,
537
+ "LABEL_213": 213,
538
+ "LABEL_214": 214,
539
+ "LABEL_215": 215,
540
+ "LABEL_216": 216,
541
+ "LABEL_217": 217,
542
+ "LABEL_218": 218,
543
+ "LABEL_219": 219,
544
+ "LABEL_22": 22,
545
+ "LABEL_220": 220,
546
+ "LABEL_221": 221,
547
+ "LABEL_222": 222,
548
+ "LABEL_223": 223,
549
+ "LABEL_224": 224,
550
+ "LABEL_225": 225,
551
+ "LABEL_226": 226,
552
+ "LABEL_227": 227,
553
+ "LABEL_228": 228,
554
+ "LABEL_229": 229,
555
+ "LABEL_23": 23,
556
+ "LABEL_230": 230,
557
+ "LABEL_231": 231,
558
+ "LABEL_232": 232,
559
+ "LABEL_233": 233,
560
+ "LABEL_234": 234,
561
+ "LABEL_235": 235,
562
+ "LABEL_236": 236,
563
+ "LABEL_237": 237,
564
+ "LABEL_238": 238,
565
+ "LABEL_239": 239,
566
+ "LABEL_24": 24,
567
+ "LABEL_240": 240,
568
+ "LABEL_241": 241,
569
+ "LABEL_242": 242,
570
+ "LABEL_243": 243,
571
+ "LABEL_244": 244,
572
+ "LABEL_245": 245,
573
+ "LABEL_246": 246,
574
+ "LABEL_247": 247,
575
+ "LABEL_248": 248,
576
+ "LABEL_249": 249,
577
+ "LABEL_25": 25,
578
+ "LABEL_250": 250,
579
+ "LABEL_251": 251,
580
+ "LABEL_252": 252,
581
+ "LABEL_253": 253,
582
+ "LABEL_254": 254,
583
+ "LABEL_255": 255,
584
+ "LABEL_256": 256,
585
+ "LABEL_257": 257,
586
+ "LABEL_258": 258,
587
+ "LABEL_259": 259,
588
+ "LABEL_26": 26,
589
+ "LABEL_260": 260,
590
+ "LABEL_261": 261,
591
+ "LABEL_262": 262,
592
+ "LABEL_263": 263,
593
+ "LABEL_264": 264,
594
+ "LABEL_265": 265,
595
+ "LABEL_266": 266,
596
+ "LABEL_267": 267,
597
+ "LABEL_268": 268,
598
+ "LABEL_269": 269,
599
+ "LABEL_27": 27,
600
+ "LABEL_270": 270,
601
+ "LABEL_271": 271,
602
+ "LABEL_272": 272,
603
+ "LABEL_273": 273,
604
+ "LABEL_274": 274,
605
+ "LABEL_275": 275,
606
+ "LABEL_276": 276,
607
+ "LABEL_277": 277,
608
+ "LABEL_278": 278,
609
+ "LABEL_279": 279,
610
+ "LABEL_28": 28,
611
+ "LABEL_280": 280,
612
+ "LABEL_281": 281,
613
+ "LABEL_282": 282,
614
+ "LABEL_283": 283,
615
+ "LABEL_284": 284,
616
+ "LABEL_285": 285,
617
+ "LABEL_286": 286,
618
+ "LABEL_287": 287,
619
+ "LABEL_288": 288,
620
+ "LABEL_289": 289,
621
+ "LABEL_29": 29,
622
+ "LABEL_290": 290,
623
+ "LABEL_291": 291,
624
+ "LABEL_292": 292,
625
+ "LABEL_293": 293,
626
+ "LABEL_294": 294,
627
+ "LABEL_295": 295,
628
+ "LABEL_296": 296,
629
+ "LABEL_297": 297,
630
+ "LABEL_298": 298,
631
+ "LABEL_299": 299,
632
+ "LABEL_3": 3,
633
+ "LABEL_30": 30,
634
+ "LABEL_300": 300,
635
+ "LABEL_301": 301,
636
+ "LABEL_302": 302,
637
+ "LABEL_303": 303,
638
+ "LABEL_304": 304,
639
+ "LABEL_305": 305,
640
+ "LABEL_306": 306,
641
+ "LABEL_307": 307,
642
+ "LABEL_308": 308,
643
+ "LABEL_309": 309,
644
+ "LABEL_31": 31,
645
+ "LABEL_310": 310,
646
+ "LABEL_311": 311,
647
+ "LABEL_312": 312,
648
+ "LABEL_313": 313,
649
+ "LABEL_314": 314,
650
+ "LABEL_315": 315,
651
+ "LABEL_316": 316,
652
+ "LABEL_317": 317,
653
+ "LABEL_318": 318,
654
+ "LABEL_319": 319,
655
+ "LABEL_32": 32,
656
+ "LABEL_320": 320,
657
+ "LABEL_321": 321,
658
+ "LABEL_322": 322,
659
+ "LABEL_323": 323,
660
+ "LABEL_324": 324,
661
+ "LABEL_325": 325,
662
+ "LABEL_326": 326,
663
+ "LABEL_327": 327,
664
+ "LABEL_328": 328,
665
+ "LABEL_329": 329,
666
+ "LABEL_33": 33,
667
+ "LABEL_330": 330,
668
+ "LABEL_331": 331,
669
+ "LABEL_332": 332,
670
+ "LABEL_333": 333,
671
+ "LABEL_334": 334,
672
+ "LABEL_335": 335,
673
+ "LABEL_336": 336,
674
+ "LABEL_337": 337,
675
+ "LABEL_338": 338,
676
+ "LABEL_339": 339,
677
+ "LABEL_34": 34,
678
+ "LABEL_340": 340,
679
+ "LABEL_341": 341,
680
+ "LABEL_342": 342,
681
+ "LABEL_343": 343,
682
+ "LABEL_344": 344,
683
+ "LABEL_345": 345,
684
+ "LABEL_346": 346,
685
+ "LABEL_347": 347,
686
+ "LABEL_348": 348,
687
+ "LABEL_349": 349,
688
+ "LABEL_35": 35,
689
+ "LABEL_350": 350,
690
+ "LABEL_351": 351,
691
+ "LABEL_352": 352,
692
+ "LABEL_353": 353,
693
+ "LABEL_354": 354,
694
+ "LABEL_355": 355,
695
+ "LABEL_356": 356,
696
+ "LABEL_357": 357,
697
+ "LABEL_358": 358,
698
+ "LABEL_359": 359,
699
+ "LABEL_36": 36,
700
+ "LABEL_360": 360,
701
+ "LABEL_361": 361,
702
+ "LABEL_362": 362,
703
+ "LABEL_363": 363,
704
+ "LABEL_364": 364,
705
+ "LABEL_365": 365,
706
+ "LABEL_366": 366,
707
+ "LABEL_367": 367,
708
+ "LABEL_368": 368,
709
+ "LABEL_369": 369,
710
+ "LABEL_37": 37,
711
+ "LABEL_370": 370,
712
+ "LABEL_371": 371,
713
+ "LABEL_372": 372,
714
+ "LABEL_373": 373,
715
+ "LABEL_374": 374,
716
+ "LABEL_375": 375,
717
+ "LABEL_376": 376,
718
+ "LABEL_377": 377,
719
+ "LABEL_378": 378,
720
+ "LABEL_379": 379,
721
+ "LABEL_38": 38,
722
+ "LABEL_380": 380,
723
+ "LABEL_381": 381,
724
+ "LABEL_382": 382,
725
+ "LABEL_383": 383,
726
+ "LABEL_384": 384,
727
+ "LABEL_385": 385,
728
+ "LABEL_386": 386,
729
+ "LABEL_387": 387,
730
+ "LABEL_388": 388,
731
+ "LABEL_389": 389,
732
+ "LABEL_39": 39,
733
+ "LABEL_390": 390,
734
+ "LABEL_4": 4,
735
+ "LABEL_40": 40,
736
+ "LABEL_41": 41,
737
+ "LABEL_42": 42,
738
+ "LABEL_43": 43,
739
+ "LABEL_44": 44,
740
+ "LABEL_45": 45,
741
+ "LABEL_46": 46,
742
+ "LABEL_47": 47,
743
+ "LABEL_48": 48,
744
+ "LABEL_49": 49,
745
+ "LABEL_5": 5,
746
+ "LABEL_50": 50,
747
+ "LABEL_51": 51,
748
+ "LABEL_52": 52,
749
+ "LABEL_53": 53,
750
+ "LABEL_54": 54,
751
+ "LABEL_55": 55,
752
+ "LABEL_56": 56,
753
+ "LABEL_57": 57,
754
+ "LABEL_58": 58,
755
+ "LABEL_59": 59,
756
+ "LABEL_6": 6,
757
+ "LABEL_60": 60,
758
+ "LABEL_61": 61,
759
+ "LABEL_62": 62,
760
+ "LABEL_63": 63,
761
+ "LABEL_64": 64,
762
+ "LABEL_65": 65,
763
+ "LABEL_66": 66,
764
+ "LABEL_67": 67,
765
+ "LABEL_68": 68,
766
+ "LABEL_69": 69,
767
+ "LABEL_7": 7,
768
+ "LABEL_70": 70,
769
+ "LABEL_71": 71,
770
+ "LABEL_72": 72,
771
+ "LABEL_73": 73,
772
+ "LABEL_74": 74,
773
+ "LABEL_75": 75,
774
+ "LABEL_76": 76,
775
+ "LABEL_77": 77,
776
+ "LABEL_78": 78,
777
+ "LABEL_79": 79,
778
+ "LABEL_8": 8,
779
+ "LABEL_80": 80,
780
+ "LABEL_81": 81,
781
+ "LABEL_82": 82,
782
+ "LABEL_83": 83,
783
+ "LABEL_84": 84,
784
+ "LABEL_85": 85,
785
+ "LABEL_86": 86,
786
+ "LABEL_87": 87,
787
+ "LABEL_88": 88,
788
+ "LABEL_89": 89,
789
+ "LABEL_9": 9,
790
+ "LABEL_90": 90,
791
+ "LABEL_91": 91,
792
+ "LABEL_92": 92,
793
+ "LABEL_93": 93,
794
+ "LABEL_94": 94,
795
+ "LABEL_95": 95,
796
+ "LABEL_96": 96,
797
+ "LABEL_97": 97,
798
+ "LABEL_98": 98,
799
+ "LABEL_99": 99
800
+ },
801
+ "layer_norm_eps": 1e-12,
802
+ "max_position_embeddings": 512,
803
+ "model_type": "bert",
804
+ "num_attention_heads": 12,
805
+ "num_hidden_layers": 6,
806
+ "num_relation_heads": 32,
807
+ "pad_token_id": 0,
808
+ "pooler_fc_size": 768,
809
+ "pooler_num_attention_heads": 12,
810
+ "pooler_num_fc_layers": 3,
811
+ "pooler_size_per_head": 128,
812
+ "pooler_type": "first_token_transform",
813
+ "position_embedding_type": "absolute",
814
+ "torch_dtype": "float32",
815
+ "transformers_version": "4.44.1",
816
+ "type_vocab_size": 2,
817
+ "use_cache": true,
818
+ "vocab_size": 21128
819
+ }
soundsation/g2p/sources/g2p_chinese_model/poly_bert_model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8765d835ffdf9811c832d4dc7b6a552757aa8615c01d1184db716a50c20aebbc
3
+ size 76583333
soundsation/g2p/sources/g2p_chinese_model/polychar.txt ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
+
6
+
7
+
8
+
9
+
10
+
11
+
12
+
13
+
14
+
15
+ 便
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+
24
+
25
+
26
+
27
+
28
+
29
+
30
+
31
+
32
+
33
+
34
+
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+
51
+
52
+
53
+ 宿
54
+
55
+
56
+
57
+
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+
67
+
68
+
69
+
70
+
71
+
72
+
73
+
74
+
75
+
76
+
77
+
78
+
79
+
80
+
81
+
82
+
83
+
84
+
85
+
86
+
87
+
88
+
89
+
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
100
+
101
+
102
+
103
+
104
+
105
+
106
+
107
+
108
+
109
+
110
+
111
+
112
+
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+
138
+
139
+
140
+
141
+
142
+
143
+
144
+
145
+
146
+
147
+
148
+
149
+
150
+
151
+
152
+
153
+
154
+
155
+
156
+
157
+
158
+
159
+
soundsation/g2p/sources/g2p_chinese_model/polydict.json ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": "丧{sang1}",
3
+ "2": "丧{sang4}",
4
+ "3": "中{zhong1}",
5
+ "4": "中{zhong4}",
6
+ "5": "为{wei2}",
7
+ "6": "为{wei4}",
8
+ "7": "乌{wu1}",
9
+ "8": "乌{wu4}",
10
+ "9": "乐{lao4}",
11
+ "10": "乐{le4}",
12
+ "11": "乐{le5}",
13
+ "12": "乐{yao4}",
14
+ "13": "乐{yve4}",
15
+ "14": "了{le5}",
16
+ "15": "了{liao3}",
17
+ "16": "了{liao5}",
18
+ "17": "什{shen2}",
19
+ "18": "什{shi2}",
20
+ "19": "仔{zai3}",
21
+ "20": "仔{zai5}",
22
+ "21": "仔{zi3}",
23
+ "22": "仔{zi5}",
24
+ "23": "令{ling2}",
25
+ "24": "令{ling4}",
26
+ "25": "任{ren2}",
27
+ "26": "任{ren4}",
28
+ "27": "会{hui4}",
29
+ "28": "会{hui5}",
30
+ "29": "会{kuai4}",
31
+ "30": "传{chuan2}",
32
+ "31": "传{zhuan4}",
33
+ "32": "佛{fo2}",
34
+ "33": "佛{fu2}",
35
+ "34": "供{gong1}",
36
+ "35": "供{gong4}",
37
+ "36": "便{bian4}",
38
+ "37": "便{pian2}",
39
+ "38": "倒{dao3}",
40
+ "39": "倒{dao4}",
41
+ "40": "假{jia3}",
42
+ "41": "假{jia4}",
43
+ "42": "兴{xing1}",
44
+ "43": "兴{xing4}",
45
+ "44": "冠{guan1}",
46
+ "45": "冠{guan4}",
47
+ "46": "冲{chong1}",
48
+ "47": "冲{chong4}",
49
+ "48": "几{ji1}",
50
+ "49": "几{ji2}",
51
+ "50": "几{ji3}",
52
+ "51": "分{fen1}",
53
+ "52": "分{fen4}",
54
+ "53": "分{fen5}",
55
+ "54": "切{qie1}",
56
+ "55": "切{qie4}",
57
+ "56": "划{hua2}",
58
+ "57": "划{hua4}",
59
+ "58": "划{hua5}",
60
+ "59": "创{chuang1}",
61
+ "60": "创{chuang4}",
62
+ "61": "剥{bao1}",
63
+ "62": "剥{bo1}",
64
+ "63": "勒{le4}",
65
+ "64": "勒{le5}",
66
+ "65": "勒{lei1}",
67
+ "66": "区{ou1}",
68
+ "67": "区{qu1}",
69
+ "68": "华{hua2}",
70
+ "69": "华{hua4}",
71
+ "70": "单{chan2}",
72
+ "71": "单{dan1}",
73
+ "72": "单{shan4}",
74
+ "73": "卜{bo5}",
75
+ "74": "卜{bu3}",
76
+ "75": "占{zhan1}",
77
+ "76": "占{zhan4}",
78
+ "77": "卡{ka2}",
79
+ "78": "卡{ka3}",
80
+ "79": "卡{qia3}",
81
+ "80": "卷{jvan3}",
82
+ "81": "卷{jvan4}",
83
+ "82": "厦{sha4}",
84
+ "83": "厦{xia4}",
85
+ "84": "参{can1}",
86
+ "85": "参{cen1}",
87
+ "86": "参{shen1}",
88
+ "87": "发{fa1}",
89
+ "88": "发{fa4}",
90
+ "89": "发{fa5}",
91
+ "90": "只{zhi1}",
92
+ "91": "只{zhi3}",
93
+ "92": "号{hao2}",
94
+ "93": "号{hao4}",
95
+ "94": "号{hao5}",
96
+ "95": "同{tong2}",
97
+ "96": "同{tong4}",
98
+ "97": "同{tong5}",
99
+ "98": "吐{tu2}",
100
+ "99": "吐{tu3}",
101
+ "100": "吐{tu4}",
102
+ "101": "和{he2}",
103
+ "102": "和{he4}",
104
+ "103": "和{he5}",
105
+ "104": "和{huo2}",
106
+ "105": "和{huo4}",
107
+ "106": "和{huo5}",
108
+ "107": "喝{he1}",
109
+ "108": "喝{he4}",
110
+ "109": "圈{jvan4}",
111
+ "110": "圈{qvan1}",
112
+ "111": "圈{qvan5}",
113
+ "112": "地{de5}",
114
+ "113": "地{di4}",
115
+ "114": "地{di5}",
116
+ "115": "塞{sai1}",
117
+ "116": "塞{sai2}",
118
+ "117": "塞{sai4}",
119
+ "118": "塞{se4}",
120
+ "119": "壳{ke2}",
121
+ "120": "壳{qiao4}",
122
+ "121": "处{chu3}",
123
+ "122": "处{chu4}",
124
+ "123": "奇{ji1}",
125
+ "124": "奇{qi2}",
126
+ "125": "奔{ben1}",
127
+ "126": "奔{ben4}",
128
+ "127": "好{hao3}",
129
+ "128": "好{hao4}",
130
+ "129": "好{hao5}",
131
+ "130": "宁{ning2}",
132
+ "131": "宁{ning4}",
133
+ "132": "宁{ning5}",
134
+ "133": "宿{su4}",
135
+ "134": "宿{xiu3}",
136
+ "135": "宿{xiu4}",
137
+ "136": "将{jiang1}",
138
+ "137": "将{jiang4}",
139
+ "138": "少{shao3}",
140
+ "139": "少{shao4}",
141
+ "140": "尽{jin3}",
142
+ "141": "尽{jin4}",
143
+ "142": "岗{gang1}",
144
+ "143": "岗{gang3}",
145
+ "144": "差{cha1}",
146
+ "145": "差{cha4}",
147
+ "146": "差{chai1}",
148
+ "147": "差{ci1}",
149
+ "148": "巷{hang4}",
150
+ "149": "巷{xiang4}",
151
+ "150": "帖{tie1}",
152
+ "151": "帖{tie3}",
153
+ "152": "帖{tie4}",
154
+ "153": "干{gan1}",
155
+ "154": "干{gan4}",
156
+ "155": "应{ying1}",
157
+ "156": "应{ying4}",
158
+ "157": "应{ying5}",
159
+ "158": "度{du4}",
160
+ "159": "度{du5}",
161
+ "160": "度{duo2}",
162
+ "161": "弹{dan4}",
163
+ "162": "弹{tan2}",
164
+ "163": "弹{tan5}",
165
+ "164": "强{jiang4}",
166
+ "165": "强{qiang2}",
167
+ "166": "强{qiang3}",
168
+ "167": "当{dang1}",
169
+ "168": "当{dang4}",
170
+ "169": "当{dang5}",
171
+ "170": "待{dai1}",
172
+ "171": "待{dai4}",
173
+ "172": "得{de2}",
174
+ "173": "得{de5}",
175
+ "174": "得{dei3}",
176
+ "175": "得{dei5}",
177
+ "176": "恶{e3}",
178
+ "177": "恶{e4}",
179
+ "178": "恶{wu4}",
180
+ "179": "扁{bian3}",
181
+ "180": "扁{pian1}",
182
+ "181": "扇{shan1}",
183
+ "182": "扇{shan4}",
184
+ "183": "扎{za1}",
185
+ "184": "扎{zha1}",
186
+ "185": "扎{zha2}",
187
+ "186": "扫{sao3}",
188
+ "187": "扫{sao4}",
189
+ "188": "担{dan1}",
190
+ "189": "担{dan4}",
191
+ "190": "担{dan5}",
192
+ "191": "挑{tiao1}",
193
+ "192": "挑{tiao3}",
194
+ "193": "据{jv1}",
195
+ "194": "据{jv4}",
196
+ "195": "撒{sa1}",
197
+ "196": "撒{sa3}",
198
+ "197": "撒{sa5}",
199
+ "198": "教{jiao1}",
200
+ "199": "教{jiao4}",
201
+ "200": "散{san3}",
202
+ "201": "散{san4}",
203
+ "202": "散{san5}",
204
+ "203": "数{shu3}",
205
+ "204": "数{shu4}",
206
+ "205": "数{shu5}",
207
+ "206": "斗{dou3}",
208
+ "207": "斗{dou4}",
209
+ "208": "晃{huang3}",
210
+ "209": "曝{bao4}",
211
+ "210": "曲{qu1}",
212
+ "211": "曲{qu3}",
213
+ "212": "更{geng1}",
214
+ "213": "更{geng4}",
215
+ "214": "曾{ceng1}",
216
+ "215": "曾{ceng2}",
217
+ "216": "曾{zeng1}",
218
+ "217": "朝{chao2}",
219
+ "218": "朝{zhao1}",
220
+ "219": "朴{piao2}",
221
+ "220": "朴{pu2}",
222
+ "221": "朴{pu3}",
223
+ "222": "杆{gan1}",
224
+ "223": "杆{gan3}",
225
+ "224": "查{cha2}",
226
+ "225": "查{zha1}",
227
+ "226": "校{jiao4}",
228
+ "227": "校{xiao4}",
229
+ "228": "模{mo2}",
230
+ "229": "模{mu2}",
231
+ "230": "横{heng2}",
232
+ "231": "横{heng4}",
233
+ "232": "没{mei2}",
234
+ "233": "没{mo4}",
235
+ "234": "泡{pao1}",
236
+ "235": "泡{pao4}",
237
+ "236": "泡{pao5}",
238
+ "237": "济{ji3}",
239
+ "238": "济{ji4}",
240
+ "239": "混{hun2}",
241
+ "240": "混{hun3}",
242
+ "241": "混{hun4}",
243
+ "242": "混{hun5}",
244
+ "243": "漂{piao1}",
245
+ "244": "漂{piao3}",
246
+ "245": "漂{piao4}",
247
+ "246": "炸{zha2}",
248
+ "247": "炸{zha4}",
249
+ "248": "熟{shou2}",
250
+ "249": "熟{shu2}",
251
+ "250": "燕{yan1}",
252
+ "251": "燕{yan4}",
253
+ "252": "片{pian1}",
254
+ "253": "片{pian4}",
255
+ "254": "率{lv4}",
256
+ "255": "率{shuai4}",
257
+ "256": "畜{chu4}",
258
+ "257": "畜{xu4}",
259
+ "258": "的{de5}",
260
+ "259": "的{di1}",
261
+ "260": "的{di2}",
262
+ "261": "的{di4}",
263
+ "262": "的{di5}",
264
+ "263": "盛{cheng2}",
265
+ "264": "盛{sheng4}",
266
+ "265": "相{xiang1}",
267
+ "266": "相{xiang4}",
268
+ "267": "相{xiang5}",
269
+ "268": "省{sheng3}",
270
+ "269": "省{xing3}",
271
+ "270": "看{kan1}",
272
+ "271": "看{kan4}",
273
+ "272": "看{kan5}",
274
+ "273": "着{zhao1}",
275
+ "274": "着{zhao2}",
276
+ "275": "着{zhao5}",
277
+ "276": "着{zhe5}",
278
+ "277": "着{zhuo2}",
279
+ "278": "着{zhuo5}",
280
+ "279": "矫{jiao3}",
281
+ "280": "禁{jin1}",
282
+ "281": "禁{jin4}",
283
+ "282": "种{zhong3}",
284
+ "283": "种{zhong4}",
285
+ "284": "称{chen4}",
286
+ "285": "称{cheng1}",
287
+ "286": "空{kong1}",
288
+ "287": "空{kong4}",
289
+ "288": "答{da1}",
290
+ "289": "答{da2}",
291
+ "290": "粘{nian2}",
292
+ "291": "粘{zhan1}",
293
+ "292": "糊{hu2}",
294
+ "293": "糊{hu5}",
295
+ "294": "系{ji4}",
296
+ "295": "系{xi4}",
297
+ "296": "系{xi5}",
298
+ "297": "累{lei2}",
299
+ "298": "累{lei3}",
300
+ "299": "累{lei4}",
301
+ "300": "累{lei5}",
302
+ "301": "纤{qian4}",
303
+ "302": "纤{xian1}",
304
+ "303": "结{jie1}",
305
+ "304": "结{jie2}",
306
+ "305": "结{jie5}",
307
+ "306": "给{gei3}",
308
+ "307": "给{gei5}",
309
+ "308": "给{ji3}",
310
+ "309": "缝{feng2}",
311
+ "310": "缝{feng4}",
312
+ "311": "缝{feng5}",
313
+ "312": "肖{xiao1}",
314
+ "313": "肖{xiao4}",
315
+ "314": "背{bei1}",
316
+ "315": "背{bei4}",
317
+ "316": "脏{zang1}",
318
+ "317": "脏{zang4}",
319
+ "318": "舍{she3}",
320
+ "319": "舍{she4}",
321
+ "320": "色{se4}",
322
+ "321": "色{shai3}",
323
+ "322": "落{lao4}",
324
+ "323": "落{luo4}",
325
+ "324": "蒙{meng1}",
326
+ "325": "蒙{meng2}",
327
+ "326": "蒙{meng3}",
328
+ "327": "薄{bao2}",
329
+ "328": "薄{bo2}",
330
+ "329": "薄{bo4}",
331
+ "330": "藏{cang2}",
332
+ "331": "藏{zang4}",
333
+ "332": "血{xie3}",
334
+ "333": "血{xue4}",
335
+ "334": "行{hang2}",
336
+ "335": "行{hang5}",
337
+ "336": "行{heng5}",
338
+ "337": "行{xing2}",
339
+ "338": "行{xing4}",
340
+ "339": "要{yao1}",
341
+ "340": "要{yao4}",
342
+ "341": "观{guan1}",
343
+ "342": "观{guan4}",
344
+ "343": "觉{jiao4}",
345
+ "344": "觉{jiao5}",
346
+ "345": "觉{jve2}",
347
+ "346": "角{jiao3}",
348
+ "347": "角{jve2}",
349
+ "348": "解{jie3}",
350
+ "349": "解{jie4}",
351
+ "350": "解{xie4}",
352
+ "351": "说{shui4}",
353
+ "352": "说{shuo1}",
354
+ "353": "调{diao4}",
355
+ "354": "调{tiao2}",
356
+ "355": "踏{ta1}",
357
+ "356": "踏{ta4}",
358
+ "357": "车{che1}",
359
+ "358": "车{jv1}",
360
+ "359": "转{zhuan3}",
361
+ "360": "转{zhuan4}",
362
+ "361": "载{zai3}",
363
+ "362": "载{zai4}",
364
+ "363": "还{hai2}",
365
+ "364": "还{huan2}",
366
+ "365": "遂{sui2}",
367
+ "366": "遂{sui4}",
368
+ "367": "都{dou1}",
369
+ "368": "都{du1}",
370
+ "369": "重{chong2}",
371
+ "370": "重{zhong4}",
372
+ "371": "量{liang2}",
373
+ "372": "量{liang4}",
374
+ "373": "量{liang5}",
375
+ "374": "钻{zuan1}",
376
+ "375": "钻{zuan4}",
377
+ "376": "铺{pu1}",
378
+ "377": "铺{pu4}",
379
+ "378": "长{chang2}",
380
+ "379": "长{chang3}",
381
+ "380": "长{zhang3}",
382
+ "381": "间{jian1}",
383
+ "382": "间{jian4}",
384
+ "383": "降{jiang4}",
385
+ "384": "降{xiang2}",
386
+ "385": "难{nan2}",
387
+ "386": "难{nan4}",
388
+ "387": "难{nan5}",
389
+ "388": "露{lou4}",
390
+ "389": "露{lu4}",
391
+ "390": "鲜{xian1}",
392
+ "391": "鲜{xian3}"
393
+ }
soundsation/g2p/sources/g2p_chinese_model/polydict_r.json ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "丧{sang1}": 1,
3
+ "丧{sang4}": 2,
4
+ "中{zhong1}": 3,
5
+ "中{zhong4}": 4,
6
+ "为{wei2}": 5,
7
+ "为{wei4}": 6,
8
+ "乌{wu1}": 7,
9
+ "乌{wu4}": 8,
10
+ "乐{lao4}": 9,
11
+ "乐{le4}": 10,
12
+ "乐{le5}": 11,
13
+ "乐{yao4}": 12,
14
+ "乐{yve4}": 13,
15
+ "了{le5}": 14,
16
+ "了{liao3}": 15,
17
+ "了{liao5}": 16,
18
+ "什{shen2}": 17,
19
+ "什{shi2}": 18,
20
+ "仔{zai3}": 19,
21
+ "仔{zai5}": 20,
22
+ "仔{zi3}": 21,
23
+ "仔{zi5}": 22,
24
+ "令{ling2}": 23,
25
+ "令{ling4}": 24,
26
+ "任{ren2}": 25,
27
+ "任{ren4}": 26,
28
+ "会{hui4}": 27,
29
+ "会{hui5}": 28,
30
+ "会{kuai4}": 29,
31
+ "传{chuan2}": 30,
32
+ "传{zhuan4}": 31,
33
+ "佛{fo2}": 32,
34
+ "佛{fu2}": 33,
35
+ "供{gong1}": 34,
36
+ "供{gong4}": 35,
37
+ "便{bian4}": 36,
38
+ "便{pian2}": 37,
39
+ "倒{dao3}": 38,
40
+ "倒{dao4}": 39,
41
+ "假{jia3}": 40,
42
+ "假{jia4}": 41,
43
+ "兴{xing1}": 42,
44
+ "兴{xing4}": 43,
45
+ "冠{guan1}": 44,
46
+ "冠{guan4}": 45,
47
+ "冲{chong1}": 46,
48
+ "冲{chong4}": 47,
49
+ "几{ji1}": 48,
50
+ "几{ji2}": 49,
51
+ "几{ji3}": 50,
52
+ "分{fen1}": 51,
53
+ "分{fen4}": 52,
54
+ "分{fen5}": 53,
55
+ "切{qie1}": 54,
56
+ "切{qie4}": 55,
57
+ "划{hua2}": 56,
58
+ "划{hua4}": 57,
59
+ "划{hua5}": 58,
60
+ "创{chuang1}": 59,
61
+ "创{chuang4}": 60,
62
+ "剥{bao1}": 61,
63
+ "剥{bo1}": 62,
64
+ "勒{le4}": 63,
65
+ "勒{le5}": 64,
66
+ "勒{lei1}": 65,
67
+ "区{ou1}": 66,
68
+ "区{qu1}": 67,
69
+ "华{hua2}": 68,
70
+ "华{hua4}": 69,
71
+ "单{chan2}": 70,
72
+ "单{dan1}": 71,
73
+ "单{shan4}": 72,
74
+ "卜{bo5}": 73,
75
+ "卜{bu3}": 74,
76
+ "占{zhan1}": 75,
77
+ "占{zhan4}": 76,
78
+ "卡{ka2}": 77,
79
+ "卡{ka3}": 78,
80
+ "卡{qia3}": 79,
81
+ "卷{jvan3}": 80,
82
+ "卷{jvan4}": 81,
83
+ "厦{sha4}": 82,
84
+ "厦{xia4}": 83,
85
+ "参{can1}": 84,
86
+ "参{cen1}": 85,
87
+ "参{shen1}": 86,
88
+ "发{fa1}": 87,
89
+ "发{fa4}": 88,
90
+ "发{fa5}": 89,
91
+ "只{zhi1}": 90,
92
+ "只{zhi3}": 91,
93
+ "号{hao2}": 92,
94
+ "号{hao4}": 93,
95
+ "号{hao5}": 94,
96
+ "同{tong2}": 95,
97
+ "同{tong4}": 96,
98
+ "同{tong5}": 97,
99
+ "吐{tu2}": 98,
100
+ "吐{tu3}": 99,
101
+ "吐{tu4}": 100,
102
+ "和{he2}": 101,
103
+ "和{he4}": 102,
104
+ "和{he5}": 103,
105
+ "和{huo2}": 104,
106
+ "和{huo4}": 105,
107
+ "和{huo5}": 106,
108
+ "喝{he1}": 107,
109
+ "喝{he4}": 108,
110
+ "圈{jvan4}": 109,
111
+ "圈{qvan1}": 110,
112
+ "圈{qvan5}": 111,
113
+ "地{de5}": 112,
114
+ "地{di4}": 113,
115
+ "地{di5}": 114,
116
+ "塞{sai1}": 115,
117
+ "塞{sai2}": 116,
118
+ "塞{sai4}": 117,
119
+ "塞{se4}": 118,
120
+ "壳{ke2}": 119,
121
+ "壳{qiao4}": 120,
122
+ "处{chu3}": 121,
123
+ "处{chu4}": 122,
124
+ "奇{ji1}": 123,
125
+ "奇{qi2}": 124,
126
+ "奔{ben1}": 125,
127
+ "奔{ben4}": 126,
128
+ "好{hao3}": 127,
129
+ "好{hao4}": 128,
130
+ "好{hao5}": 129,
131
+ "宁{ning2}": 130,
132
+ "宁{ning4}": 131,
133
+ "宁{ning5}": 132,
134
+ "宿{su4}": 133,
135
+ "宿{xiu3}": 134,
136
+ "宿{xiu4}": 135,
137
+ "将{jiang1}": 136,
138
+ "将{jiang4}": 137,
139
+ "少{shao3}": 138,
140
+ "少{shao4}": 139,
141
+ "尽{jin3}": 140,
142
+ "尽{jin4}": 141,
143
+ "岗{gang1}": 142,
144
+ "岗{gang3}": 143,
145
+ "差{cha1}": 144,
146
+ "差{cha4}": 145,
147
+ "差{chai1}": 146,
148
+ "差{ci1}": 147,
149
+ "巷{hang4}": 148,
150
+ "巷{xiang4}": 149,
151
+ "帖{tie1}": 150,
152
+ "帖{tie3}": 151,
153
+ "帖{tie4}": 152,
154
+ "干{gan1}": 153,
155
+ "干{gan4}": 154,
156
+ "应{ying1}": 155,
157
+ "应{ying4}": 156,
158
+ "应{ying5}": 157,
159
+ "度{du4}": 158,
160
+ "度{du5}": 159,
161
+ "度{duo2}": 160,
162
+ "弹{dan4}": 161,
163
+ "弹{tan2}": 162,
164
+ "弹{tan5}": 163,
165
+ "强{jiang4}": 164,
166
+ "强{qiang2}": 165,
167
+ "强{qiang3}": 166,
168
+ "当{dang1}": 167,
169
+ "当{dang4}": 168,
170
+ "当{dang5}": 169,
171
+ "待{dai1}": 170,
172
+ "待{dai4}": 171,
173
+ "得{de2}": 172,
174
+ "得{de5}": 173,
175
+ "得{dei3}": 174,
176
+ "得{dei5}": 175,
177
+ "恶{e3}": 176,
178
+ "恶{e4}": 177,
179
+ "恶{wu4}": 178,
180
+ "扁{bian3}": 179,
181
+ "扁{pian1}": 180,
182
+ "扇{shan1}": 181,
183
+ "扇{shan4}": 182,
184
+ "扎{za1}": 183,
185
+ "扎{zha1}": 184,
186
+ "扎{zha2}": 185,
187
+ "扫{sao3}": 186,
188
+ "扫{sao4}": 187,
189
+ "担{dan1}": 188,
190
+ "担{dan4}": 189,
191
+ "担{dan5}": 190,
192
+ "挑{tiao1}": 191,
193
+ "挑{tiao3}": 192,
194
+ "据{jv1}": 193,
195
+ "据{jv4}": 194,
196
+ "撒{sa1}": 195,
197
+ "撒{sa3}": 196,
198
+ "撒{sa5}": 197,
199
+ "教{jiao1}": 198,
200
+ "教{jiao4}": 199,
201
+ "散{san3}": 200,
202
+ "散{san4}": 201,
203
+ "散{san5}": 202,
204
+ "数{shu3}": 203,
205
+ "数{shu4}": 204,
206
+ "数{shu5}": 205,
207
+ "斗{dou3}": 206,
208
+ "斗{dou4}": 207,
209
+ "晃{huang3}": 208,
210
+ "曝{bao4}": 209,
211
+ "曲{qu1}": 210,
212
+ "曲{qu3}": 211,
213
+ "更{geng1}": 212,
214
+ "更{geng4}": 213,
215
+ "曾{ceng1}": 214,
216
+ "曾{ceng2}": 215,
217
+ "曾{zeng1}": 216,
218
+ "朝{chao2}": 217,
219
+ "朝{zhao1}": 218,
220
+ "朴{piao2}": 219,
221
+ "朴{pu2}": 220,
222
+ "朴{pu3}": 221,
223
+ "杆{gan1}": 222,
224
+ "杆{gan3}": 223,
225
+ "查{cha2}": 224,
226
+ "查{zha1}": 225,
227
+ "校{jiao4}": 226,
228
+ "校{xiao4}": 227,
229
+ "模{mo2}": 228,
230
+ "模{mu2}": 229,
231
+ "横{heng2}": 230,
232
+ "横{heng4}": 231,
233
+ "没{mei2}": 232,
234
+ "没{mo4}": 233,
235
+ "泡{pao1}": 234,
236
+ "泡{pao4}": 235,
237
+ "泡{pao5}": 236,
238
+ "济{ji3}": 237,
239
+ "济{ji4}": 238,
240
+ "混{hun2}": 239,
241
+ "混{hun3}": 240,
242
+ "混{hun4}": 241,
243
+ "混{hun5}": 242,
244
+ "漂{piao1}": 243,
245
+ "漂{piao3}": 244,
246
+ "漂{piao4}": 245,
247
+ "炸{zha2}": 246,
248
+ "炸{zha4}": 247,
249
+ "熟{shou2}": 248,
250
+ "熟{shu2}": 249,
251
+ "燕{yan1}": 250,
252
+ "燕{yan4}": 251,
253
+ "片{pian1}": 252,
254
+ "片{pian4}": 253,
255
+ "率{lv4}": 254,
256
+ "率{shuai4}": 255,
257
+ "畜{chu4}": 256,
258
+ "畜{xu4}": 257,
259
+ "的{de5}": 258,
260
+ "的{di1}": 259,
261
+ "的{di2}": 260,
262
+ "的{di4}": 261,
263
+ "的{di5}": 262,
264
+ "盛{cheng2}": 263,
265
+ "盛{sheng4}": 264,
266
+ "相{xiang1}": 265,
267
+ "相{xiang4}": 266,
268
+ "相{xiang5}": 267,
269
+ "省{sheng3}": 268,
270
+ "省{xing3}": 269,
271
+ "看{kan1}": 270,
272
+ "看{kan4}": 271,
273
+ "看{kan5}": 272,
274
+ "着{zhao1}": 273,
275
+ "着{zhao2}": 274,
276
+ "着{zhao5}": 275,
277
+ "着{zhe5}": 276,
278
+ "着{zhuo2}": 277,
279
+ "着{zhuo5}": 278,
280
+ "矫{jiao3}": 279,
281
+ "禁{jin1}": 280,
282
+ "禁{jin4}": 281,
283
+ "种{zhong3}": 282,
284
+ "种{zhong4}": 283,
285
+ "称{chen4}": 284,
286
+ "称{cheng1}": 285,
287
+ "空{kong1}": 286,
288
+ "空{kong4}": 287,
289
+ "答{da1}": 288,
290
+ "答{da2}": 289,
291
+ "粘{nian2}": 290,
292
+ "粘{zhan1}": 291,
293
+ "糊{hu2}": 292,
294
+ "糊{hu5}": 293,
295
+ "系{ji4}": 294,
296
+ "系{xi4}": 295,
297
+ "系{xi5}": 296,
298
+ "累{lei2}": 297,
299
+ "累{lei3}": 298,
300
+ "累{lei4}": 299,
301
+ "累{lei5}": 300,
302
+ "纤{qian4}": 301,
303
+ "纤{xian1}": 302,
304
+ "结{jie1}": 303,
305
+ "结{jie2}": 304,
306
+ "结{jie5}": 305,
307
+ "给{gei3}": 306,
308
+ "给{gei5}": 307,
309
+ "给{ji3}": 308,
310
+ "缝{feng2}": 309,
311
+ "缝{feng4}": 310,
312
+ "缝{feng5}": 311,
313
+ "肖{xiao1}": 312,
314
+ "肖{xiao4}": 313,
315
+ "背{bei1}": 314,
316
+ "背{bei4}": 315,
317
+ "脏{zang1}": 316,
318
+ "脏{zang4}": 317,
319
+ "舍{she3}": 318,
320
+ "舍{she4}": 319,
321
+ "色{se4}": 320,
322
+ "色{shai3}": 321,
323
+ "落{lao4}": 322,
324
+ "落{luo4}": 323,
325
+ "蒙{meng1}": 324,
326
+ "蒙{meng2}": 325,
327
+ "蒙{meng3}": 326,
328
+ "薄{bao2}": 327,
329
+ "薄{bo2}": 328,
330
+ "薄{bo4}": 329,
331
+ "藏{cang2}": 330,
332
+ "藏{zang4}": 331,
333
+ "血{xie3}": 332,
334
+ "血{xue4}": 333,
335
+ "行{hang2}": 334,
336
+ "行{hang5}": 335,
337
+ "行{heng5}": 336,
338
+ "行{xing2}": 337,
339
+ "行{xing4}": 338,
340
+ "要{yao1}": 339,
341
+ "要{yao4}": 340,
342
+ "观{guan1}": 341,
343
+ "观{guan4}": 342,
344
+ "觉{jiao4}": 343,
345
+ "觉{jiao5}": 344,
346
+ "觉{jve2}": 345,
347
+ "角{jiao3}": 346,
348
+ "角{jve2}": 347,
349
+ "解{jie3}": 348,
350
+ "解{jie4}": 349,
351
+ "解{xie4}": 350,
352
+ "说{shui4}": 351,
353
+ "说{shuo1}": 352,
354
+ "调{diao4}": 353,
355
+ "调{tiao2}": 354,
356
+ "踏{ta1}": 355,
357
+ "踏{ta4}": 356,
358
+ "车{che1}": 357,
359
+ "车{jv1}": 358,
360
+ "转{zhuan3}": 359,
361
+ "转{zhuan4}": 360,
362
+ "载{zai3}": 361,
363
+ "载{zai4}": 362,
364
+ "还{hai2}": 363,
365
+ "还{huan2}": 364,
366
+ "遂{sui2}": 365,
367
+ "遂{sui4}": 366,
368
+ "都{dou1}": 367,
369
+ "都{du1}": 368,
370
+ "重{chong2}": 369,
371
+ "重{zhong4}": 370,
372
+ "量{liang2}": 371,
373
+ "量{liang4}": 372,
374
+ "量{liang5}": 373,
375
+ "钻{zuan1}": 374,
376
+ "钻{zuan4}": 375,
377
+ "铺{pu1}": 376,
378
+ "铺{pu4}": 377,
379
+ "长{chang2}": 378,
380
+ "长{chang3}": 379,
381
+ "长{zhang3}": 380,
382
+ "间{jian1}": 381,
383
+ "间{jian4}": 382,
384
+ "降{jiang4}": 383,
385
+ "降{xiang2}": 384,
386
+ "难{nan2}": 385,
387
+ "难{nan4}": 386,
388
+ "难{nan5}": 387,
389
+ "露{lou4}": 388,
390
+ "露{lu4}": 389,
391
+ "鲜{xian1}": 390,
392
+ "鲜{xian3}": 391
393
+ }
soundsation/g2p/sources/g2p_chinese_model/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
soundsation/g2p/sources/pinyin_2_bpmf.txt ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ a ㄚ
2
+ ai ㄞ
3
+ an ㄢ
4
+ ang ㄤ
5
+ ao ㄠ
6
+ ba ㄅㄚ
7
+ bai ㄅㄞ
8
+ ban ㄅㄢ
9
+ bang ㄅㄤ
10
+ bao ㄅㄠ
11
+ bei ㄅㄟ
12
+ ben ㄅㄣ
13
+ beng ㄅㄥ
14
+ bi ㄅㄧ
15
+ bian ㄅㄧㄢ
16
+ biang ㄅㄧㄤ
17
+ biao ㄅㄧㄠ
18
+ bie ㄅㄧㄝ
19
+ bin ㄅㄧㄣ
20
+ bing ㄅㄧㄥ
21
+ bo ㄅㄛ
22
+ bu ㄅㄨ
23
+ ca ㄘㄚ
24
+ cai ㄘㄞ
25
+ can ㄘㄢ
26
+ cang ㄘㄤ
27
+ cao ㄘㄠ
28
+ ce ㄘㄜ
29
+ cen ㄘㄣ
30
+ ceng ㄘㄥ
31
+ cha ㄔㄚ
32
+ chai ㄔㄞ
33
+ chan ㄔㄢ
34
+ chang ㄔㄤ
35
+ chao ㄔㄠ
36
+ che ㄔㄜ
37
+ chen ㄔㄣ
38
+ cheng ㄔㄥ
39
+ chi ㄔ
40
+ chong ㄔㄨㄥ
41
+ chou ㄔㄡ
42
+ chu ㄔㄨ
43
+ chua ㄔㄨㄚ
44
+ chuai ㄔㄨㄞ
45
+ chuan ㄔㄨㄢ
46
+ chuang ㄔㄨㄤ
47
+ chui ㄔㄨㄟ
48
+ chun ㄔㄨㄣ
49
+ chuo ㄔㄨㄛ
50
+ ci ㄘ
51
+ cong ㄘㄨㄥ
52
+ cou ㄘㄡ
53
+ cu ㄘㄨ
54
+ cuan ㄘㄨㄢ
55
+ cui ㄘㄨㄟ
56
+ cun ㄘㄨㄣ
57
+ cuo ㄘㄨㄛ
58
+ da ㄉㄚ
59
+ dai ㄉㄞ
60
+ dan ㄉㄢ
61
+ dang ㄉㄤ
62
+ dao ㄉㄠ
63
+ de ㄉㄜ
64
+ dei ㄉㄟ
65
+ den ㄉㄣ
66
+ deng ㄉㄥ
67
+ di ㄉㄧ
68
+ dia ㄉㄧㄚ
69
+ dian ㄉㄧㄢ
70
+ diao ㄉㄧㄠ
71
+ die ㄉㄧㄝ
72
+ din ㄉㄧㄣ
73
+ ding ㄉㄧㄥ
74
+ diu ㄉㄧㄡ
75
+ dong ㄉㄨㄥ
76
+ dou ㄉㄡ
77
+ du ㄉㄨ
78
+ duan ㄉㄨㄢ
79
+ dui ㄉㄨㄟ
80
+ dun ㄉㄨㄣ
81
+ duo ㄉㄨㄛ
82
+ e ㄜ
83
+ ei ㄟ
84
+ en ㄣ
85
+ eng ㄥ
86
+ er ㄦ
87
+ fa ㄈㄚ
88
+ fan ㄈㄢ
89
+ fang ㄈㄤ
90
+ fei ㄈㄟ
91
+ fen ㄈㄣ
92
+ feng ㄈㄥ
93
+ fo ㄈㄛ
94
+ fou ㄈㄡ
95
+ fu ㄈㄨ
96
+ ga ㄍㄚ
97
+ gai ㄍㄞ
98
+ gan ㄍㄢ
99
+ gang ㄍㄤ
100
+ gao ㄍㄠ
101
+ ge ㄍㄜ
102
+ gei ㄍㄟ
103
+ gen ㄍㄣ
104
+ geng ㄍㄥ
105
+ gong ㄍㄨㄥ
106
+ gou ㄍㄡ
107
+ gu ㄍㄨ
108
+ gua ㄍㄨㄚ
109
+ guai ㄍㄨㄞ
110
+ guan ㄍㄨㄢ
111
+ guang ㄍㄨㄤ
112
+ gui ㄍㄨㄟ
113
+ gun ㄍㄨㄣ
114
+ guo ㄍㄨㄛ
115
+ ha ㄏㄚ
116
+ hai ㄏㄞ
117
+ han ㄏㄢ
118
+ hang ㄏㄤ
119
+ hao ㄏㄠ
120
+ he ㄏㄜ
121
+ hei ㄏㄟ
122
+ hen ㄏㄣ
123
+ heng ㄏㄥ
124
+ hm ㄏㄇ
125
+ hong ㄏㄨㄥ
126
+ hou ㄏㄡ
127
+ hu ㄏㄨ
128
+ hua ㄏㄨㄚ
129
+ huai ㄏㄨㄞ
130
+ huan ㄏㄨㄢ
131
+ huang ㄏㄨㄤ
132
+ hui ㄏㄨㄟ
133
+ hun ㄏㄨㄣ
134
+ huo ㄏㄨㄛ
135
+ ji ㄐㄧ
136
+ jia ㄐㄧㄚ
137
+ jian ㄐㄧㄢ
138
+ jiang ㄐㄧㄤ
139
+ jiao ㄐㄧㄠ
140
+ jie ㄐㄧㄝ
141
+ jin ㄐㄧㄣ
142
+ jing ㄐㄧㄥ
143
+ jiong ㄐㄩㄥ
144
+ jiu ㄐㄧㄡ
145
+ ju ㄐㄩ
146
+ jv ㄐㄩ
147
+ juan ㄐㄩㄢ
148
+ jvan ㄐㄩㄢ
149
+ jue ㄐㄩㄝ
150
+ jve ㄐㄩㄝ
151
+ jun ㄐㄩㄣ
152
+ ka ㄎㄚ
153
+ kai ㄎㄞ
154
+ kan ㄎㄢ
155
+ kang ㄎㄤ
156
+ kao ㄎㄠ
157
+ ke ㄎㄜ
158
+ kei ㄎㄟ
159
+ ken ㄎㄣ
160
+ keng ㄎㄥ
161
+ kong ㄎㄨㄥ
162
+ kou ㄎㄡ
163
+ ku ㄎㄨ
164
+ kua ㄎㄨㄚ
165
+ kuai ㄎㄨㄞ
166
+ kuan ㄎㄨㄢ
167
+ kuang ㄎㄨㄤ
168
+ kui ㄎㄨㄟ
169
+ kun ㄎㄨㄣ
170
+ kuo ㄎㄨㄛ
171
+ la ㄌㄚ
172
+ lai ㄌㄞ
173
+ lan ㄌㄢ
174
+ lang ㄌㄤ
175
+ lao ㄌㄠ
176
+ le ㄌㄜ
177
+ lei ㄌㄟ
178
+ leng ㄌㄥ
179
+ li ㄌㄧ
180
+ lia ㄌㄧㄚ
181
+ lian ㄌㄧㄢ
182
+ liang ㄌㄧㄤ
183
+ liao ㄌㄧㄠ
184
+ lie ㄌㄧㄝ
185
+ lin ㄌㄧㄣ
186
+ ling ㄌㄧㄥ
187
+ liu ㄌㄧㄡ
188
+ lo ㄌㄛ
189
+ long ㄌㄨㄥ
190
+ lou ㄌㄡ
191
+ lu ㄌㄨ
192
+ luan ㄌㄨㄢ
193
+ lue ㄌㄩㄝ
194
+ lun ㄌㄨㄣ
195
+ luo ㄌㄨㄛ
196
+ lv ㄌㄩ
197
+ lve ㄌㄩㄝ
198
+ m ㄇㄨ
199
+ ma ㄇㄚ
200
+ mai ㄇㄞ
201
+ man ㄇㄢ
202
+ mang ㄇㄤ
203
+ mao ㄇㄠ
204
+ me ㄇㄜ
205
+ mei ㄇㄟ
206
+ men ㄇㄣ
207
+ meng ㄇㄥ
208
+ mi ㄇㄧ
209
+ mian ㄇㄧㄢ
210
+ miao ㄇㄧㄠ
211
+ mie ㄇㄧㄝ
212
+ min ㄇㄧㄣ
213
+ ming ㄇㄧㄥ
214
+ miu ㄇㄧㄡ
215
+ mo ㄇㄛ
216
+ mou ㄇㄡ
217
+ mu ㄇㄨ
218
+ n ㄣ
219
+ na ㄋㄚ
220
+ nai ㄋㄞ
221
+ nan ㄋㄢ
222
+ nang ㄋㄤ
223
+ nao ㄋㄠ
224
+ ne ㄋㄜ
225
+ nei ㄋㄟ
226
+ nen ㄋㄣ
227
+ neng ㄋㄥ
228
+ ng ㄣ
229
+ ni ㄋㄧ
230
+ nian ㄋㄧㄢ
231
+ niang ㄋㄧㄤ
232
+ niao ㄋㄧㄠ
233
+ nie ㄋㄧㄝ
234
+ nin ㄋㄧㄣ
235
+ ning ㄋㄧㄥ
236
+ niu ㄋㄧㄡ
237
+ nong ㄋㄨㄥ
238
+ nou ㄋㄡ
239
+ nu ㄋㄨ
240
+ nuan ㄋㄨㄢ
241
+ nue ㄋㄩㄝ
242
+ nun ㄋㄨㄣ
243
+ nuo ㄋㄨㄛ
244
+ nv ㄋㄩ
245
+ nve ㄋㄩㄝ
246
+ o ㄛ
247
+ ou ㄡ
248
+ pa ㄆㄚ
249
+ pai ㄆㄞ
250
+ pan ㄆㄢ
251
+ pang ㄆㄤ
252
+ pao ㄆㄠ
253
+ pei ㄆㄟ
254
+ pen ㄆㄣ
255
+ peng ㄆㄥ
256
+ pi ㄆㄧ
257
+ pian ㄆㄧㄢ
258
+ piao ㄆㄧㄠ
259
+ pie ㄆㄧㄝ
260
+ pin ㄆㄧㄣ
261
+ ping ㄆㄧㄥ
262
+ po ㄆㄛ
263
+ pou ㄆㄡ
264
+ pu ㄆㄨ
265
+ qi ㄑㄧ
266
+ qia ㄑㄧㄚ
267
+ qian ㄑㄧㄢ
268
+ qiang ㄑㄧㄤ
269
+ qiao ㄑㄧㄠ
270
+ qie ㄑㄧㄝ
271
+ qin ㄑㄧㄣ
272
+ qing ㄑㄧㄥ
273
+ qiong ㄑㄩㄥ
274
+ qiu ㄑㄧㄡ
275
+ qu ㄑㄩ
276
+ quan ㄑㄩㄢ
277
+ qvan ㄑㄩㄢ
278
+ que ㄑㄩㄝ
279
+ qun ㄑㄩㄣ
280
+ ran ㄖㄢ
281
+ rang ㄖㄤ
282
+ rao ㄖㄠ
283
+ re ㄖㄜ
284
+ ren ㄖㄣ
285
+ reng ㄖㄥ
286
+ ri ㄖ
287
+ rong ㄖㄨㄥ
288
+ rou ㄖㄡ
289
+ ru ㄖㄨ
290
+ rua ㄖㄨㄚ
291
+ ruan ㄖㄨㄢ
292
+ rui ㄖㄨㄟ
293
+ run ㄖㄨㄣ
294
+ ruo ㄖㄨㄛ
295
+ sa ㄙㄚ
296
+ sai ㄙㄞ
297
+ san ㄙㄢ
298
+ sang ㄙㄤ
299
+ sao ㄙㄠ
300
+ se ㄙㄜ
301
+ sen ㄙㄣ
302
+ seng ㄙㄥ
303
+ sha ㄕㄚ
304
+ shai ㄕㄞ
305
+ shan ㄕㄢ
306
+ shang ㄕㄤ
307
+ shao ㄕㄠ
308
+ she ㄕㄜ
309
+ shei ㄕㄟ
310
+ shen ㄕㄣ
311
+ sheng ㄕㄥ
312
+ shi ㄕ
313
+ shou ㄕㄡ
314
+ shu ㄕㄨ
315
+ shua ㄕㄨㄚ
316
+ shuai ㄕㄨㄞ
317
+ shuan ㄕㄨㄢ
318
+ shuang ㄕㄨㄤ
319
+ shui ㄕㄨㄟ
320
+ shun ㄕㄨㄣ
321
+ shuo ㄕㄨㄛ
322
+ si ㄙ
323
+ song ㄙㄨㄥ
324
+ sou ㄙㄡ
325
+ su ㄙㄨ
326
+ suan ㄙㄨㄢ
327
+ sui ㄙㄨㄟ
328
+ sun ㄙㄨㄣ
329
+ suo ㄙㄨㄛ
330
+ ta ㄊㄚ
331
+ tai ㄊㄞ
332
+ tan ㄊㄢ
333
+ tang ㄊㄤ
334
+ tao ㄊㄠ
335
+ te ㄊㄜ
336
+ tei ㄊㄟ
337
+ teng ㄊㄥ
338
+ ti ㄊㄧ
339
+ tian ㄊㄧㄢ
340
+ tiao ㄊㄧㄠ
341
+ tie ㄊㄧㄝ
342
+ ting ㄊㄧㄥ
343
+ tong ㄊㄨㄥ
344
+ tou ㄊㄡ
345
+ tsuo ㄘㄨㄛ
346
+ tu ㄊㄨ
347
+ tuan ㄊㄨㄢ
348
+ tui ㄊㄨㄟ
349
+ tun ㄊㄨㄣ
350
+ tuo ㄊㄨㄛ
351
+ tzan ㄗㄢ
352
+ wa ㄨㄚ
353
+ wai ㄨㄞ
354
+ wan ㄨㄢ
355
+ wang ㄨㄤ
356
+ wei ㄨㄟ
357
+ wen ㄨㄣ
358
+ weng ㄨㄥ
359
+ wo ㄨㄛ
360
+ wong ㄨㄥ
361
+ wu ㄨ
362
+ xi ㄒㄧ
363
+ xia ㄒㄧㄚ
364
+ xian ㄒㄧㄢ
365
+ xiang ㄒㄧㄤ
366
+ xiao ㄒㄧㄠ
367
+ xie ㄒㄧㄝ
368
+ xin ㄒㄧㄣ
369
+ xing ㄒㄧㄥ
370
+ xiong ㄒㄩㄥ
371
+ xiu ㄒㄧㄡ
372
+ xu ㄒㄩ
373
+ xuan ㄒㄩㄢ
374
+ xue ㄒㄩㄝ
375
+ xun ㄒㄩㄣ
376
+ ya ㄧㄚ
377
+ yai ㄧㄞ
378
+ yan ㄧㄢ
379
+ yang ㄧㄤ
380
+ yao ㄧㄠ
381
+ ye ㄧㄝ
382
+ yi ㄧ
383
+ yin ㄧㄣ
384
+ ying ㄧㄥ
385
+ yo ㄧㄛ
386
+ yong ㄩㄥ
387
+ you ㄧㄡ
388
+ yu ㄩ
389
+ yuan ㄩㄢ
390
+ yue ㄩㄝ
391
+ yve ㄩㄝ
392
+ yun ㄩㄣ
393
+ za ㄗㄚ
394
+ zai ㄗㄞ
395
+ zan ㄗㄢ
396
+ zang ㄗㄤ
397
+ zao ㄗㄠ
398
+ ze ㄗㄜ
399
+ zei ㄗㄟ
400
+ zen ㄗㄣ
401
+ zeng ㄗㄥ
402
+ zha ㄓㄚ
403
+ zhai ㄓㄞ
404
+ zhan ㄓㄢ
405
+ zhang ㄓㄤ
406
+ zhao ㄓㄠ
407
+ zhe ㄓㄜ
408
+ zhei ㄓㄟ
409
+ zhen ㄓㄣ
410
+ zheng ㄓㄥ
411
+ zhi ㄓ
412
+ zhong ㄓㄨㄥ
413
+ zhou ㄓㄡ
414
+ zhu ㄓㄨ
415
+ zhua ㄓㄨㄚ
416
+ zhuai ㄓㄨㄞ
417
+ zhuan ㄓㄨㄢ
418
+ zhuang ㄓㄨㄤ
419
+ zhui ㄓㄨㄟ
420
+ zhun ㄓㄨㄣ
421
+ zhuo ㄓㄨㄛ
422
+ zi ㄗ
423
+ zong ㄗㄨㄥ
424
+ zou ㄗㄡ
425
+ zu ㄗㄨ
426
+ zuan ㄗㄨㄢ
427
+ zui ㄗㄨㄟ
428
+ zun ㄗㄨㄣ
429
+ zuo ㄗㄨㄛ
soundsation/g2p/utils/front_utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+
8
+
9
+ def generate_poly_lexicon(file_path: str):
10
+ """Generate poly char lexicon for Mandarin Chinese."""
11
+ poly_dict = {}
12
+
13
+ with open(file_path, "r", encoding="utf-8") as readf:
14
+ txt_list = readf.readlines()
15
+ for txt in txt_list:
16
+ word = txt.strip("\n")
17
+ if word not in poly_dict:
18
+ poly_dict[word] = 1
19
+ readf.close()
20
+ return poly_dict
soundsation/g2p/utils/g2p.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from phonemizer.backend import EspeakBackend
7
+ from phonemizer.separator import Separator
8
+ from phonemizer.utils import list2str, str2list
9
+ from typing import List, Union
10
+ import os
11
+ import json
12
+ import sys
13
+
14
+ # separator=Separator(phone=' ', word=' _ ', syllable='|'),
15
+ separator = Separator(word=" _ ", syllable="|", phone=" ")
16
+
17
+ phonemizer_zh = EspeakBackend(
18
+ "cmn", preserve_punctuation=False, with_stress=False, language_switch="remove-flags"
19
+ )
20
+ # phonemizer_zh.separator = separator
21
+
22
+ phonemizer_en = EspeakBackend(
23
+ "en-us",
24
+ preserve_punctuation=False,
25
+ with_stress=False,
26
+ language_switch="remove-flags",
27
+ )
28
+ # phonemizer_en.separator = separator
29
+
30
+ phonemizer_ja = EspeakBackend(
31
+ "ja", preserve_punctuation=False, with_stress=False, language_switch="remove-flags"
32
+ )
33
+ # phonemizer_ja.separator = separator
34
+
35
+ phonemizer_ko = EspeakBackend(
36
+ "ko", preserve_punctuation=False, with_stress=False, language_switch="remove-flags"
37
+ )
38
+ # phonemizer_ko.separator = separator
39
+
40
+ phonemizer_fr = EspeakBackend(
41
+ "fr-fr",
42
+ preserve_punctuation=False,
43
+ with_stress=False,
44
+ language_switch="remove-flags",
45
+ )
46
+ # phonemizer_fr.separator = separator
47
+
48
+ phonemizer_de = EspeakBackend(
49
+ "de", preserve_punctuation=False, with_stress=False, language_switch="remove-flags"
50
+ )
51
+ # phonemizer_de.separator = separator
52
+
53
+
54
+ lang2backend = {
55
+ "zh": phonemizer_zh,
56
+ "ja": phonemizer_ja,
57
+ "en": phonemizer_en,
58
+ "fr": phonemizer_fr,
59
+ "ko": phonemizer_ko,
60
+ "de": phonemizer_de,
61
+ }
62
+
63
+ with open("./soundsation/g2p/utils/mls_en.json", "r") as f:
64
+ json_data = f.read()
65
+ token = json.loads(json_data)
66
+
67
+
68
+ def phonemizer_g2p(text, language):
69
+ langbackend = lang2backend[language]
70
+ phonemes = _phonemize(
71
+ langbackend,
72
+ text,
73
+ separator,
74
+ strip=True,
75
+ njobs=1,
76
+ prepend_text=False,
77
+ preserve_empty_lines=False,
78
+ )
79
+ token_id = []
80
+ if isinstance(phonemes, list):
81
+ for phone in phonemes:
82
+ phonemes_split = phone.split(" ")
83
+ token_id.append([token[p] for p in phonemes_split if p in token])
84
+ else:
85
+ phonemes_split = phonemes.split(" ")
86
+ token_id = [token[p] for p in phonemes_split if p in token]
87
+ return phonemes, token_id
88
+
89
+
90
+ def _phonemize( # pylint: disable=too-many-arguments
91
+ backend,
92
+ text: Union[str, List[str]],
93
+ separator: Separator,
94
+ strip: bool,
95
+ njobs: int,
96
+ prepend_text: bool,
97
+ preserve_empty_lines: bool,
98
+ ):
99
+ """Auxiliary function to phonemize()
100
+
101
+ Does the phonemization and returns the phonemized text. Raises a
102
+ RuntimeError on error.
103
+
104
+ """
105
+ # remember the text type for output (either list or string)
106
+ text_type = type(text)
107
+
108
+ # force the text as a list
109
+ text = [line.strip(os.linesep) for line in str2list(text)]
110
+
111
+ # if preserving empty lines, note the index of each empty line
112
+ if preserve_empty_lines:
113
+ empty_lines = [n for n, line in enumerate(text) if not line.strip()]
114
+
115
+ # ignore empty lines
116
+ text = [line for line in text if line.strip()]
117
+
118
+ if text:
119
+ # phonemize the text
120
+ phonemized = backend.phonemize(
121
+ text, separator=separator, strip=strip, njobs=njobs
122
+ )
123
+ else:
124
+ phonemized = []
125
+
126
+ # if preserving empty lines, reinsert them into text and phonemized lists
127
+ if preserve_empty_lines:
128
+ for i in empty_lines: # noqa
129
+ if prepend_text:
130
+ text.insert(i, "")
131
+ phonemized.insert(i, "")
132
+
133
+ # at that point, the phonemized text is a list of str. Format it as
134
+ # expected by the parameters
135
+ if prepend_text:
136
+ return list(zip(text, phonemized))
137
+ if text_type == str:
138
+ return list2str(phonemized)
139
+ return phonemized
soundsation/g2p/utils/log.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import functools
8
+ import logging
9
+
10
+ __all__ = [
11
+ "logger",
12
+ ]
13
+
14
+
15
+ class Logger(object):
16
+ def __init__(self, name: str = None):
17
+ name = "PaddleSpeech" if not name else name
18
+ self.logger = logging.getLogger(name)
19
+
20
+ log_config = {
21
+ "DEBUG": 10,
22
+ "INFO": 20,
23
+ "TRAIN": 21,
24
+ "EVAL": 22,
25
+ "WARNING": 30,
26
+ "ERROR": 40,
27
+ "CRITICAL": 50,
28
+ "EXCEPTION": 100,
29
+ }
30
+ for key, level in log_config.items():
31
+ logging.addLevelName(level, key)
32
+ if key == "EXCEPTION":
33
+ self.__dict__[key.lower()] = self.logger.exception
34
+ else:
35
+ self.__dict__[key.lower()] = functools.partial(self.__call__, level)
36
+
37
+ self.format = logging.Formatter(
38
+ fmt="[%(asctime)-15s] [%(levelname)8s] - %(message)s"
39
+ )
40
+
41
+ self.handler = logging.StreamHandler()
42
+ self.handler.setFormatter(self.format)
43
+
44
+ self.logger.addHandler(self.handler)
45
+ self.logger.setLevel(logging.INFO)
46
+ self.logger.propagate = False
47
+
48
+ def __call__(self, log_level: str, msg: str):
49
+ self.logger.log(log_level, msg)
50
+
51
+
52
+ logger = Logger()
soundsation/g2p/utils/mls_en.json ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[UNK]": 0,
3
+ "_": 1,
4
+ "b": 2,
5
+ "d": 3,
6
+ "f": 4,
7
+ "h": 5,
8
+ "i": 6,
9
+ "j": 7,
10
+ "k": 8,
11
+ "l": 9,
12
+ "m": 10,
13
+ "n": 11,
14
+ "p": 12,
15
+ "r": 13,
16
+ "s": 14,
17
+ "t": 15,
18
+ "v": 16,
19
+ "w": 17,
20
+ "x": 18,
21
+ "z": 19,
22
+ "æ": 20,
23
+ "ç": 21,
24
+ "ð": 22,
25
+ "ŋ": 23,
26
+ "ɐ": 24,
27
+ "ɔ": 25,
28
+ "ə": 26,
29
+ "ɚ": 27,
30
+ "ɛ": 28,
31
+ "ɡ": 29,
32
+ "ɪ": 30,
33
+ "ɬ": 31,
34
+ "ɹ": 32,
35
+ "ɾ": 33,
36
+ "ʃ": 34,
37
+ "ʊ": 35,
38
+ "ʌ": 36,
39
+ "ʒ": 37,
40
+ "ʔ": 38,
41
+ "θ": 39,
42
+ "ᵻ": 40,
43
+ "aɪ": 41,
44
+ "aʊ": 42,
45
+ "dʒ": 43,
46
+ "eɪ": 44,
47
+ "iə": 45,
48
+ "iː": 46,
49
+ "n̩": 47,
50
+ "oʊ": 48,
51
+ "oː": 49,
52
+ "tʃ": 50,
53
+ "uː": 51,
54
+ "ææ": 52,
55
+ "ɐɐ": 53,
56
+ "ɑː": 54,
57
+ "ɑ̃": 55,
58
+ "ɔɪ": 56,
59
+ "ɔː": 57,
60
+ "ɔ̃": 58,
61
+ "əl": 59,
62
+ "ɛɹ": 60,
63
+ "ɜː": 61,
64
+ "ɡʲ": 62,
65
+ "ɪɹ": 63,
66
+ "ʊɹ": 64,
67
+ "aɪə": 65,
68
+ "aɪɚ": 66,
69
+ "iːː": 67,
70
+ "oːɹ": 68,
71
+ "ɑːɹ": 69,
72
+ "ɔːɹ": 70,
73
+
74
+ "1": 71,
75
+ "a": 72,
76
+ "e": 73,
77
+ "o": 74,
78
+ "q": 75,
79
+ "u": 76,
80
+ "y": 77,
81
+ "ɑ": 78,
82
+ "ɒ": 79,
83
+ "ɕ": 80,
84
+ "ɣ": 81,
85
+ "ɫ": 82,
86
+ "ɯ": 83,
87
+ "ʐ": 84,
88
+ "ʲ": 85,
89
+ "a1": 86,
90
+ "a2": 87,
91
+ "a5": 88,
92
+ "ai": 89,
93
+ "aɜ": 90,
94
+ "aː": 91,
95
+ "ei": 92,
96
+ "eə": 93,
97
+ "i.": 94,
98
+ "i1": 95,
99
+ "i2": 96,
100
+ "i5": 97,
101
+ "io": 98,
102
+ "iɑ": 99,
103
+ "iɛ": 100,
104
+ "iɜ": 101,
105
+ "i̪": 102,
106
+ "kh": 103,
107
+ "nʲ": 104,
108
+ "o1": 105,
109
+ "o2": 106,
110
+ "o5": 107,
111
+ "ou": 108,
112
+ "oɜ": 109,
113
+ "ph": 110,
114
+ "s.": 111,
115
+ "th": 112,
116
+ "ts": 113,
117
+ "tɕ": 114,
118
+ "u1": 115,
119
+ "u2": 116,
120
+ "u5": 117,
121
+ "ua": 118,
122
+ "uo": 119,
123
+ "uə": 120,
124
+ "uɜ": 121,
125
+ "y1": 122,
126
+ "y2": 123,
127
+ "y5": 124,
128
+ "yu": 125,
129
+ "yæ": 126,
130
+ "yə": 127,
131
+ "yɛ": 128,
132
+ "yɜ": 129,
133
+ "ŋɜ": 130,
134
+ "ŋʲ": 131,
135
+ "ɑ1": 132,
136
+ "ɑ2": 133,
137
+ "ɑ5": 134,
138
+ "ɑu": 135,
139
+ "ɑɜ": 136,
140
+ "ɑʲ": 137,
141
+ "ə1": 138,
142
+ "ə2": 139,
143
+ "ə5": 140,
144
+ "ər": 141,
145
+ "əɜ": 142,
146
+ "əʊ": 143,
147
+ "ʊə": 144,
148
+ "ai1": 145,
149
+ "ai2": 146,
150
+ "ai5": 147,
151
+ "aiɜ": 148,
152
+ "ei1": 149,
153
+ "ei2": 150,
154
+ "ei5": 151,
155
+ "eiɜ": 152,
156
+ "i.1": 153,
157
+ "i.2": 154,
158
+ "i.5": 155,
159
+ "i.ɜ": 156,
160
+ "io5": 157,
161
+ "iou": 158,
162
+ "iɑ1": 159,
163
+ "iɑ2": 160,
164
+ "iɑ5": 161,
165
+ "iɑɜ": 162,
166
+ "iɛ1": 163,
167
+ "iɛ2": 164,
168
+ "iɛ5": 165,
169
+ "iɛɜ": 166,
170
+ "i̪1": 167,
171
+ "i̪2": 168,
172
+ "i̪5": 169,
173
+ "i̪ɜ": 170,
174
+ "onɡ": 171,
175
+ "ou1": 172,
176
+ "ou2": 173,
177
+ "ou5": 174,
178
+ "ouɜ": 175,
179
+ "ts.": 176,
180
+ "tsh": 177,
181
+ "tɕh": 178,
182
+ "u5ʲ": 179,
183
+ "ua1": 180,
184
+ "ua2": 181,
185
+ "ua5": 182,
186
+ "uai": 183,
187
+ "uaɜ": 184,
188
+ "uei": 185,
189
+ "uo1": 186,
190
+ "uo2": 187,
191
+ "uo5": 188,
192
+ "uoɜ": 189,
193
+ "uə1": 190,
194
+ "uə2": 191,
195
+ "uə5": 192,
196
+ "uəɜ": 193,
197
+ "yiɜ": 194,
198
+ "yu2": 195,
199
+ "yu5": 196,
200
+ "yæ2": 197,
201
+ "yæ5": 198,
202
+ "yæɜ": 199,
203
+ "yə2": 200,
204
+ "yə5": 201,
205
+ "yəɜ": 202,
206
+ "yɛ1": 203,
207
+ "yɛ2": 204,
208
+ "yɛ5": 205,
209
+ "yɛɜ": 206,
210
+ "ɑu1": 207,
211
+ "ɑu2": 208,
212
+ "ɑu5": 209,
213
+ "ɑuɜ": 210,
214
+ "ər1": 211,
215
+ "ər2": 212,
216
+ "ər5": 213,
217
+ "ərɜ": 214,
218
+ "əː1": 215,
219
+ "iou1": 216,
220
+ "iou2": 217,
221
+ "iou5": 218,
222
+ "iouɜ": 219,
223
+ "onɡ1": 220,
224
+ "onɡ2": 221,
225
+ "onɡ5": 222,
226
+ "onɡɜ": 223,
227
+ "ts.h": 224,
228
+ "uai2": 225,
229
+ "uai5": 226,
230
+ "uaiɜ": 227,
231
+ "uei1": 228,
232
+ "uei2": 229,
233
+ "uei5": 230,
234
+ "ueiɜ": 231,
235
+ "uoɜʲ": 232,
236
+ "yɛ5ʲ": 233,
237
+ "ɑu2ʲ": 234,
238
+
239
+ "2": 235,
240
+ "5": 236,
241
+ "ɜ": 237,
242
+ "ʂ": 238,
243
+ "dʑ": 239,
244
+ "iɪ": 240,
245
+ "uɪ": 241,
246
+ "xʲ": 242,
247
+ "ɑt": 243,
248
+ "ɛɜ": 244,
249
+ "ɛː": 245,
250
+ "ɪː": 246,
251
+ "phʲ": 247,
252
+ "ɑ5ʲ": 248,
253
+ "ɑuʲ": 249,
254
+ "ərə": 250,
255
+ "uozʰ": 251,
256
+ "ər1ʲ": 252,
257
+ "tɕhtɕh": 253,
258
+
259
+ "c": 254,
260
+ "ʋ": 255,
261
+ "ʍ": 256,
262
+ "ʑ": 257,
263
+ "ː": 258,
264
+ "aə": 259,
265
+ "eː": 260,
266
+ "hʲ": 261,
267
+ "iʊ": 262,
268
+ "kʲ": 263,
269
+ "lʲ": 264,
270
+ "oə": 265,
271
+ "oɪ": 266,
272
+ "oʲ": 267,
273
+ "pʲ": 268,
274
+ "sʲ": 269,
275
+ "u4": 270,
276
+ "uʲ": 271,
277
+ "yi": 272,
278
+ "yʲ": 273,
279
+ "ŋ2": 274,
280
+ "ŋ5": 275,
281
+ "ŋ̩": 276,
282
+ "ɑɪ": 277,
283
+ "ɑʊ": 278,
284
+ "ɕʲ": 279,
285
+ "ət": 280,
286
+ "əə": 281,
287
+ "əɪ": 282,
288
+ "əʲ": 283,
289
+ "ɛ1": 284,
290
+ "ɛ5": 285,
291
+ "aiə": 286,
292
+ "aiɪ": 287,
293
+ "azʰ": 288,
294
+ "eiə": 289,
295
+ "eiɪ": 290,
296
+ "eiʊ": 291,
297
+ "i.ə": 292,
298
+ "i.ɪ": 293,
299
+ "i.ʊ": 294,
300
+ "ioɜ": 295,
301
+ "izʰ": 296,
302
+ "iɑə": 297,
303
+ "iɑʊ": 298,
304
+ "iɑʲ": 299,
305
+ "iɛə": 300,
306
+ "iɛɪ": 301,
307
+ "iɛʊ": 302,
308
+ "i̪ə": 303,
309
+ "i̪ʊ": 304,
310
+ "khʲ": 305,
311
+ "ouʲ": 306,
312
+ "tsʲ": 307,
313
+ "u2ʲ": 308,
314
+ "uoɪ": 309,
315
+ "uzʰ": 310,
316
+ "uɜʲ": 311,
317
+ "yæɪ": 312,
318
+ "yəʊ": 313,
319
+ "ərt": 314,
320
+ "ərɪ": 315,
321
+ "ərʲ": 316,
322
+ "əːt": 317,
323
+ "iouə": 318,
324
+ "iouʊ": 319,
325
+ "iouʲ": 320,
326
+ "iɛzʰ": 321,
327
+ "onɡə": 322,
328
+ "onɡɪ": 323,
329
+ "onɡʊ": 324,
330
+ "ouzʰ": 325,
331
+ "uai1": 326,
332
+ "ueiɪ": 327,
333
+ "ɑuzʰ": 328,
334
+ "iouzʰ": 329
335
+ }
soundsation/infer/infer.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from einops import rearrange
4
+ import argparse
5
+ import os
6
+ import time
7
+ import random
8
+
9
+ import torch
10
+ import torchaudio
11
+ import numpy as np
12
+ from einops import rearrange
13
+ import io
14
+ import pydub
15
+
16
+ from soundsation.infer.infer_utils import (
17
+ decode_audio,
18
+ get_lrc_token,
19
+ get_negative_style_prompt,
20
+ get_reference_latent,
21
+ get_style_prompt,
22
+ prepare_model,
23
+ eval_song,
24
+ )
25
+
26
+
27
+ def inference(
28
+ cfm_model,
29
+ vae_model,
30
+ eval_model,
31
+ eval_muq,
32
+ cond,
33
+ text,
34
+ duration,
35
+ style_prompt,
36
+ negative_style_prompt,
37
+ steps,
38
+ cfg_strength,
39
+ sway_sampling_coef,
40
+ start_time,
41
+ file_type,
42
+ vocal_flag,
43
+ odeint_method,
44
+ pred_frames,
45
+ batch_infer_num,
46
+ chunked=True,
47
+ ):
48
+ with torch.inference_mode():
49
+ latents, _ = cfm_model.sample(
50
+ cond=cond,
51
+ text=text,
52
+ duration=duration,
53
+ style_prompt=style_prompt,
54
+ negative_style_prompt=negative_style_prompt,
55
+ steps=steps,
56
+ cfg_strength=cfg_strength,
57
+ sway_sampling_coef=sway_sampling_coef,
58
+ start_time=start_time,
59
+ vocal_flag=vocal_flag,
60
+ odeint_method=odeint_method,
61
+ latent_pred_segments=pred_frames,
62
+ batch_infer_num=batch_infer_num
63
+ )
64
+
65
+ outputs = []
66
+ for latent in latents:
67
+ latent = latent.to(torch.float32)
68
+ latent = latent.transpose(1, 2) # [b d t]
69
+
70
+ output = decode_audio(latent, vae_model, chunked=chunked)
71
+
72
+ # Rearrange audio batch to a single sequence
73
+ output = rearrange(output, "b d n -> d (b n)")
74
+
75
+ outputs.append(output)
76
+ if batch_infer_num > 1:
77
+ generated_song = eval_song(eval_model, eval_muq, outputs)
78
+ else:
79
+ generated_song = outputs[0]
80
+ output_tensor = generated_song.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).cpu()
81
+ output_np = output_tensor.numpy().T.astype(np.float32)
82
+ if file_type == 'wav':
83
+ return (44100, output_np)
84
+ else:
85
+ buffer = io.BytesIO()
86
+ output_np = np.int16(output_np * 2**15)
87
+ song = pydub.AudioSegment(output_np.tobytes(), frame_rate=44100, sample_width=2, channels=2)
88
+ if file_type == 'mp3':
89
+ song.export(buffer, format="mp3", bitrate="320k")
90
+ else:
91
+ song.export(buffer, format="ogg", bitrate="320k")
92
+ return buffer.getvalue()
93
+
94
+
95
+
96
+ if __name__ == "__main__":
97
+ parser = argparse.ArgumentParser()
98
+ parser.add_argument(
99
+ "--lrc-path",
100
+ type=str,
101
+ help="lyrics of target song",
102
+ ) # lyrics of target song
103
+ parser.add_argument(
104
+ "--ref-prompt",
105
+ type=str,
106
+ help="reference prompt as style prompt for target song",
107
+ required=False,
108
+ ) # reference prompt as style prompt for target song
109
+ parser.add_argument(
110
+ "--ref-audio-path",
111
+ type=str,
112
+ help="reference audio as style prompt for target song",
113
+ required=False,
114
+ ) # reference audio as style prompt for target song
115
+ parser.add_argument(
116
+ "--chunked",
117
+ action="store_true",
118
+ help="whether to use chunked decoding",
119
+ ) # whether to use chunked decoding
120
+ parser.add_argument(
121
+ "--audio-length",
122
+ type=int,
123
+ default=95,
124
+ choices=[95, 285],
125
+ help="length of generated song",
126
+ ) # length of target song
127
+ parser.add_argument(
128
+ "--repo-id", type=str, default="josephchay/Soundsation-base", help="target model"
129
+ )
130
+ parser.add_argument(
131
+ "--output-dir",
132
+ type=str,
133
+ default="infer/example/output",
134
+ help="output directory fo generated song",
135
+ ) # output directory of target song
136
+ parser.add_argument(
137
+ "--edit",
138
+ action="store_true",
139
+ help="whether to open edit mode",
140
+ ) # edit flag
141
+ parser.add_argument(
142
+ "--ref-song",
143
+ type=str,
144
+ required=False,
145
+ help="reference prompt as latent prompt for editing",
146
+ ) # reference prompt as latent prompt for editing
147
+ parser.add_argument(
148
+ "--edit-segments",
149
+ type=str,
150
+ required=False,
151
+ help="edit segments o target song",
152
+ ) # edit segments o target song
153
+ args = parser.parse_args()
154
+
155
+ assert (
156
+ args.ref_prompt or args.ref_audio_path
157
+ ), "either ref_prompt or ref_audio_path should be provided"
158
+ assert not (
159
+ args.ref_prompt and args.ref_audio_path
160
+ ), "only one of them should be provided"
161
+ if args.edit:
162
+ assert (
163
+ args.ref_song and args.edit_segments
164
+ ), "reference song and edit segments should be provided for editing"
165
+
166
+ device = "cpu"
167
+ if torch.cuda.is_available():
168
+ device = "cuda"
169
+ elif torch.mps.is_available():
170
+ device = "mps"
171
+
172
+ audio_length = args.audio_length
173
+ if audio_length == 95:
174
+ max_frames = 2048
175
+ elif audio_length == 285:
176
+ max_frames = 6144
177
+
178
+ cfm, tokenizer, muq, vae, eval_model, eval_muq = prepare_model(max_frames, device, repo_id=args.repo_id)
179
+
180
+ if args.lrc_path:
181
+ with open(args.lrc_path, "r", encoding='utf-8') as f:
182
+ lrc = f.read()
183
+ else:
184
+ lrc = ""
185
+ lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device)
186
+
187
+ if args.ref_audio_path:
188
+ style_prompt = get_style_prompt(muq, args.ref_audio_path)
189
+ else:
190
+ style_prompt = get_style_prompt(muq, prompt=args.ref_prompt)
191
+
192
+ negative_style_prompt = get_negative_style_prompt(device)
193
+
194
+ latent_prompt, pred_frames = get_reference_latent(device, max_frames, args.edit, args.edit_segments, args.ref_song, vae)
195
+
196
+ s_t = time.time()
197
+ generated_songs = inference(
198
+ cfm_model=cfm,
199
+ vae_model=vae,
200
+ cond=latent_prompt,
201
+ text=lrc_prompt,
202
+ duration=max_frames,
203
+ style_prompt=style_prompt,
204
+ negative_style_prompt=negative_style_prompt,
205
+ start_time=start_time,
206
+ pred_frames=pred_frames,
207
+ chunked=args.chunked,
208
+ )
209
+
210
+
211
+
212
+ generated_song = eval_song(eval_model, eval_muq, generated_songs)
213
+
214
+ # Peak normalize, clip, convert to int16, and save to file
215
+ generated_song = (
216
+ generated_song.to(torch.float32)
217
+ .div(torch.max(torch.abs(generated_song)))
218
+ .clamp(-1, 1)
219
+ .mul(32767)
220
+ .to(torch.int16)
221
+ .cpu()
222
+ )
223
+ e_t = time.time() - s_t
224
+ print(f"inference cost {e_t:.2f} seconds")
225
+ output_dir = args.output_dir
226
+ os.makedirs(output_dir, exist_ok=True)
227
+
228
+ output_path = os.path.join(output_dir, "output.wav")
229
+ torchaudio.save(output_path, generated_song, sample_rate=44100)
soundsation/infer/infer_utils.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import librosa
3
+ import torchaudio
4
+ import random
5
+ import json
6
+ from muq import MuQMuLan, MuQ
7
+ from mutagen.mp3 import MP3
8
+ import os
9
+ import numpy as np
10
+ from huggingface_hub import hf_hub_download
11
+ from hydra.utils import instantiate
12
+ from omegaconf import OmegaConf
13
+ from safetensors.torch import load_file
14
+
15
+ from soundsation.model import DiT, CFM
16
+
17
+ def vae_sample(mean, scale):
18
+ stdev = torch.nn.functional.softplus(scale) + 1e-4
19
+ var = stdev * stdev
20
+ logvar = torch.log(var)
21
+ latents = torch.randn_like(mean) * stdev + mean
22
+
23
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
24
+
25
+ return latents, kl
26
+
27
+ def normalize_audio(y, target_dbfs=0):
28
+ max_amplitude = torch.max(torch.abs(y))
29
+
30
+ target_amplitude = 10.0**(target_dbfs / 20.0)
31
+ scale_factor = target_amplitude / max_amplitude
32
+
33
+ normalized_audio = y * scale_factor
34
+
35
+ return normalized_audio
36
+
37
+ def set_audio_channels(audio, target_channels):
38
+ if target_channels == 1:
39
+ # Convert to mono
40
+ audio = audio.mean(1, keepdim=True)
41
+ elif target_channels == 2:
42
+ # Convert to stereo
43
+ if audio.shape[1] == 1:
44
+ audio = audio.repeat(1, 2, 1)
45
+ elif audio.shape[1] > 2:
46
+ audio = audio[:, :2, :]
47
+ return audio
48
+
49
+ class PadCrop(torch.nn.Module):
50
+ def __init__(self, n_samples, randomize=True):
51
+ super().__init__()
52
+ self.n_samples = n_samples
53
+ self.randomize = randomize
54
+
55
+ def __call__(self, signal):
56
+ n, s = signal.shape
57
+ start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
58
+ end = start + self.n_samples
59
+ output = signal.new_zeros([n, self.n_samples])
60
+ output[:, :min(s, self.n_samples)] = signal[:, start:end]
61
+ return output
62
+
63
+ def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
64
+
65
+ audio = audio.to(device)
66
+
67
+ if in_sr != target_sr:
68
+ resample_tf = torchaudio.transforms.Resample(in_sr, target_sr).to(device)
69
+ audio = resample_tf(audio)
70
+ if target_length is None:
71
+ target_length = audio.shape[-1]
72
+ audio = PadCrop(target_length, randomize=False)(audio)
73
+
74
+ # Add batch dimension
75
+ if audio.dim() == 1:
76
+ audio = audio.unsqueeze(0).unsqueeze(0)
77
+ elif audio.dim() == 2:
78
+ audio = audio.unsqueeze(0)
79
+
80
+ audio = set_audio_channels(audio, target_channels)
81
+
82
+ return audio
83
+
84
+ def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
85
+ downsampling_ratio = 2048
86
+ io_channels = 2
87
+ if not chunked:
88
+ return vae_model.decode_export(latents)
89
+ else:
90
+ # chunked decoding
91
+ hop_size = chunk_size - overlap
92
+ total_size = latents.shape[2]
93
+ batch_size = latents.shape[0]
94
+ chunks = []
95
+ i = 0
96
+ for i in range(0, total_size - chunk_size + 1, hop_size):
97
+ chunk = latents[:, :, i : i + chunk_size]
98
+ chunks.append(chunk)
99
+ if i + chunk_size != total_size:
100
+ # Final chunk
101
+ chunk = latents[:, :, -chunk_size:]
102
+ chunks.append(chunk)
103
+ chunks = torch.stack(chunks)
104
+ num_chunks = chunks.shape[0]
105
+ # samples_per_latent is just the downsampling ratio
106
+ samples_per_latent = downsampling_ratio
107
+ # Create an empty waveform, we will populate it with chunks as decode them
108
+ y_size = total_size * samples_per_latent
109
+ y_final = torch.zeros((batch_size, io_channels, y_size)).to(latents.device)
110
+ for i in range(num_chunks):
111
+ x_chunk = chunks[i, :]
112
+ # decode the chunk
113
+ y_chunk = vae_model.decode_export(x_chunk)
114
+ # figure out where to put the audio along the time domain
115
+ if i == num_chunks - 1:
116
+ # final chunk always goes at the end
117
+ t_end = y_size
118
+ t_start = t_end - y_chunk.shape[2]
119
+ else:
120
+ t_start = i * hop_size * samples_per_latent
121
+ t_end = t_start + chunk_size * samples_per_latent
122
+ # remove the edges of the overlaps
123
+ ol = (overlap // 2) * samples_per_latent
124
+ chunk_start = 0
125
+ chunk_end = y_chunk.shape[2]
126
+ if i > 0:
127
+ # no overlap for the start of the first chunk
128
+ t_start += ol
129
+ chunk_start += ol
130
+ if i < num_chunks - 1:
131
+ # no overlap for the end of the last chunk
132
+ t_end -= ol
133
+ chunk_end -= ol
134
+ # paste the chunked audio into our y_final output audio
135
+ y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end]
136
+ return y_final
137
+
138
+ def encode_audio(audio, vae_model, chunked=False, overlap=32, chunk_size=128):
139
+ downsampling_ratio = 2048
140
+ latent_dim = 128
141
+ if not chunked:
142
+ # default behavior. Encode the entire audio in parallel
143
+ return vae_model.encode_export(audio)
144
+ else:
145
+ # CHUNKED ENCODING
146
+ # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
147
+ samples_per_latent = downsampling_ratio
148
+ total_size = audio.shape[2] # in samples
149
+ batch_size = audio.shape[0]
150
+ chunk_size *= samples_per_latent # converting metric in latents to samples
151
+ overlap *= samples_per_latent # converting metric in latents to samples
152
+ hop_size = chunk_size - overlap
153
+ chunks = []
154
+ for i in range(0, total_size - chunk_size + 1, hop_size):
155
+ chunk = audio[:,:,i:i+chunk_size]
156
+ chunks.append(chunk)
157
+ if i+chunk_size != total_size:
158
+ # Final chunk
159
+ chunk = audio[:,:,-chunk_size:]
160
+ chunks.append(chunk)
161
+ chunks = torch.stack(chunks)
162
+ num_chunks = chunks.shape[0]
163
+ # Note: y_size might be a different value from the latent length used in diffusion training
164
+ # because we can encode audio of varying lengths
165
+ # However, the audio should've been padded to a multiple of samples_per_latent by now.
166
+ y_size = total_size // samples_per_latent
167
+ # Create an empty latent, we will populate it with chunks as we encode them
168
+ y_final = torch.zeros((batch_size,latent_dim,y_size)).to(audio.device)
169
+ for i in range(num_chunks):
170
+ x_chunk = chunks[i,:]
171
+ # encode the chunk
172
+ y_chunk = vae_model.encode_export(x_chunk)
173
+ # figure out where to put the audio along the time domain
174
+ if i == num_chunks-1:
175
+ # final chunk always goes at the end
176
+ t_end = y_size
177
+ t_start = t_end - y_chunk.shape[2]
178
+ else:
179
+ t_start = i * hop_size // samples_per_latent
180
+ t_end = t_start + chunk_size // samples_per_latent
181
+ # remove the edges of the overlaps
182
+ ol = overlap//samples_per_latent//2
183
+ chunk_start = 0
184
+ chunk_end = y_chunk.shape[2]
185
+ if i > 0:
186
+ # no overlap for the start of the first chunk
187
+ t_start += ol
188
+ chunk_start += ol
189
+ if i < num_chunks-1:
190
+ # no overlap for the end of the last chunk
191
+ t_end -= ol
192
+ chunk_end -= ol
193
+ # paste the chunked audio into our y_final output audio
194
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
195
+ return y_final
196
+
197
+ def prepare_model(device):
198
+ # prepare cfm model
199
+
200
+ dit_ckpt_path = hf_hub_download(repo_id="josephchay/Soundsation", filename="cfm_model.pt")
201
+ dit_config_path = "./soundsation/config/config.json"
202
+ with open(dit_config_path) as f:
203
+ model_config = json.load(f)
204
+ dit_model_cls = DiT
205
+ cfm = CFM(
206
+ transformer=dit_model_cls(**model_config["model"], max_frames=2048),
207
+ num_channels=model_config["model"]['mel_dim'],
208
+ )
209
+ cfm = cfm.to(device)
210
+ cfm = load_checkpoint(cfm, dit_ckpt_path, device=device, use_ema=False)
211
+
212
+ # prepare tokenizer
213
+ tokenizer = CNENTokenizer()
214
+
215
+ # prepare muq
216
+ muq = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large", cache_dir="./pretrained")
217
+ muq = muq.to(device).eval()
218
+
219
+ # prepare vae
220
+ vae_ckpt_path = hf_hub_download(repo_id="josephchay/Soundsation-vae", filename="vae_model.pt")
221
+ vae = torch.jit.load(vae_ckpt_path, map_location="cpu").to(device)
222
+
223
+
224
+ # prepare eval model
225
+ train_config = OmegaConf.load("./pretrained/eval.yaml")
226
+ checkpoint_path = "./pretrained/eval.safetensors"
227
+
228
+ eval_model = instantiate(train_config.generator).to(device).eval()
229
+ state_dict = load_file(checkpoint_path, device="cpu")
230
+ eval_model.load_state_dict(state_dict)
231
+
232
+ eval_muq = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
233
+ eval_muq = eval_muq.to(device).eval()
234
+
235
+ return cfm, tokenizer, muq, vae, eval_model, eval_muq
236
+
237
+
238
+ # for song edit, will be added in the future
239
+ def get_reference_latent(device, max_frames, edit, pred_segments, ref_song, vae_model):
240
+ sampling_rate = 44100
241
+ downsample_rate = 2048
242
+ io_channels = 2
243
+ if edit:
244
+ input_audio, in_sr = torchaudio.load(ref_song)
245
+ input_audio = prepare_audio(input_audio, in_sr=in_sr, target_sr=sampling_rate, target_length=None, target_channels=io_channels, device=device)
246
+ input_audio = normalize_audio(input_audio, -6)
247
+
248
+ with torch.no_grad():
249
+ latent = encode_audio(input_audio, vae_model, chunked=True) # [b d t]
250
+ mean, scale = latent.chunk(2, dim=1)
251
+ prompt, _ = vae_sample(mean, scale)
252
+ prompt = prompt.transpose(1, 2) # [b t d]
253
+ prompt = prompt[:,:max_frames,:] if prompt.shape[1] >= max_frames else torch.nn.functional.pad(prompt, (0, 0, 0, max_frames - prompt.shape[1]), mode="constant", value=0)
254
+
255
+ pred_segments = json.loads(pred_segments)
256
+ # import pdb; pdb.set_trace()
257
+ pred_frames = []
258
+ for st, et in pred_segments:
259
+ sf = 0 if st == -1 else int(st * sampling_rate / downsample_rate)
260
+ # if st == -1:
261
+ # sf = 0
262
+ # else:
263
+ # sf = int(st * sampling_rate / downsample_rate )
264
+
265
+ ef = max_frames if et == -1 else int(et * sampling_rate / downsample_rate)
266
+ # if et == -1:
267
+ # ef = max_frames
268
+ # else:
269
+ # ef = int(et * sampling_rate / downsample_rate )
270
+ pred_frames.append((sf, ef))
271
+ # import pdb; pdb.set_trace()
272
+ return prompt, pred_frames
273
+ else:
274
+ prompt = torch.zeros(1, max_frames, 64).to(device)
275
+ pred_frames = [(0, max_frames)]
276
+ return prompt, pred_frames
277
+
278
+
279
+ def get_negative_style_prompt(device):
280
+ file_path = "./src/negative_prompt.npy"
281
+ vocal_stlye = np.load(file_path)
282
+
283
+ vocal_stlye = torch.from_numpy(vocal_stlye).to(device) # [1, 512]
284
+ vocal_stlye = vocal_stlye.half()
285
+
286
+ return vocal_stlye
287
+
288
+ @torch.no_grad()
289
+ def eval_song(eval_model, eval_muq, songs):
290
+
291
+ resampled_songs = [torchaudio.functional.resample(song.mean(dim=0, keepdim=True), 44100, 24000) for song in songs]
292
+ ssl_list = []
293
+ for i in range(len(resampled_songs)):
294
+ output = eval_muq(resampled_songs[i], output_hidden_states=True)
295
+ muq_ssl = output["hidden_states"][6]
296
+ ssl_list.append(muq_ssl.squeeze(0))
297
+
298
+ ssl = torch.stack(ssl_list)
299
+ scores_g = eval_model(ssl)
300
+ score = torch.mean(scores_g, dim=1)
301
+ idx = score.argmax(dim=0)
302
+
303
+ return songs[idx]
304
+
305
+
306
+ @torch.no_grad()
307
+ def get_audio_style_prompt(model, wav_path):
308
+ vocal_flag = False
309
+ mulan = model
310
+ audio, _ = librosa.load(wav_path, sr=24000)
311
+ audio_len = librosa.get_duration(y=audio, sr=24000)
312
+
313
+ if audio_len <= 1:
314
+ vocal_flag = True
315
+
316
+ if audio_len > 10:
317
+ start_time = int(audio_len // 2 - 5)
318
+ wav = audio[start_time*24000:(start_time+10)*24000]
319
+
320
+ else:
321
+ wav = audio
322
+ wav = torch.tensor(wav).unsqueeze(0).to(model.device)
323
+
324
+ with torch.no_grad():
325
+ audio_emb = mulan(wavs = wav) # [1, 512]
326
+
327
+ audio_emb = audio_emb.half()
328
+
329
+ return audio_emb, vocal_flag
330
+
331
+
332
+ @torch.no_grad()
333
+ def get_text_style_prompt(model, text_prompt):
334
+ mulan = model
335
+
336
+ with torch.no_grad():
337
+ text_emb = mulan(texts = text_prompt) # [1, 512]
338
+ text_emb = text_emb.half()
339
+
340
+ return text_emb
341
+
342
+
343
+ @torch.no_grad()
344
+ def get_style_prompt(model, wav_path=None, prompt=None):
345
+ mulan = model
346
+
347
+ if prompt is not None:
348
+ return mulan(texts=prompt).half()
349
+
350
+ ext = os.path.splitext(wav_path)[-1].lower()
351
+ if ext == ".mp3":
352
+ meta = MP3(wav_path)
353
+ audio_len = meta.info.length
354
+ elif ext in [".wav", ".flac"]:
355
+ audio_len = librosa.get_duration(path=wav_path)
356
+ else:
357
+ raise ValueError("Unsupported file format: {}".format(ext))
358
+
359
+ if audio_len < 10:
360
+ print(
361
+ f"Warning: The audio file {wav_path} is too short ({audio_len:.2f} seconds). Expected at least 10 seconds."
362
+ )
363
+
364
+ assert audio_len >= 10
365
+
366
+ mid_time = audio_len // 2
367
+ start_time = mid_time - 5
368
+ wav, _ = librosa.load(wav_path, sr=24000, offset=start_time, duration=10)
369
+
370
+ wav = torch.tensor(wav).unsqueeze(0).to(model.device)
371
+
372
+ with torch.no_grad():
373
+ audio_emb = mulan(wavs=wav) # [1, 512]
374
+
375
+ audio_emb = audio_emb
376
+ audio_emb = audio_emb.half()
377
+
378
+ return audio_emb
379
+
380
+ def parse_lyrics(lyrics: str):
381
+ lyrics_with_time = []
382
+ lyrics = lyrics.strip()
383
+ for line in lyrics.split("\n"):
384
+ try:
385
+ time, lyric = line[1:9], line[10:]
386
+ lyric = lyric.strip()
387
+ mins, secs = time.split(":")
388
+ secs = int(mins) * 60 + float(secs)
389
+ lyrics_with_time.append((secs, lyric))
390
+ except:
391
+ continue
392
+ return lyrics_with_time
393
+
394
+
395
+ class CNENTokenizer:
396
+ def __init__(self):
397
+ with open("./soundsation/g2p/g2p/vocab.json", "r", encoding='utf-8') as file:
398
+ self.phone2id: dict = json.load(file)["vocab"]
399
+ self.id2phone = {v: k for (k, v) in self.phone2id.items()}
400
+ from soundsation.g2p.g2p_generation import chn_eng_g2p
401
+
402
+ self.tokenizer = chn_eng_g2p
403
+
404
+ def encode(self, text):
405
+ phone, token = self.tokenizer(text)
406
+ token = [x + 1 for x in token]
407
+ return token
408
+
409
+ def decode(self, token):
410
+ return "|".join([self.id2phone[x - 1] for x in token])
411
+
412
+
413
+ def get_lrc_token(max_frames, text, tokenizer, device):
414
+
415
+ lyrics_shift = 0
416
+ sampling_rate = 44100
417
+ downsample_rate = 2048
418
+ max_secs = max_frames / (sampling_rate / downsample_rate)
419
+
420
+ comma_token_id = 1
421
+ period_token_id = 2
422
+
423
+ lrc_with_time = parse_lyrics(text)
424
+
425
+ modified_lrc_with_time = []
426
+ for i in range(len(lrc_with_time)):
427
+ time, line = lrc_with_time[i]
428
+ line_token = tokenizer.encode(line)
429
+ modified_lrc_with_time.append((time, line_token))
430
+ lrc_with_time = modified_lrc_with_time
431
+
432
+ lrc_with_time = [
433
+ (time_start, line)
434
+ for (time_start, line) in lrc_with_time
435
+ if time_start < max_secs
436
+ ]
437
+ if max_frames == 2048:
438
+ lrc_with_time = lrc_with_time[:-1] if len(lrc_with_time) >= 1 else lrc_with_time
439
+
440
+ normalized_start_time = 0.0
441
+
442
+ lrc = torch.zeros((max_frames,), dtype=torch.long)
443
+
444
+ tokens_count = 0
445
+ last_end_pos = 0
446
+ for time_start, line in lrc_with_time:
447
+ tokens = [
448
+ token if token != period_token_id else comma_token_id for token in line
449
+ ] + [period_token_id]
450
+ tokens = torch.tensor(tokens, dtype=torch.long)
451
+ num_tokens = tokens.shape[0]
452
+
453
+ gt_frame_start = int(time_start * sampling_rate / downsample_rate)
454
+
455
+ frame_shift = random.randint(int(-lyrics_shift), int(lyrics_shift))
456
+
457
+ frame_start = max(gt_frame_start - frame_shift, last_end_pos)
458
+ frame_len = min(num_tokens, max_frames - frame_start)
459
+
460
+ lrc[frame_start : frame_start + frame_len] = tokens[:frame_len]
461
+
462
+ tokens_count += num_tokens
463
+ last_end_pos = frame_start + frame_len
464
+
465
+ lrc_emb = lrc.unsqueeze(0).to(device)
466
+
467
+ normalized_start_time = torch.tensor(normalized_start_time).unsqueeze(0).to(device)
468
+ normalized_start_time = normalized_start_time.half()
469
+
470
+ return lrc_emb, normalized_start_time
471
+
472
+
473
+ def load_checkpoint(model, ckpt_path, device, use_ema=True):
474
+ model = model.half()
475
+
476
+ ckpt_type = ckpt_path.split(".")[-1]
477
+ if ckpt_type == "safetensors":
478
+ from safetensors.torch import load_file
479
+
480
+ checkpoint = load_file(ckpt_path)
481
+ else:
482
+ checkpoint = torch.load(ckpt_path, weights_only=True)
483
+
484
+ if use_ema:
485
+ if ckpt_type == "safetensors":
486
+ checkpoint = {"ema_model_state_dict": checkpoint}
487
+ checkpoint["model_state_dict"] = {
488
+ k.replace("ema_model.", ""): v
489
+ for k, v in checkpoint["ema_model_state_dict"].items()
490
+ if k not in ["initted", "step"]
491
+ }
492
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
493
+ else:
494
+ if ckpt_type == "safetensors":
495
+ checkpoint = {"model_state_dict": checkpoint}
496
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
497
+
498
+ return model.to(device)
soundsation/model/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from soundsation.model.cfm import CFM
2
+
3
+ from soundsation.model.dit import DiT
4
+
5
+
6
+ __all__ = ["CFM"]
soundsation/model/cfm.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ASLP-LAB
2
+ # 2025 Ziqian Ning ([email protected])
3
+ # 2025 Huakang Chen ([email protected])
4
+ # 2025 Guobin Ma ([email protected])
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """ This implementation is adapted from github repo:
19
+ https://github.com/SWivid/F5-TTS.
20
+ """
21
+
22
+ from __future__ import annotations
23
+ from typing import Callable
24
+ from random import random
25
+
26
+ import torch
27
+ from torch import nn
28
+ import torch
29
+ import torch.nn.functional as F
30
+ from torch.nn.utils.rnn import pad_sequence
31
+
32
+ from torchdiffeq import odeint
33
+
34
+ from soundsation.model.utils import (
35
+ exists,
36
+ list_str_to_idx,
37
+ list_str_to_tensor,
38
+ lens_to_mask,
39
+ mask_from_frac_lengths,
40
+ )
41
+
42
+ def custom_mask_from_start_end_indices(
43
+ seq_len: int["b"], # noqa: F821
44
+ latent_pred_segments,
45
+ device,
46
+ max_seq_len
47
+ ):
48
+ max_seq_len = max_seq_len
49
+ seq = torch.arange(max_seq_len, device=device).long()
50
+
51
+ res_mask = torch.zeros(max_seq_len, device=device, dtype=torch.bool)
52
+
53
+ for start, end in latent_pred_segments:
54
+ start = start.unsqueeze(0)
55
+ end = end.unsqueeze(0)
56
+ start_mask = seq[None, :] >= start[:, None]
57
+ end_mask = seq[None, :] < end[:, None]
58
+ res_mask = res_mask | (start_mask & end_mask)
59
+
60
+ return res_mask
61
+
62
+ class CFM(nn.Module):
63
+ def __init__(
64
+ self,
65
+ transformer: nn.Module,
66
+ sigma=0.0,
67
+ odeint_kwargs: dict = dict(
68
+ method="euler"
69
+ ),
70
+ odeint_options: dict = dict(
71
+ min_step=0.05
72
+ ),
73
+ audio_drop_prob=0.3,
74
+ cond_drop_prob=0.2,
75
+ style_drop_prob=0.1,
76
+ lrc_drop_prob=0.1,
77
+ num_channels=None,
78
+ frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
79
+ vocab_char_map: dict[str:int] | None = None,
80
+ max_frames=2048
81
+ ):
82
+ super().__init__()
83
+
84
+ self.frac_lengths_mask = frac_lengths_mask
85
+
86
+ self.num_channels = num_channels
87
+
88
+ # classifier-free guidance
89
+ self.audio_drop_prob = audio_drop_prob
90
+ self.cond_drop_prob = cond_drop_prob
91
+ self.style_drop_prob = style_drop_prob
92
+ self.lrc_drop_prob = lrc_drop_prob
93
+
94
+ # transformer
95
+ self.transformer = transformer
96
+ dim = transformer.dim
97
+ self.dim = dim
98
+
99
+ # conditional flow related
100
+ self.sigma = sigma
101
+
102
+ # sampling related
103
+ self.odeint_kwargs = odeint_kwargs
104
+
105
+ self.odeint_options = odeint_options
106
+
107
+ # vocab map for tokenization
108
+ self.vocab_char_map = vocab_char_map
109
+
110
+ self.max_frames = max_frames
111
+
112
+ @property
113
+ def device(self):
114
+ return next(self.parameters()).device
115
+
116
+ @torch.no_grad()
117
+ def sample(
118
+ self,
119
+ cond: float["b n d"] | float["b nw"], # noqa: F722
120
+ text: int["b nt"] | list[str], # noqa: F722
121
+ duration: int | int["b"], # noqa: F821
122
+ *,
123
+ style_prompt = None,
124
+ style_prompt_lens = None,
125
+ negative_style_prompt = None,
126
+ lens: int["b"] | None = None, # noqa: F821
127
+ steps=32,
128
+ cfg_strength=4.0,
129
+ sway_sampling_coef=None,
130
+ seed: int | None = None,
131
+ max_duration=6144,
132
+ vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
133
+ no_ref_audio=False,
134
+ duplicate_test=False,
135
+ t_inter=0.1,
136
+ edit_mask=None,
137
+ start_time=None,
138
+ latent_pred_segments=None,
139
+ vocal_flag=False,
140
+ odeint_method="euler",
141
+ batch_infer_num=5
142
+ ):
143
+ self.eval()
144
+
145
+ self.odeint_kwargs = dict(method=odeint_method)
146
+
147
+ if next(self.parameters()).dtype == torch.float16:
148
+ cond = cond.half()
149
+
150
+ # raw wave
151
+ if cond.shape[1] > duration:
152
+ cond = cond[:, :duration, :]
153
+
154
+ if cond.ndim == 2:
155
+ cond = self.mel_spec(cond)
156
+ cond = cond.permute(0, 2, 1)
157
+ assert cond.shape[-1] == self.num_channels
158
+
159
+ batch, cond_seq_len, device = *cond.shape[:2], cond.device
160
+ if not exists(lens):
161
+ lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
162
+
163
+ # text
164
+ if isinstance(text, list):
165
+ if exists(self.vocab_char_map):
166
+ text = list_str_to_idx(text, self.vocab_char_map).to(device)
167
+ else:
168
+ text = list_str_to_tensor(text).to(device)
169
+ assert text.shape[0] == batch
170
+
171
+ # duration
172
+ cond_mask = lens_to_mask(lens)
173
+ if edit_mask is not None:
174
+ cond_mask = cond_mask & edit_mask
175
+
176
+ latent_pred_segments = torch.tensor(latent_pred_segments).to(cond.device)
177
+ fixed_span_mask = custom_mask_from_start_end_indices(cond_seq_len, latent_pred_segments, device=cond.device, max_seq_len=duration)
178
+ fixed_span_mask = fixed_span_mask.unsqueeze(-1)
179
+ step_cond = torch.where(fixed_span_mask, torch.zeros_like(cond), cond)
180
+
181
+ if isinstance(duration, int):
182
+ duration = torch.full((batch_infer_num,), duration, device=device, dtype=torch.long)
183
+
184
+ duration = duration.clamp(max=max_duration)
185
+ max_duration = duration.amax()
186
+
187
+ # duplicate test corner for inner time step oberservation
188
+ if duplicate_test:
189
+ test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
190
+
191
+ if batch > 1:
192
+ mask = lens_to_mask(duration)
193
+ else: # save memory and speed up, as single inference need no mask currently
194
+ mask = None
195
+
196
+ # test for no ref audio
197
+ if no_ref_audio:
198
+ cond = torch.zeros_like(cond)
199
+
200
+ if vocal_flag:
201
+ style_prompt = negative_style_prompt
202
+ negative_style_prompt = torch.zeros_like(style_prompt)
203
+
204
+ cond = cond.repeat(batch_infer_num, 1, 1)
205
+ step_cond = step_cond.repeat(batch_infer_num, 1, 1)
206
+ text = text.repeat(batch_infer_num, 1)
207
+ style_prompt = style_prompt.repeat(batch_infer_num, 1)
208
+ negative_style_prompt = negative_style_prompt.repeat(batch_infer_num, 1)
209
+ start_time = start_time.repeat(batch_infer_num)
210
+ fixed_span_mask = fixed_span_mask.repeat(batch_infer_num, 1, 1)
211
+
212
+ def fn(t, x):
213
+ # predict flow
214
+ pred = self.transformer(
215
+ x=x, cond=step_cond, text=text, time=t, drop_audio_cond=False, drop_text=False, drop_prompt=False,
216
+ style_prompt=style_prompt, start_time=start_time
217
+ )
218
+ if cfg_strength < 1e-5:
219
+ return pred
220
+
221
+ null_pred = self.transformer(
222
+ x=x, cond=step_cond, text=text, time=t, drop_audio_cond=True, drop_text=True, drop_prompt=False,
223
+ style_prompt=negative_style_prompt, start_time=start_time
224
+ )
225
+ return pred + (pred - null_pred) * cfg_strength
226
+
227
+ # noise input
228
+ # to make sure batch inference result is same with different batch size, and for sure single inference
229
+ # still some difference maybe due to convolutional layers
230
+ y0 = []
231
+ for dur in duration:
232
+ if exists(seed):
233
+ torch.manual_seed(seed)
234
+ y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
235
+ y0 = pad_sequence(y0, padding_value=0, batch_first=True)
236
+
237
+ t_start = 0
238
+
239
+ # duplicate test corner for inner time step oberservation
240
+ if duplicate_test:
241
+ t_start = t_inter
242
+ y0 = (1 - t_start) * y0 + t_start * test_cond
243
+ steps = int(steps * (1 - t_start))
244
+
245
+ t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype)
246
+ if sway_sampling_coef is not None:
247
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
248
+
249
+ trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
250
+
251
+ sampled = trajectory[-1]
252
+ out = sampled
253
+ out = torch.where(fixed_span_mask, out, cond)
254
+
255
+ if exists(vocoder):
256
+ out = out.permute(0, 2, 1)
257
+ out = vocoder(out)
258
+
259
+ out = torch.chunk(out, batch_infer_num, dim=0)
260
+ return out, trajectory
261
+
262
+ def forward(
263
+ self,
264
+ inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
265
+ text: int["b nt"] | list[str], # noqa: F722
266
+ style_prompt = None,
267
+ style_prompt_lens = None,
268
+ lens: int["b"] | None = None, # noqa: F821
269
+ noise_scheduler: str | None = None,
270
+ grad_ckpt = False,
271
+ start_time = None,
272
+ ):
273
+
274
+ batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
275
+
276
+ # lens and mask
277
+ if not exists(lens):
278
+ lens = torch.full((batch,), seq_len, device=device)
279
+
280
+ mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
281
+
282
+ # get a random span to mask out for training conditionally
283
+ frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
284
+ rand_span_mask = mask_from_frac_lengths(lens, frac_lengths, self.max_frames)
285
+
286
+ if exists(mask):
287
+ rand_span_mask = mask
288
+
289
+ # mel is x1
290
+ x1 = inp
291
+
292
+ # x0 is gaussian noise
293
+ x0 = torch.randn_like(x1)
294
+
295
+ # time step
296
+ time = torch.normal(mean=0, std=1, size=(batch,), device=self.device)
297
+ time = torch.nn.functional.sigmoid(time)
298
+ # TODO. noise_scheduler
299
+
300
+ # sample xt (φ_t(x) in the paper)
301
+ t = time.unsqueeze(-1).unsqueeze(-1)
302
+ φ = (1 - t) * x0 + t * x1
303
+ flow = x1 - x0
304
+
305
+ # only predict what is within the random mask span for infilling
306
+ cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
307
+
308
+ # transformer and cfg training with a drop rate
309
+ drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
310
+ drop_text = random() < self.lrc_drop_prob
311
+ drop_prompt = random() < self.style_drop_prob
312
+
313
+ # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
314
+ # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
315
+ pred = self.transformer(
316
+ x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, drop_prompt=drop_prompt,
317
+ style_prompt=style_prompt, start_time=start_time
318
+ )
319
+
320
+ # flow matching loss
321
+ loss = F.mse_loss(pred, flow, reduction="none")
322
+ loss = loss[rand_span_mask]
323
+
324
+ return loss.mean(), cond, pred
soundsation/model/dit.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ASLP-LAB
2
+ # 2025 Ziqian Ning ([email protected])
3
+ # 2025 Huakang Chen ([email protected])
4
+ # 2025 Yuepeng Jiang ([email protected])
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """ This implementation is adapted from github repo:
19
+ https://github.com/SWivid/F5-TTS.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import torch
25
+ from torch import nn
26
+ import torch
27
+
28
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRotaryEmbedding
29
+ from transformers.models.llama import LlamaConfig
30
+
31
+ from soundsation.model.modules import (
32
+ TimestepEmbedding,
33
+ ConvNeXtV2Block,
34
+ ConvPositionEmbedding,
35
+ AdaLayerNormZero_Final,
36
+ precompute_freqs_cis,
37
+ get_pos_embed_indices,
38
+ _prepare_decoder_attention_mask,
39
+ )
40
+
41
+ # Text embedding
42
+ class TextEmbedding(nn.Module):
43
+ def __init__(self, text_num_embeds, text_dim, max_pos, conv_layers=0, conv_mult=2):
44
+ super().__init__()
45
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
46
+
47
+ if conv_layers > 0:
48
+ self.extra_modeling = True
49
+ self.precompute_max_pos = max_pos # ~44s of 24khz audio
50
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
51
+ self.text_blocks = nn.Sequential(
52
+ *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
53
+ )
54
+ else:
55
+ self.extra_modeling = False
56
+
57
+ def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
58
+ batch, text_len = text.shape[0], text.shape[1]
59
+
60
+ if drop_text: # cfg for text
61
+ text = torch.zeros_like(text)
62
+
63
+ text = self.text_embed(text) # b n -> b n d
64
+
65
+ # possible extra modeling
66
+ if self.extra_modeling:
67
+ # sinus pos emb
68
+ batch_start = torch.zeros((batch,), dtype=torch.long)
69
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
70
+ text_pos_embed = self.freqs_cis[pos_idx]
71
+ text = text + text_pos_embed
72
+
73
+ # convnextv2 blocks
74
+ text = self.text_blocks(text)
75
+
76
+ return text
77
+
78
+
79
+ # noised input audio and context mixing embedding
80
+ class InputEmbedding(nn.Module):
81
+ def __init__(self, mel_dim, text_dim, out_dim, cond_dim):
82
+ super().__init__()
83
+ self.proj = nn.Linear(mel_dim * 2 + text_dim + cond_dim * 2, out_dim)
84
+ self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
85
+
86
+ def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], style_emb, time_emb, drop_audio_cond=False): # noqa: F722
87
+ if drop_audio_cond: # cfg for cond audio
88
+ cond = torch.zeros_like(cond)
89
+ style_emb = style_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
90
+ time_emb = time_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
91
+ x = self.proj(torch.cat((x, cond, text_embed, style_emb, time_emb), dim=-1))
92
+ x = self.conv_pos_embed(x) + x
93
+ return x
94
+
95
+
96
+ # Transformer backbone using Llama blocks
97
+ class DiT(nn.Module):
98
+ def __init__(
99
+ self,
100
+ *,
101
+ dim,
102
+ depth=8,
103
+ heads=8,
104
+ dim_head=64,
105
+ dropout=0.1,
106
+ ff_mult=4,
107
+ mel_dim=100,
108
+ text_num_embeds=256,
109
+ text_dim=None,
110
+ conv_layers=0,
111
+ long_skip_connection=False,
112
+ max_frames=2048
113
+ ):
114
+ super().__init__()
115
+
116
+ self.max_frames = max_frames
117
+
118
+ cond_dim = 512
119
+ self.time_embed = TimestepEmbedding(cond_dim)
120
+ self.start_time_embed = TimestepEmbedding(cond_dim)
121
+ if text_dim is None:
122
+ text_dim = mel_dim
123
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers, max_pos=self.max_frames)
124
+ self.input_embed = InputEmbedding(mel_dim, text_dim, dim, cond_dim=cond_dim)
125
+
126
+ self.dim = dim
127
+ self.depth = depth
128
+
129
+ llama_config = LlamaConfig(hidden_size=dim, intermediate_size=dim * ff_mult, hidden_act='silu', max_position_embeddings=self.max_frames)
130
+ llama_config._attn_implementation = 'sdpa'
131
+ self.transformer_blocks = nn.ModuleList(
132
+ [LlamaDecoderLayer(llama_config, layer_idx=i) for i in range(depth)]
133
+ )
134
+ self.rotary_emb = LlamaRotaryEmbedding(config=llama_config)
135
+ self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
136
+
137
+ self.text_fusion_linears = nn.ModuleList(
138
+ [
139
+ nn.Sequential(
140
+ nn.Linear(cond_dim, dim),
141
+ nn.SiLU()
142
+ ) for i in range(depth // 2)
143
+ ]
144
+ )
145
+ for layer in self.text_fusion_linears:
146
+ for p in layer.parameters():
147
+ p.detach().zero_()
148
+
149
+ self.norm_out = AdaLayerNormZero_Final(dim, cond_dim) # final modulation
150
+ self.proj_out = nn.Linear(dim, mel_dim)
151
+
152
+ def forward_timestep_invariant(self, text, seq_len, drop_text, start_time):
153
+ s_t = self.start_time_embed(start_time)
154
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
155
+ text_residuals = []
156
+ for layer in self.text_fusion_linears:
157
+ text_residual = layer(text_embed)
158
+ text_residuals.append(text_residual)
159
+ return s_t, text_embed, text_residuals
160
+
161
+
162
+ def forward(
163
+ self,
164
+ x: float["b n d"], # nosied input audio # noqa: F722
165
+ cond: float["b n d"], # masked cond audio # noqa: F722
166
+ text: int["b nt"], # text # noqa: F722
167
+ time: float["b"] | float[""], # time step # noqa: F821 F722
168
+ drop_audio_cond, # cfg for cond audio
169
+ drop_text, # cfg for text
170
+ drop_prompt=False,
171
+ style_prompt=None, # [b d t]
172
+ start_time=None,
173
+ ):
174
+
175
+ batch, seq_len = x.shape[0], x.shape[1]
176
+ if time.ndim == 0:
177
+ time = time.repeat(batch)
178
+
179
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
180
+ t = self.time_embed(time)
181
+ s_t = self.start_time_embed(start_time)
182
+ c = t + s_t
183
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
184
+
185
+ if drop_prompt:
186
+ style_prompt = torch.zeros_like(style_prompt)
187
+
188
+ style_embed = style_prompt # [b, 512]
189
+
190
+ x = self.input_embed(x, cond, text_embed, style_embed, c, drop_audio_cond=drop_audio_cond)
191
+
192
+ if self.long_skip_connection is not None:
193
+ residual = x
194
+
195
+ pos_ids = torch.arange(x.shape[1], device=x.device)
196
+ pos_ids = pos_ids.unsqueeze(0).repeat(x.shape[0], 1)
197
+ rotary_embed = self.rotary_emb(x, pos_ids)
198
+
199
+ attention_mask = torch.ones(
200
+ (batch, seq_len),
201
+ dtype=torch.bool,
202
+ device=x.device,
203
+ )
204
+ attention_mask = _prepare_decoder_attention_mask(
205
+ attention_mask,
206
+ (batch, seq_len),
207
+ x,
208
+ )
209
+
210
+ for i, block in enumerate(self.transformer_blocks):
211
+ x, *_ = block(x, attention_mask=attention_mask, position_embeddings=rotary_embed)
212
+ if i < self.depth // 2:
213
+ x = x + self.text_fusion_linears[i](text_embed)
214
+
215
+ if self.long_skip_connection is not None:
216
+ x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
217
+
218
+ x = self.norm_out(x, c)
219
+ output = self.proj_out(x)
220
+
221
+ return output
soundsation/model/modules.py ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Optional
12
+ import math
13
+
14
+ import torch
15
+ from torch import nn
16
+ import torch
17
+ import torch.nn.functional as F
18
+ import torchaudio
19
+
20
+ from x_transformers.x_transformers import apply_rotary_pos_emb
21
+
22
+
23
+
24
+ class FiLMLayer(nn.Module):
25
+ """
26
+ Feature-wise Linear Modulation (FiLM) layer
27
+ Reference: https://arxiv.org/abs/1709.07871
28
+ """
29
+ def __init__(self, in_channels, cond_channels):
30
+
31
+ super(FiLMLayer, self).__init__()
32
+ self.in_channels = in_channels
33
+ self.film = nn.Conv1d(cond_channels, in_channels * 2, 1)
34
+
35
+ def forward(self, x, c):
36
+ gamma, beta = torch.chunk(self.film(c.unsqueeze(2)), chunks=2, dim=1)
37
+ gamma = gamma.transpose(1, 2)
38
+ beta = beta.transpose(1, 2)
39
+ # print(gamma.shape, beta.shape)
40
+ return gamma * x + beta
41
+
42
+ # raw wav to mel spec
43
+
44
+
45
+ class MelSpec(nn.Module):
46
+ def __init__(
47
+ self,
48
+ filter_length=1024,
49
+ hop_length=256,
50
+ win_length=1024,
51
+ n_mel_channels=100,
52
+ target_sample_rate=24_000,
53
+ normalize=False,
54
+ power=1,
55
+ norm=None,
56
+ center=True,
57
+ ):
58
+ super().__init__()
59
+ self.n_mel_channels = n_mel_channels
60
+
61
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(
62
+ sample_rate=target_sample_rate,
63
+ n_fft=filter_length,
64
+ win_length=win_length,
65
+ hop_length=hop_length,
66
+ n_mels=n_mel_channels,
67
+ power=power,
68
+ center=center,
69
+ normalized=normalize,
70
+ norm=norm,
71
+ )
72
+
73
+ self.register_buffer("dummy", torch.tensor(0), persistent=False)
74
+
75
+ def forward(self, inp):
76
+ if len(inp.shape) == 3:
77
+ inp = inp.squeeze(1) # 'b 1 nw -> b nw'
78
+
79
+ assert len(inp.shape) == 2
80
+
81
+ if self.dummy.device != inp.device:
82
+ self.to(inp.device)
83
+
84
+ mel = self.mel_stft(inp)
85
+ mel = mel.clamp(min=1e-5).log()
86
+ return mel
87
+
88
+
89
+ # sinusoidal position embedding
90
+
91
+
92
+ class SinusPositionEmbedding(nn.Module):
93
+ def __init__(self, dim):
94
+ super().__init__()
95
+ self.dim = dim
96
+
97
+ def forward(self, x, scale=1000):
98
+ device = x.device
99
+ half_dim = self.dim // 2
100
+ emb = math.log(10000) / (half_dim - 1)
101
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
102
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
103
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
104
+ return emb
105
+
106
+
107
+ # convolutional position embedding
108
+
109
+
110
+ class ConvPositionEmbedding(nn.Module):
111
+ def __init__(self, dim, kernel_size=31, groups=16):
112
+ super().__init__()
113
+ assert kernel_size % 2 != 0
114
+ self.conv1d = nn.Sequential(
115
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
116
+ nn.Mish(),
117
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
118
+ nn.Mish(),
119
+ )
120
+
121
+ def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
122
+ if mask is not None:
123
+ mask = mask[..., None]
124
+ x = x.masked_fill(~mask, 0.0)
125
+
126
+ x = x.permute(0, 2, 1)
127
+ x = self.conv1d(x)
128
+ out = x.permute(0, 2, 1)
129
+
130
+ if mask is not None:
131
+ out = out.masked_fill(~mask, 0.0)
132
+
133
+ return out
134
+
135
+
136
+ # rotary positional embedding related
137
+
138
+
139
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
140
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
141
+ # has some connection to NTK literature
142
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
143
+ # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
144
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
145
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
146
+ t = torch.arange(end, device=freqs.device) # type: ignore
147
+ freqs = torch.outer(t, freqs).float() # type: ignore
148
+ freqs_cos = torch.cos(freqs) # real part
149
+ freqs_sin = torch.sin(freqs) # imaginary part
150
+ return torch.cat([freqs_cos, freqs_sin], dim=-1)
151
+
152
+
153
+ def get_pos_embed_indices(start, length, max_pos, scale=1.0):
154
+ # length = length if isinstance(length, int) else length.max()
155
+ scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
156
+ pos = (
157
+ start.unsqueeze(1)
158
+ + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
159
+ )
160
+ # avoid extra long error.
161
+ pos = torch.where(pos < max_pos, pos, max_pos - 1)
162
+ return pos
163
+
164
+
165
+ # Global Response Normalization layer (Instance Normalization ?)
166
+
167
+
168
+ class GRN(nn.Module):
169
+ def __init__(self, dim):
170
+ super().__init__()
171
+ self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
172
+ self.beta = nn.Parameter(torch.zeros(1, 1, dim))
173
+
174
+ def forward(self, x):
175
+ Gx = torch.norm(x, p=2, dim=1, keepdim=True)
176
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
177
+ return self.gamma * (x * Nx) + self.beta + x
178
+
179
+
180
+ # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
181
+ # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
182
+
183
+
184
+ class ConvNeXtV2Block(nn.Module):
185
+ def __init__(
186
+ self,
187
+ dim: int,
188
+ intermediate_dim: int,
189
+ dilation: int = 1,
190
+ ):
191
+ super().__init__()
192
+ padding = (dilation * (7 - 1)) // 2
193
+ self.dwconv = nn.Conv1d(
194
+ dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
195
+ ) # depthwise conv
196
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
197
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
198
+ self.act = nn.GELU()
199
+ self.grn = GRN(intermediate_dim)
200
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
201
+
202
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
203
+ residual = x
204
+ x = x.transpose(1, 2) # b n d -> b d n
205
+ x = self.dwconv(x)
206
+ x = x.transpose(1, 2) # b d n -> b n d
207
+ x = self.norm(x)
208
+ x = self.pwconv1(x)
209
+ x = self.act(x)
210
+ x = self.grn(x)
211
+ x = self.pwconv2(x)
212
+ return residual + x
213
+
214
+
215
+ # AdaLayerNormZero
216
+ # return with modulated x for attn input, and params for later mlp modulation
217
+
218
+
219
+ class AdaLayerNormZero(nn.Module):
220
+ def __init__(self, dim):
221
+ super().__init__()
222
+
223
+ self.silu = nn.SiLU()
224
+ self.linear = nn.Linear(dim, dim * 6)
225
+
226
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
227
+
228
+ def forward(self, x, emb=None):
229
+ emb = self.linear(self.silu(emb))
230
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
231
+
232
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
233
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
234
+
235
+
236
+ # AdaLayerNormZero for final layer
237
+ # return only with modulated x for attn input, cuz no more mlp modulation
238
+
239
+
240
+ class AdaLayerNormZero_Final(nn.Module):
241
+ def __init__(self, dim, cond_dim):
242
+ super().__init__()
243
+
244
+ self.silu = nn.SiLU()
245
+ self.linear = nn.Linear(cond_dim, dim * 2)
246
+
247
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
248
+
249
+ def forward(self, x, emb):
250
+ emb = self.linear(self.silu(emb))
251
+ scale, shift = torch.chunk(emb, 2, dim=1)
252
+
253
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
254
+ return x
255
+
256
+
257
+ # FeedForward
258
+
259
+
260
+ class FeedForward(nn.Module):
261
+ def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
262
+ super().__init__()
263
+ inner_dim = int(dim * mult)
264
+ dim_out = dim_out if dim_out is not None else dim
265
+
266
+ activation = nn.GELU(approximate=approximate)
267
+ #activation = nn.SiLU()
268
+ project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
269
+ self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
270
+
271
+ def forward(self, x):
272
+ return self.ff(x)
273
+
274
+
275
+ # Attention with possible joint part
276
+ # modified from diffusers/src/diffusers/models/attention_processor.py
277
+
278
+
279
+ class Attention(nn.Module):
280
+ def __init__(
281
+ self,
282
+ processor: JointAttnProcessor | AttnProcessor,
283
+ dim: int,
284
+ heads: int = 8,
285
+ dim_head: int = 64,
286
+ dropout: float = 0.0,
287
+ context_dim: Optional[int] = None, # if not None -> joint attention
288
+ context_pre_only=None,
289
+ ):
290
+ super().__init__()
291
+
292
+ if not hasattr(F, "scaled_dot_product_attention"):
293
+ raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
294
+
295
+ self.processor = processor
296
+
297
+ self.dim = dim
298
+ self.heads = heads
299
+ self.inner_dim = dim_head * heads
300
+ self.dropout = dropout
301
+
302
+ self.context_dim = context_dim
303
+ self.context_pre_only = context_pre_only
304
+
305
+ self.to_q = nn.Linear(dim, self.inner_dim)
306
+ self.to_k = nn.Linear(dim, self.inner_dim)
307
+ self.to_v = nn.Linear(dim, self.inner_dim)
308
+
309
+ if self.context_dim is not None:
310
+ self.to_k_c = nn.Linear(context_dim, self.inner_dim)
311
+ self.to_v_c = nn.Linear(context_dim, self.inner_dim)
312
+ if self.context_pre_only is not None:
313
+ self.to_q_c = nn.Linear(context_dim, self.inner_dim)
314
+
315
+ self.to_out = nn.ModuleList([])
316
+ self.to_out.append(nn.Linear(self.inner_dim, dim))
317
+ self.to_out.append(nn.Dropout(dropout))
318
+
319
+ if self.context_pre_only is not None and not self.context_pre_only:
320
+ self.to_out_c = nn.Linear(self.inner_dim, dim)
321
+
322
+ def forward(
323
+ self,
324
+ x: float["b n d"], # noised input x # noqa: F722
325
+ c: float["b n d"] = None, # context c # noqa: F722
326
+ mask: bool["b n"] | None = None, # noqa: F722
327
+ rope=None, # rotary position embedding for x
328
+ c_rope=None, # rotary position embedding for c
329
+ ) -> torch.Tensor:
330
+ if c is not None:
331
+ return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
332
+ else:
333
+ return self.processor(self, x, mask=mask, rope=rope)
334
+
335
+
336
+ # Attention processor
337
+
338
+
339
+ class AttnProcessor:
340
+ def __init__(self):
341
+ pass
342
+
343
+ def __call__(
344
+ self,
345
+ attn: Attention,
346
+ x: float["b n d"], # noised input x # noqa: F722
347
+ mask: bool["b n"] | None = None, # noqa: F722
348
+ rope=None, # rotary position embedding
349
+ ) -> torch.FloatTensor:
350
+ batch_size = x.shape[0]
351
+
352
+ # `sample` projections.
353
+ query = attn.to_q(x)
354
+ key = attn.to_k(x)
355
+ value = attn.to_v(x)
356
+
357
+ # apply rotary position embedding
358
+ if rope is not None:
359
+ freqs, xpos_scale = rope
360
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
361
+
362
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
363
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
364
+
365
+ # attention
366
+ inner_dim = key.shape[-1]
367
+ head_dim = inner_dim // attn.heads
368
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
369
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
370
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
371
+
372
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
373
+ if mask is not None:
374
+ attn_mask = mask
375
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
376
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
377
+ else:
378
+ attn_mask = None
379
+
380
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
381
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
382
+ x = x.to(query.dtype)
383
+
384
+ # linear proj
385
+ x = attn.to_out[0](x)
386
+ # dropout
387
+ x = attn.to_out[1](x)
388
+
389
+ if mask is not None:
390
+ mask = mask.unsqueeze(-1)
391
+ x = x.masked_fill(~mask, 0.0)
392
+
393
+ return x
394
+
395
+
396
+ # Joint Attention processor for MM-DiT
397
+ # modified from diffusers/src/diffusers/models/attention_processor.py
398
+
399
+
400
+ class JointAttnProcessor:
401
+ def __init__(self):
402
+ pass
403
+
404
+ def __call__(
405
+ self,
406
+ attn: Attention,
407
+ x: float["b n d"], # noised input x # noqa: F722
408
+ c: float["b nt d"] = None, # context c, here text # noqa: F722
409
+ mask: bool["b n"] | None = None, # noqa: F722
410
+ rope=None, # rotary position embedding for x
411
+ c_rope=None, # rotary position embedding for c
412
+ ) -> torch.FloatTensor:
413
+ residual = x
414
+
415
+ batch_size = c.shape[0]
416
+
417
+ # `sample` projections.
418
+ query = attn.to_q(x)
419
+ key = attn.to_k(x)
420
+ value = attn.to_v(x)
421
+
422
+ # `context` projections.
423
+ c_query = attn.to_q_c(c)
424
+ c_key = attn.to_k_c(c)
425
+ c_value = attn.to_v_c(c)
426
+
427
+ # apply rope for context and noised input independently
428
+ if rope is not None:
429
+ freqs, xpos_scale = rope
430
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
431
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
432
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
433
+ if c_rope is not None:
434
+ freqs, xpos_scale = c_rope
435
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
436
+ c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
437
+ c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
438
+
439
+ # attention
440
+ query = torch.cat([query, c_query], dim=1)
441
+ key = torch.cat([key, c_key], dim=1)
442
+ value = torch.cat([value, c_value], dim=1)
443
+
444
+ inner_dim = key.shape[-1]
445
+ head_dim = inner_dim // attn.heads
446
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
447
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
448
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
449
+
450
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
451
+ if mask is not None:
452
+ attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
453
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
454
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
455
+ else:
456
+ attn_mask = None
457
+
458
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
459
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
460
+ x = x.to(query.dtype)
461
+
462
+ # Split the attention outputs.
463
+ x, c = (
464
+ x[:, : residual.shape[1]],
465
+ x[:, residual.shape[1] :],
466
+ )
467
+
468
+ # linear proj
469
+ x = attn.to_out[0](x)
470
+ # dropout
471
+ x = attn.to_out[1](x)
472
+ if not attn.context_pre_only:
473
+ c = attn.to_out_c(c)
474
+
475
+ if mask is not None:
476
+ mask = mask.unsqueeze(-1)
477
+ x = x.masked_fill(~mask, 0.0)
478
+ # c = c.masked_fill(~mask, 0.) # no mask for c (text)
479
+
480
+ return x, c
481
+
482
+
483
+ # DiT Block
484
+
485
+
486
+ class DiTBlock(nn.Module):
487
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, use_style_prompt=False):
488
+ super().__init__()
489
+
490
+ self.attn_norm = AdaLayerNormZero(dim)
491
+ self.attn = Attention(
492
+ processor=AttnProcessor(),
493
+ dim=dim,
494
+ heads=heads,
495
+ dim_head=dim_head,
496
+ dropout=dropout,
497
+ )
498
+
499
+ self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
500
+ self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
501
+
502
+ self.use_style_prompt = use_style_prompt
503
+ if use_style_prompt:
504
+ #self.film = FiLMLayer(dim, dim)
505
+ self.prompt_norm = AdaLayerNormZero_Final(dim)
506
+
507
+ def forward(self, x, t, c=None, mask=None, rope=None): # x: noised input, t: time embedding
508
+ if c is not None and self.use_style_prompt:
509
+ #x = self.film(x, c)
510
+ x = self.prompt_norm(x, c)
511
+
512
+ # pre-norm & modulation for attention input
513
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
514
+
515
+ # attention
516
+ attn_output = self.attn(x=norm, mask=mask, rope=rope)
517
+
518
+ # process attention output for input x
519
+ x = x + gate_msa.unsqueeze(1) * attn_output
520
+
521
+ norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
522
+ ff_output = self.ff(norm)
523
+ x = x + gate_mlp.unsqueeze(1) * ff_output
524
+
525
+ return x
526
+
527
+
528
+ # MMDiT Block https://arxiv.org/abs/2403.03206
529
+
530
+
531
+ class MMDiTBlock(nn.Module):
532
+ r"""
533
+ modified from diffusers/src/diffusers/models/attention.py
534
+
535
+ notes.
536
+ _c: context related. text, cond, etc. (left part in sd3 fig2.b)
537
+ _x: noised input related. (right part)
538
+ context_pre_only: last layer only do prenorm + modulation cuz no more ffn
539
+ """
540
+
541
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
542
+ super().__init__()
543
+
544
+ self.context_pre_only = context_pre_only
545
+
546
+ self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
547
+ self.attn_norm_x = AdaLayerNormZero(dim)
548
+ self.attn = Attention(
549
+ processor=JointAttnProcessor(),
550
+ dim=dim,
551
+ heads=heads,
552
+ dim_head=dim_head,
553
+ dropout=dropout,
554
+ context_dim=dim,
555
+ context_pre_only=context_pre_only,
556
+ )
557
+
558
+ if not context_pre_only:
559
+ self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
560
+ self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
561
+ else:
562
+ self.ff_norm_c = None
563
+ self.ff_c = None
564
+ self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
565
+ self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
566
+
567
+ def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
568
+ # pre-norm & modulation for attention input
569
+ if self.context_pre_only:
570
+ norm_c = self.attn_norm_c(c, t)
571
+ else:
572
+ norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
573
+ norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
574
+
575
+ # attention
576
+ x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
577
+
578
+ # process attention output for context c
579
+ if self.context_pre_only:
580
+ c = None
581
+ else: # if not last layer
582
+ c = c + c_gate_msa.unsqueeze(1) * c_attn_output
583
+
584
+ norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
585
+ c_ff_output = self.ff_c(norm_c)
586
+ c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
587
+
588
+ # process attention output for input x
589
+ x = x + x_gate_msa.unsqueeze(1) * x_attn_output
590
+
591
+ norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
592
+ x_ff_output = self.ff_x(norm_x)
593
+ x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
594
+
595
+ return c, x
596
+
597
+
598
+ # time step conditioning embedding
599
+
600
+
601
+ class TimestepEmbedding(nn.Module):
602
+ def __init__(self, dim, freq_embed_dim=256):
603
+ super().__init__()
604
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
605
+ self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
606
+
607
+ def forward(self, timestep: float["b"]): # noqa: F821
608
+ time_hidden = self.time_embed(timestep)
609
+ time_hidden = time_hidden.to(timestep.dtype)
610
+ time = self.time_mlp(time_hidden) # b d
611
+ return time
612
+
613
+
614
+ # attention mask realated
615
+
616
+
617
+ def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds):
618
+ # create noncausal mask
619
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
620
+ combined_attention_mask = None
621
+
622
+ def _expand_mask(
623
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None
624
+ ):
625
+ """
626
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
627
+ """
628
+ bsz, src_len = mask.size()
629
+ tgt_len = tgt_len if tgt_len is not None else src_len
630
+
631
+ expanded_mask = (
632
+ mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
633
+ )
634
+
635
+ inverted_mask = 1.0 - expanded_mask
636
+
637
+ return inverted_mask.masked_fill(
638
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
639
+ )
640
+
641
+ if attention_mask is not None:
642
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
643
+ expanded_attn_mask = _expand_mask(
644
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
645
+ ).to(inputs_embeds.device)
646
+ combined_attention_mask = (
647
+ expanded_attn_mask
648
+ if combined_attention_mask is None
649
+ else expanded_attn_mask + combined_attention_mask
650
+ )
651
+
652
+ return combined_attention_mask
soundsation/model/trainer.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import gc
5
+ from tqdm import tqdm
6
+ import wandb
7
+
8
+ import torch
9
+ from torch.optim import AdamW
10
+ from torch.optim.lr_scheduler import LinearLR, SequentialLR, ConstantLR
11
+
12
+ from accelerate import Accelerator
13
+ from accelerate.utils import DistributedDataParallelKwargs
14
+ from soundsation.dataset.custom_dataset_align2f5 import LanceDiffusionDataset
15
+
16
+ from torch.utils.data import DataLoader, DistributedSampler
17
+
18
+ from ema_pytorch import EMA
19
+
20
+ from soundsation.model import CFM
21
+ from soundsation.model.utils import exists, default
22
+
23
+ import time
24
+
25
+ # from apex.optimizers.fused_adam import FusedAdam
26
+
27
+ # trainer
28
+
29
+
30
+ class Trainer:
31
+ def __init__(
32
+ self,
33
+ model: CFM,
34
+ args,
35
+ epochs,
36
+ learning_rate,
37
+ #dataloader,
38
+ num_warmup_updates=20000,
39
+ save_per_updates=1000,
40
+ checkpoint_path=None,
41
+ batch_size=32,
42
+ batch_size_type: str = "sample",
43
+ max_samples=32,
44
+ grad_accumulation_steps=1,
45
+ max_grad_norm=1.0,
46
+ noise_scheduler: str | None = None,
47
+ duration_predictor: torch.nn.Module | None = None,
48
+ wandb_project="test_e2-tts",
49
+ wandb_run_name="test_run",
50
+ wandb_resume_id: str = None,
51
+ last_per_steps=None,
52
+ accelerate_kwargs: dict = dict(),
53
+ ema_kwargs: dict = dict(),
54
+ bnb_optimizer: bool = False,
55
+ reset_lr: bool = False,
56
+ use_style_prompt: bool = False,
57
+ grad_ckpt: bool = False
58
+ ):
59
+ self.args = args
60
+
61
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False, )
62
+
63
+ logger = "wandb" if wandb.api.api_key else None
64
+ #logger = None
65
+ print(f"Using logger: {logger}")
66
+ # print("-----------1-------------")
67
+ import tbe.common
68
+ # print("-----------2-------------")
69
+ self.accelerator = Accelerator(
70
+ log_with=logger,
71
+ kwargs_handlers=[ddp_kwargs],
72
+ gradient_accumulation_steps=grad_accumulation_steps,
73
+ **accelerate_kwargs,
74
+ )
75
+ # print("-----------3-------------")
76
+
77
+ if logger == "wandb":
78
+ if exists(wandb_resume_id):
79
+ init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
80
+ else:
81
+ init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
82
+ self.accelerator.init_trackers(
83
+ project_name=wandb_project,
84
+ init_kwargs=init_kwargs,
85
+ config={
86
+ "epochs": epochs,
87
+ "learning_rate": learning_rate,
88
+ "num_warmup_updates": num_warmup_updates,
89
+ "batch_size": batch_size,
90
+ "batch_size_type": batch_size_type,
91
+ "max_samples": max_samples,
92
+ "grad_accumulation_steps": grad_accumulation_steps,
93
+ "max_grad_norm": max_grad_norm,
94
+ "gpus": self.accelerator.num_processes,
95
+ "noise_scheduler": noise_scheduler,
96
+ },
97
+ )
98
+
99
+ self.precision = self.accelerator.state.mixed_precision
100
+ self.precision = self.precision.replace("no", "fp32")
101
+ print("!!!!!!!!!!!!!!!!!", self.precision)
102
+
103
+ self.model = model
104
+ #self.model = torch.compile(model)
105
+
106
+ #self.dataloader = dataloader
107
+
108
+ if self.is_main:
109
+ self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
110
+
111
+ self.ema_model.to(self.accelerator.device)
112
+ if self.accelerator.state.distributed_type in ["DEEPSPEED", "FSDP"]:
113
+ self.ema_model.half()
114
+
115
+ self.epochs = epochs
116
+ self.num_warmup_updates = num_warmup_updates
117
+ self.save_per_updates = save_per_updates
118
+ self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
119
+ self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
120
+
121
+ self.max_samples = max_samples
122
+ self.grad_accumulation_steps = grad_accumulation_steps
123
+ self.max_grad_norm = max_grad_norm
124
+
125
+ self.noise_scheduler = noise_scheduler
126
+
127
+ self.duration_predictor = duration_predictor
128
+
129
+ self.reset_lr = reset_lr
130
+
131
+ self.use_style_prompt = use_style_prompt
132
+
133
+ self.grad_ckpt = grad_ckpt
134
+
135
+ if bnb_optimizer:
136
+ import bitsandbytes as bnb
137
+
138
+ self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
139
+ else:
140
+ self.optimizer = AdamW(model.parameters(), lr=learning_rate)
141
+ #self.optimizer = FusedAdam(model.parameters(), lr=learning_rate)
142
+
143
+ #self.model = torch.compile(self.model)
144
+ if self.accelerator.state.distributed_type == "DEEPSPEED":
145
+ self.accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = batch_size
146
+
147
+ self.get_dataloader()
148
+ self.get_scheduler()
149
+ # self.get_constant_scheduler()
150
+
151
+ self.model, self.optimizer, self.scheduler, self.train_dataloader = self.accelerator.prepare(self.model, self.optimizer, self.scheduler, self.train_dataloader)
152
+
153
+ def get_scheduler(self):
154
+ warmup_steps = (
155
+ self.num_warmup_updates * self.accelerator.num_processes
156
+ ) # consider a fixed warmup steps while using accelerate multi-gpu ddp
157
+ total_steps = len(self.train_dataloader) * self.epochs / self.grad_accumulation_steps
158
+ decay_steps = total_steps - warmup_steps
159
+ warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
160
+ decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
161
+ # constant_scheduler = ConstantLR(self.optimizer, factor=1, total_iters=decay_steps)
162
+ self.scheduler = SequentialLR(
163
+ self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
164
+ )
165
+
166
+ def get_constant_scheduler(self):
167
+ total_steps = len(self.train_dataloader) * self.epochs / self.grad_accumulation_steps
168
+ self.scheduler = ConstantLR(self.optimizer, factor=1, total_iters=total_steps)
169
+
170
+ def get_dataloader(self):
171
+ prompt_path = self.args.prompt_path.split('|')
172
+ lrc_path = self.args.lrc_path.split('|')
173
+ latent_path = self.args.latent_path.split('|')
174
+ ldd = LanceDiffusionDataset(*LanceDiffusionDataset.init_data(self.args.dataset_path), \
175
+ max_frames=self.args.max_frames, min_frames=self.args.min_frames, \
176
+ align_lyrics=self.args.align_lyrics, lyrics_slice=self.args.lyrics_slice, \
177
+ use_style_prompt=self.args.use_style_prompt, parse_lyrics=self.args.parse_lyrics,
178
+ lyrics_shift=self.args.lyrics_shift, downsample_rate=self.args.downsample_rate, \
179
+ skip_empty_lyrics=self.args.skip_empty_lyrics, tokenizer_type=self.args.tokenizer_type, precision=self.precision, \
180
+ start_time=time.time(), pure_prob=self.args.pure_prob)
181
+
182
+ # start_time = time.time()
183
+ self.train_dataloader = DataLoader(
184
+ dataset=ldd,
185
+ batch_size=self.args.batch_size, # 每个批次的样本数
186
+ shuffle=True, # 是否随机打乱数据
187
+ num_workers=4, # 用于加载数据的子进程数
188
+ pin_memory=True, # 加速GPU训练
189
+ collate_fn=ldd.custom_collate_fn,
190
+ persistent_workers=True
191
+ )
192
+
193
+
194
+ @property
195
+ def is_main(self):
196
+ return self.accelerator.is_main_process
197
+
198
+ def save_checkpoint(self, step, last=False):
199
+ self.accelerator.wait_for_everyone()
200
+ if self.is_main:
201
+ checkpoint = dict(
202
+ model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
203
+ optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
204
+ ema_model_state_dict=self.ema_model.state_dict(),
205
+ scheduler_state_dict=self.scheduler.state_dict(),
206
+ step=step,
207
+ )
208
+ if not os.path.exists(self.checkpoint_path):
209
+ os.makedirs(self.checkpoint_path)
210
+ if last:
211
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
212
+ print(f"Saved last checkpoint at step {step}")
213
+ else:
214
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
215
+
216
+ def load_checkpoint(self):
217
+ if (
218
+ not exists(self.checkpoint_path)
219
+ or not os.path.exists(self.checkpoint_path)
220
+ or not os.listdir(self.checkpoint_path)
221
+ ):
222
+ return 0
223
+
224
+ self.accelerator.wait_for_everyone()
225
+ if "model_last.pt" in os.listdir(self.checkpoint_path):
226
+ latest_checkpoint = "model_last.pt"
227
+ else:
228
+ latest_checkpoint = sorted(
229
+ [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
230
+ key=lambda x: int("".join(filter(str.isdigit, x))),
231
+ )[-1]
232
+
233
+ checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
234
+
235
+ ### **1. 过滤 `ema_model` 的不匹配参数**
236
+ if self.is_main:
237
+ ema_dict = self.ema_model.state_dict()
238
+ ema_checkpoint_dict = checkpoint["ema_model_state_dict"]
239
+
240
+ filtered_ema_dict = {
241
+ k: v for k, v in ema_checkpoint_dict.items()
242
+ if k in ema_dict and ema_dict[k].shape == v.shape # 仅加载 shape 匹配的参数
243
+ }
244
+
245
+ print(f"Loading {len(filtered_ema_dict)} / {len(ema_checkpoint_dict)} ema_model params")
246
+ self.ema_model.load_state_dict(filtered_ema_dict, strict=False)
247
+
248
+ ### **2. 过滤 `model` 的不匹配参数**
249
+ model_dict = self.accelerator.unwrap_model(self.model).state_dict()
250
+ checkpoint_model_dict = checkpoint["model_state_dict"]
251
+
252
+ filtered_model_dict = {
253
+ k: v for k, v in checkpoint_model_dict.items()
254
+ if k in model_dict and model_dict[k].shape == v.shape # 仅加载 shape 匹配的参数
255
+ }
256
+
257
+ print(f"Loading {len(filtered_model_dict)} / {len(checkpoint_model_dict)} model params")
258
+ self.accelerator.unwrap_model(self.model).load_state_dict(filtered_model_dict, strict=False)
259
+
260
+ ### **3. 加载优化器、调度器和步数**
261
+ if "step" in checkpoint:
262
+ if self.scheduler and not self.reset_lr:
263
+ self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
264
+ step = checkpoint["step"]
265
+ else:
266
+ step = 0
267
+
268
+ del checkpoint
269
+ gc.collect()
270
+ print("Checkpoint loaded at step", step)
271
+ return step
272
+
273
+ def train(self, resumable_with_seed: int = None):
274
+ train_dataloader = self.train_dataloader
275
+
276
+ start_step = self.load_checkpoint()
277
+ global_step = start_step
278
+
279
+ if resumable_with_seed > 0:
280
+ orig_epoch_step = len(train_dataloader)
281
+ skipped_epoch = int(start_step // orig_epoch_step)
282
+ skipped_batch = start_step % orig_epoch_step
283
+ skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
284
+ else:
285
+ skipped_epoch = 0
286
+
287
+ for epoch in range(skipped_epoch, self.epochs):
288
+ self.model.train()
289
+ if resumable_with_seed > 0 and epoch == skipped_epoch:
290
+ progress_bar = tqdm(
291
+ skipped_dataloader,
292
+ desc=f"Epoch {epoch+1}/{self.epochs}",
293
+ unit="step",
294
+ disable=not self.accelerator.is_local_main_process,
295
+ initial=skipped_batch,
296
+ total=orig_epoch_step,
297
+ smoothing=0.15
298
+ )
299
+ else:
300
+ progress_bar = tqdm(
301
+ train_dataloader,
302
+ desc=f"Epoch {epoch+1}/{self.epochs}",
303
+ unit="step",
304
+ disable=not self.accelerator.is_local_main_process,
305
+ smoothing=0.15
306
+ )
307
+
308
+ for batch in progress_bar:
309
+ with self.accelerator.accumulate(self.model):
310
+ text_inputs = batch["lrc"]
311
+ mel_spec = batch["latent"].permute(0, 2, 1)
312
+ mel_lengths = batch["latent_lengths"]
313
+ style_prompt = batch["prompt"]
314
+ style_prompt_lens = batch["prompt_lengths"]
315
+ start_time = batch["start_time"]
316
+
317
+ loss, cond, pred = self.model(
318
+ mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler,
319
+ style_prompt=style_prompt if self.use_style_prompt else None,
320
+ style_prompt_lens=style_prompt_lens if self.use_style_prompt else None,
321
+ grad_ckpt=self.grad_ckpt, start_time=start_time
322
+ )
323
+ self.accelerator.backward(loss)
324
+
325
+ if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
326
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
327
+
328
+ self.optimizer.step()
329
+ self.scheduler.step()
330
+ self.optimizer.zero_grad()
331
+
332
+ if self.is_main:
333
+ self.ema_model.update()
334
+
335
+ global_step += 1
336
+
337
+ if self.accelerator.is_local_main_process:
338
+ self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
339
+
340
+ progress_bar.set_postfix(step=str(global_step), loss=loss.item())
341
+
342
+ if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
343
+ self.save_checkpoint(global_step)
344
+
345
+ if global_step % self.last_per_steps == 0:
346
+ self.save_checkpoint(global_step, last=True)
347
+
348
+ self.save_checkpoint(global_step, last=True)
349
+
350
+ self.accelerator.end_training()
soundsation/model/utils.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import random
5
+ from collections import defaultdict
6
+ from importlib.resources import files
7
+
8
+ import torch
9
+ from torch.nn.utils.rnn import pad_sequence
10
+
11
+
12
+ # seed everything
13
+
14
+
15
+ def seed_everything(seed=0):
16
+ random.seed(seed)
17
+ os.environ["PYTHONHASHSEED"] = str(seed)
18
+ torch.manual_seed(seed)
19
+ torch.cuda.manual_seed(seed)
20
+ torch.cuda.manual_seed_all(seed)
21
+ torch.backends.cudnn.deterministic = True
22
+ torch.backends.cudnn.benchmark = False
23
+
24
+
25
+ # helpers
26
+
27
+
28
+ def exists(v):
29
+ return v is not None
30
+
31
+
32
+ def default(v, d):
33
+ return v if exists(v) else d
34
+
35
+
36
+ # tensor helpers
37
+
38
+
39
+ def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
40
+ if not exists(length):
41
+ length = t.amax()
42
+
43
+ seq = torch.arange(length, device=t.device)
44
+ return seq[None, :] < t[:, None]
45
+
46
+
47
+ def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"], max_frames): # noqa: F722 F821
48
+ max_seq_len = max_frames
49
+ seq = torch.arange(max_seq_len, device=start.device).long()
50
+ start_mask = seq[None, :] >= start[:, None]
51
+ end_mask = seq[None, :] < end[:, None]
52
+ return start_mask & end_mask
53
+
54
+
55
+ def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"], max_frames): # noqa: F722 F821
56
+ lengths = (frac_lengths * seq_len).long()
57
+ max_start = seq_len - lengths
58
+
59
+ rand = torch.rand_like(frac_lengths)
60
+ start = (max_start * rand).long().clamp(min=0)
61
+ end = start + lengths
62
+
63
+ return mask_from_start_end_indices(seq_len, start, end, max_frames)
64
+
65
+
66
+ def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
67
+ if not exists(mask):
68
+ return t.mean(dim=1)
69
+
70
+ t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
71
+ num = t.sum(dim=1)
72
+ den = mask.float().sum(dim=1)
73
+
74
+ return num / den.clamp(min=1.0)
75
+
76
+
77
+ # simple utf-8 tokenizer, since paper went character based
78
+ def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
79
+ list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
80
+ text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
81
+ return text
82
+
83
+
84
+ # char tokenizer, based on custom dataset's extracted .txt file
85
+ def list_str_to_idx(
86
+ text: list[str] | list[list[str]],
87
+ vocab_char_map: dict[str, int], # {char: idx}
88
+ padding_value=-1,
89
+ ) -> int["b nt"]: # noqa: F722
90
+ list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
91
+ text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
92
+ return text
93
+
94
+
95
+ # Get tokenizer
96
+
97
+
98
+ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
99
+ """
100
+ tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
101
+ - "char" for char-wise tokenizer, need .txt vocab_file
102
+ - "byte" for utf-8 tokenizer
103
+ - "custom" if you're directly passing in a path to the vocab.txt you want to use
104
+ vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
105
+ - if use "char", derived from unfiltered character & symbol counts of custom dataset
106
+ - if use "byte", set to 256 (unicode byte range)
107
+ """
108
+ if tokenizer in ["pinyin", "char"]:
109
+ tokenizer_path = os.path.join(files("soundsation").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
110
+ with open(tokenizer_path, "r", encoding="utf-8") as f:
111
+ vocab_char_map = {}
112
+ for i, char in enumerate(f):
113
+ vocab_char_map[char[:-1]] = i
114
+ vocab_size = len(vocab_char_map)
115
+ assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
116
+
117
+ elif tokenizer == "byte":
118
+ vocab_char_map = None
119
+ vocab_size = 256
120
+
121
+ elif tokenizer == "custom":
122
+ with open(dataset_name, "r", encoding="utf-8") as f:
123
+ vocab_char_map = {}
124
+ for i, char in enumerate(f):
125
+ vocab_char_map[char[:-1]] = i
126
+ vocab_size = len(vocab_char_map)
127
+
128
+ return vocab_char_map, vocab_size
129
+
130
+
131
+ # convert char to pinyin
132
+
133
+
134
+ def convert_char_to_pinyin(text_list, polyphone=True):
135
+ final_text_list = []
136
+ god_knows_why_en_testset_contains_zh_quote = str.maketrans(
137
+ {"“": '"', "”": '"', "‘": "'", "’": "'"}
138
+ ) # in case librispeech (orig no-pc) test-clean
139
+ custom_trans = str.maketrans({";": ","}) # add custom trans here, to address oov
140
+ for text in text_list:
141
+ char_list = []
142
+ text = text.translate(god_knows_why_en_testset_contains_zh_quote)
143
+ text = text.translate(custom_trans)
144
+ for seg in jieba.cut(text):
145
+ seg_byte_len = len(bytes(seg, "UTF-8"))
146
+ if seg_byte_len == len(seg): # if pure alphabets and symbols
147
+ if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
148
+ char_list.append(" ")
149
+ char_list.extend(seg)
150
+ elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters
151
+ seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
152
+ for c in seg:
153
+ if c not in "。,、;:?!《》【】—…":
154
+ char_list.append(" ")
155
+ char_list.append(c)
156
+ else: # if mixed chinese characters, alphabets and symbols
157
+ for c in seg:
158
+ if ord(c) < 256:
159
+ char_list.extend(c)
160
+ else:
161
+ if c not in "。,、;:?!《》【】—…":
162
+ char_list.append(" ")
163
+ char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
164
+ else: # if is zh punc
165
+ char_list.append(c)
166
+ final_text_list.append(char_list)
167
+
168
+ return final_text_list
169
+
170
+
171
+ # filter func for dirty data with many repetitions
172
+
173
+
174
+ def repetition_found(text, length=2, tolerance=10):
175
+ pattern_count = defaultdict(int)
176
+ for i in range(len(text) - length + 1):
177
+ pattern = text[i : i + length]
178
+ pattern_count[pattern] += 1
179
+ for pattern, count in pattern_count.items():
180
+ if count > tolerance:
181
+ return True
182
+ return False
src/negative_prompt.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cb7d74eb7a8eda12acb8247b21d373928301db8a8cb0db480d341799fed3ce5
3
+ size 2176