Spaces:
Configuration error
Configuration error
disable mask for single infer to save mem; add custom trans for vocab to address oov
Browse files- model/cfm.py +4 -1
- model/utils.py +2 -0
model/cfm.py
CHANGED
@@ -142,7 +142,10 @@ class CFM(nn.Module):
|
|
142 |
cond_mask = rearrange(cond_mask, '... -> ... 1')
|
143 |
step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
|
144 |
|
145 |
-
|
|
|
|
|
|
|
146 |
|
147 |
# test for no ref audio
|
148 |
if no_ref_audio:
|
|
|
142 |
cond_mask = rearrange(cond_mask, '... -> ... 1')
|
143 |
step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
|
144 |
|
145 |
+
if batch > 1:
|
146 |
+
mask = lens_to_mask(duration)
|
147 |
+
else: # save memory and speed up, as single inference need no mask currently
|
148 |
+
mask = None
|
149 |
|
150 |
# test for no ref audio
|
151 |
if no_ref_audio:
|
model/utils.py
CHANGED
@@ -153,9 +153,11 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
|
153 |
def convert_char_to_pinyin(text_list, polyphone = True):
|
154 |
final_text_list = []
|
155 |
god_knows_why_en_testset_contains_zh_quote = str.maketrans({'β': '"', 'β': '"', 'β': "'", 'β': "'"}) # in case librispeech (orig no-pc) test-clean
|
|
|
156 |
for text in text_list:
|
157 |
char_list = []
|
158 |
text = text.translate(god_knows_why_en_testset_contains_zh_quote)
|
|
|
159 |
for seg in jieba.cut(text):
|
160 |
seg_byte_len = len(bytes(seg, 'UTF-8'))
|
161 |
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
|
|
153 |
def convert_char_to_pinyin(text_list, polyphone = True):
|
154 |
final_text_list = []
|
155 |
god_knows_why_en_testset_contains_zh_quote = str.maketrans({'β': '"', 'β': '"', 'β': "'", 'β': "'"}) # in case librispeech (orig no-pc) test-clean
|
156 |
+
custom_trans = str.maketrans({';': ','}) # add custom trans here, to address oov
|
157 |
for text in text_list:
|
158 |
char_list = []
|
159 |
text = text.translate(god_knows_why_en_testset_contains_zh_quote)
|
160 |
+
text = text.translate(custom_trans)
|
161 |
for seg in jieba.cut(text):
|
162 |
seg_byte_len = len(bytes(seg, 'UTF-8'))
|
163 |
if seg_byte_len == len(seg): # if pure alphabets and symbols
|