Yushen CHEN commited on
Commit
c879353
·
1 Parent(s): e6098be

update finetune_gradio.py, not to force lower case

Browse files

Not to force lower case, otherwise train infer mismatch with main infer code

src/f5_tts/train/finetune_gradio.py CHANGED
@@ -178,11 +178,6 @@ def get_audio_duration(audio_path):
178
  return audio.shape[1] / sample_rate
179
 
180
 
181
- def clear_text(text):
182
- """Clean and prepare text by lowering the case and stripping whitespace."""
183
- return text.lower().strip()
184
-
185
-
186
  def get_rms(
187
  y,
188
  frame_length=2048,
@@ -707,7 +702,7 @@ def transcribe_all(name_project, audio_files, language, user=False, progress=gr.
707
 
708
  try:
709
  text = transcribe(file_segment, language)
710
- text = text.lower().strip().replace('"', "")
711
 
712
  data += f"{name_segment}|{text}\n"
713
 
@@ -816,7 +811,7 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
816
  error_files.append([file_audio, "very short text length 3"])
817
  continue
818
 
819
- text = clear_text(text)
820
  text = convert_char_to_pinyin([text], polyphone=True)[0]
821
 
822
  audio_path_list.append(file_audio)
@@ -1234,8 +1229,8 @@ def infer(
1234
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
1235
  tts_api.infer(
1236
  ref_file=ref_audio,
1237
- ref_text=ref_text.lower().strip(),
1238
- gen_text=gen_text.lower().strip(),
1239
  nfe_step=nfe_step,
1240
  speed=speed,
1241
  remove_silence=remove_silence,
 
178
  return audio.shape[1] / sample_rate
179
 
180
 
 
 
 
 
 
181
  def get_rms(
182
  y,
183
  frame_length=2048,
 
702
 
703
  try:
704
  text = transcribe(file_segment, language)
705
+ text = text.strip()
706
 
707
  data += f"{name_segment}|{text}\n"
708
 
 
811
  error_files.append([file_audio, "very short text length 3"])
812
  continue
813
 
814
+ text = text.strip()
815
  text = convert_char_to_pinyin([text], polyphone=True)[0]
816
 
817
  audio_path_list.append(file_audio)
 
1229
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
1230
  tts_api.infer(
1231
  ref_file=ref_audio,
1232
+ ref_text=ref_text.strip(),
1233
+ gen_text=gen_text.strip(),
1234
  nfe_step=nfe_step,
1235
  speed=speed,
1236
  remove_silence=remove_silence,