SWivid commited on
Commit
39ce201
Β·
1 Parent(s): f6e3b78

disable mask for single infer to save mem; add custom trans for vocab to address oov

Browse files
Files changed (2) hide show
  1. model/cfm.py +4 -1
  2. model/utils.py +2 -0
model/cfm.py CHANGED
@@ -142,7 +142,10 @@ class CFM(nn.Module):
142
  cond_mask = rearrange(cond_mask, '... -> ... 1')
143
  step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
144
 
145
- mask = lens_to_mask(duration)
 
 
 
146
 
147
  # test for no ref audio
148
  if no_ref_audio:
 
142
  cond_mask = rearrange(cond_mask, '... -> ... 1')
143
  step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
144
 
145
+ if batch > 1:
146
+ mask = lens_to_mask(duration)
147
+ else: # save memory and speed up, as single inference need no mask currently
148
+ mask = None
149
 
150
  # test for no ref audio
151
  if no_ref_audio:
model/utils.py CHANGED
@@ -153,9 +153,11 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
153
  def convert_char_to_pinyin(text_list, polyphone = True):
154
  final_text_list = []
155
  god_knows_why_en_testset_contains_zh_quote = str.maketrans({'β€œ': '"', '”': '"', 'β€˜': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean
 
156
  for text in text_list:
157
  char_list = []
158
  text = text.translate(god_knows_why_en_testset_contains_zh_quote)
 
159
  for seg in jieba.cut(text):
160
  seg_byte_len = len(bytes(seg, 'UTF-8'))
161
  if seg_byte_len == len(seg): # if pure alphabets and symbols
 
153
  def convert_char_to_pinyin(text_list, polyphone = True):
154
  final_text_list = []
155
  god_knows_why_en_testset_contains_zh_quote = str.maketrans({'β€œ': '"', '”': '"', 'β€˜': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean
156
+ custom_trans = str.maketrans({';': ','}) # add custom trans here, to address oov
157
  for text in text_list:
158
  char_list = []
159
  text = text.translate(god_knows_why_en_testset_contains_zh_quote)
160
+ text = text.translate(custom_trans)
161
  for seg in jieba.cut(text):
162
  seg_byte_len = len(bytes(seg, 'UTF-8'))
163
  if seg_byte_len == len(seg): # if pure alphabets and symbols