SWivid commited on
Commit
e54fee3
·
1 Parent(s): 372f6ab

redirect to split hf ckpt repos

Browse files
Files changed (2) hide show
  1. gradio_app.py +4 -4
  2. inference-cli.py +12 -12
gradio_app.py CHANGED
@@ -62,8 +62,8 @@ speed = 1.0
62
  fix_duration = None
63
 
64
 
65
- def load_model(exp_name, model_cls, model_cfg, ckpt_step):
66
- ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors"))
67
  # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
68
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
69
  model = CFM(
@@ -93,10 +93,10 @@ F5TTS_model_cfg = dict(
93
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
94
 
95
  F5TTS_ema_model = load_model(
96
- "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
97
  )
98
  E2TTS_ema_model = load_model(
99
- "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
100
  )
101
 
102
  def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
 
62
  fix_duration = None
63
 
64
 
65
+ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
66
+ ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
67
  # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
68
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
69
  model = CFM(
 
93
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
94
 
95
  F5TTS_ema_model = load_model(
96
+ "F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
97
  )
98
  E2TTS_ema_model = load_model(
99
+ "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
100
  )
101
 
102
  def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
inference-cli.py CHANGED
@@ -74,7 +74,7 @@ ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
74
  ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
75
  gen_text = args.gen_text if args.gen_text else config["gen_text"]
76
  output_dir = args.output_dir if args.output_dir else config["output_dir"]
77
- exp_name = args.model if args.model else config["model"]
78
  remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
79
  wave_path = Path(output_dir)/"out.wav"
80
  spectrogram_path = Path(output_dir)/"out.png"
@@ -112,8 +112,8 @@ speed = 1.0
112
  # fix_duration = 27 # None or float (duration in seconds)
113
  fix_duration = None
114
 
115
- def load_model(exp_name, model_cls, model_cfg, ckpt_step):
116
- ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors"))
117
  # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
118
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
119
  model = CFM(
@@ -238,11 +238,11 @@ def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
238
 
239
  return batches
240
 
241
- def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence):
242
- if exp_name == "F5-TTS":
243
- ema_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
244
- elif exp_name == "E2-TTS":
245
- ema_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
246
 
247
  audio, sr = ref_audio
248
  if audio.shape[0] > 1:
@@ -320,7 +320,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence)
320
  print(spectrogram_path)
321
 
322
 
323
- def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_split_words):
324
  if not custom_split_words.strip():
325
  custom_words = [word.strip() for word in custom_split_words.split(',')]
326
  global SPLIT_WORDS
@@ -372,8 +372,8 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
372
  for i, gen_text in enumerate(gen_text_batches):
373
  print(f'gen_text {i}', gen_text)
374
 
375
- print(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches, loading models...")
376
- return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence)
377
 
378
 
379
- infer(ref_audio, ref_text, gen_text, exp_name, remove_silence, ",".join(SPLIT_WORDS))
 
74
  ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
75
  gen_text = args.gen_text if args.gen_text else config["gen_text"]
76
  output_dir = args.output_dir if args.output_dir else config["output_dir"]
77
+ model = args.model if args.model else config["model"]
78
  remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
79
  wave_path = Path(output_dir)/"out.wav"
80
  spectrogram_path = Path(output_dir)/"out.png"
 
112
  # fix_duration = 27 # None or float (duration in seconds)
113
  fix_duration = None
114
 
115
+ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
116
+ ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
117
  # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
118
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
119
  model = CFM(
 
238
 
239
  return batches
240
 
241
+ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
242
+ if model == "F5-TTS":
243
+ ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
244
+ elif model == "E2-TTS":
245
+ ema_model = load_model(model, "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
246
 
247
  audio, sr = ref_audio
248
  if audio.shape[0] > 1:
 
320
  print(spectrogram_path)
321
 
322
 
323
+ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_split_words):
324
  if not custom_split_words.strip():
325
  custom_words = [word.strip() for word in custom_split_words.split(',')]
326
  global SPLIT_WORDS
 
372
  for i, gen_text in enumerate(gen_text_batches):
373
  print(f'gen_text {i}', gen_text)
374
 
375
+ print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
376
+ return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence)
377
 
378
 
379
+ infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS))