unknown commited on
Commit
e1e3b26
·
1 Parent(s): d864425

fix space curse problem with utf-8-sig

Browse files
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -801,11 +801,11 @@ def vocab_extend(project_name, symbols, model_type):
801
  return "Symbols are okay no need to extend."
802
 
803
  size_vocab = len(vocab)
804
- vocab.pop() # fix empty space leave
805
  for item in miss_symbols:
806
  vocab.append(item)
807
 
808
- with open(file_vocab_project, "w", encoding="utf-8-sig") as f:
809
  f.write("\n".join(vocab))
810
 
811
  if model_type == "F5-TTS":
@@ -813,14 +813,17 @@ def vocab_extend(project_name, symbols, model_type):
813
  else:
814
  ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
815
 
816
- new_ckpt_path = os.path.join(path_project_ckpts, name_project)
 
 
 
817
  os.makedirs(new_ckpt_path, exist_ok=True)
818
  new_ckpt_file = os.path.join(new_ckpt_path, "model_1200000.pt")
819
 
820
- size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=len(miss_symbols))
821
 
822
  vocab_new = "\n".join(miss_symbols)
823
- return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {len(miss_symbols)}\nnew symbols :\n{vocab_new}"
824
 
825
 
826
  def vocab_check(project_name):
 
801
  return "Symbols are okay no need to extend."
802
 
803
  size_vocab = len(vocab)
804
+
805
  for item in miss_symbols:
806
  vocab.append(item)
807
 
808
+ with open(file_vocab_project, "w", encoding="utf-8") as f:
809
  f.write("\n".join(vocab))
810
 
811
  if model_type == "F5-TTS":
 
813
  else:
814
  ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
815
 
816
+ vocab_size_new = len(miss_symbols)
817
+
818
+ dataset_name = name_project.replace("_pinyin", "").replace("_char", "")
819
+ new_ckpt_path = os.path.join(path_project_ckpts, dataset_name)
820
  os.makedirs(new_ckpt_path, exist_ok=True)
821
  new_ckpt_file = os.path.join(new_ckpt_path, "model_1200000.pt")
822
 
823
+ size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
824
 
825
  vocab_new = "\n".join(miss_symbols)
826
+ return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {vocab_size_new}\nnew symbols :\n{vocab_new}"
827
 
828
 
829
  def vocab_check(project_name):